From 6682b88517dfc30c210763fc0233aef76526868c Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 22 Sep 2023 11:39:37 +0000 Subject: [PATCH] [Recipies] Add method FuseBatchNorm. --- include/aidge/operator/BatchNorm.hpp | 3 +- include/aidge/operator/Conv.hpp | 10 +- python_binding/recipies/pybind_Recipies.cpp | 21 +++ src/recipies/FuseBatchNorm.cpp | 146 ++++++++++++++++++++ 4 files changed, 176 insertions(+), 4 deletions(-) create mode 100644 src/recipies/FuseBatchNorm.cpp diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 6861c1359..d808e0ce6 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -76,7 +76,6 @@ public: if (!mInputs[0]->empty()) { for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) { if(mInputs[i]->size() != mInputs[0]->dims()[1]) { - assert(!mInputs[0]->hasImpl() && "Incompatible size with already implemented learnable parameter"); mInputs[i]->resize(std::array<DimSize_t, 1>({mInputs[0]->dims()[1]})); } } @@ -160,4 +159,4 @@ template <> const char *const EnumStrings<Aidge::BatchNormParam>::data[] = { "Epsilon", "Momentum" }; } -#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_ \ No newline at end of file +#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_ diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 1edc94b96..96804f811 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -195,8 +195,14 @@ inline std::shared_ptr<Node> Conv( namespace { template <> -const char *const EnumStrings<Aidge::ConvParam>::data[] = {"StrideDims", "DilationDims", "InChannels", "OutChannels", - "KernelDims", "PaddingDims"}; +const char *const EnumStrings<Aidge::ConvParam>::data[] = { + "StrideDims", + "DilationDims", + "InChannels", + "OutChannels", + "KernelDims", + "PaddingDims" +}; } #endif /* AIDGE_CORE_OPERATOR_CONV_H_ */ diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index d1c392f74..93c131ef7 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -20,6 +20,8 @@ namespace py = pybind11; namespace Aidge { void init_Recipies(py::module &m) { + + m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. @@ -32,6 +34,7 @@ void init_Recipies(py::module &m) { :param nodes: The MatMul and Add nodes to fuse. :type nodes: list of :py:class:`aidge_core.Node` )mydelimiter"); + m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. @@ -44,6 +47,24 @@ void init_Recipies(py::module &m) { :param nodes: The flatten operator to remove. :type nodes: list of :py:class:`aidge_core.Node` )mydelimiter"); + m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + :param nodes: The MatMul and Add nodes to fuse. + :type nodes: list of :py:class:`aidge_core.Node` + )mydelimiter"); + m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( + Recipie to remove a flatten operator. + + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); + m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( + Recipie to remove a flatten operator. + + :param nodes: The flatten operator to remove. + :type nodes: list of :py:class:`aidge_core.Node` + )mydelimiter"); } } // namespace Aidge diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp new file mode 100644 index 000000000..2250bf87c --- /dev/null +++ b/src/recipies/FuseBatchNorm.cpp @@ -0,0 +1,146 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include <set> +#include <cassert> +#include <memory> +#include <string> +#include "aidge/operator/FC.hpp" +#include "aidge/operator/BatchNorm.hpp" +#include "aidge/operator/Conv.hpp" + +#include "aidge/utils/Recipies.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/GenericOperator.hpp" +// Graph Regex +#include "aidge/graphmatching/GRegex.hpp" +#include "aidge/graphmatching/NodeRegex.hpp" +using namespace Aidge; + +void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ + + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); + + // Assert the nodes types are correct to be fused + std::shared_ptr<Node> conv; + std::shared_ptr<Node> batchnorm; + for (const auto& element : nodes) { + assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace"); + if (element->type() == "Conv"){ + conv = element; + } + else if (element->type() == "BatchNorm") { + batchnorm = element; + } + } + // TODO : check if batchnorm is the only child of the Conv or FC + std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); + std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); + std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second); + std::shared_ptr<Tensor> b_var = batchnorm->input(4).first->getOperator()->getOutput(batchnorm->input(4).second); + + + // TODO : Find a way to remove the template + const float epsilon = std::static_pointer_cast<BatchNorm_Op<2>>(batchnorm->getOperator())->get<float>("Epsilon"); + DimSize_t convOutDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<DimSize_t>("OutChannels"); + + + assert(scale->size() == convOutDims); + assert(shift->size() == convOutDims); + assert(b_mean->size() == convOutDims); + assert(b_var->size() == convOutDims); + assert(epsilon > 0.0); + // TODO : no no_bias attribute ? + float meanVariance = 0.0; + unsigned int count = 0; + + for (std::size_t output = 0; output < convOutDims; ++output) { + // TODO : get suppose datatype is float .. + if (b_var->get<float>(output) > 1.0e-12) { + meanVariance += b_var->get<float>(output); + ++count; + } + else { + printf("Zero-variance: %s [%lu]\n", conv->name().c_str(), output); + } + } + if (count > 0) + meanVariance /= count; + else { + printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n"); + } + + const DimSize_t channelsSize = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<DimSize_t>("InChannels"); + + // TODO : suppose we have Conv2D ... + const std::array<DimSize_t, 2> kernelDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<std::array<DimSize_t, 2>>("KernelDims"); + + std::shared_ptr<Tensor> weight = conv->input(1).first->getOperator()->getOutput(conv->input(1).second); + std::shared_ptr<Tensor> bias = conv->input(2).first->getOperator()->getOutput(conv->input(2).second); + + for (std::size_t output = 0; output < convOutDims; ++output) { + // Corrected for zero-variance issue: + // "A Quantization-Friendly Separable Convolution for MobileNets" + // https://arxiv.org/pdf/1803.08607.pdf + // to help post-training quantization + const float factor = scale->get<float>(output) + / std::sqrt(epsilon + ((b_var->get<float>(output) > 1.0e-12 || count == 0) + ? b_var->get<float>(output) : meanVariance)); + // Weights adjustments + for (std::size_t channel = 0; channel < channelsSize; ++channel) { + // TODO : Suppose kerneldims = 2 + for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ + for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ + std::vector<DimSize_t> currentIdx = {output, channel, k0, k1}; + // TODO : suppose weights are float + float weightValue = weight->get<float>(currentIdx); + weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights + } + } + } + + // TODO : check if noBias==true is set, then set biasValue to 0 + float biasValue = bias->get<float>(output); + + biasValue = shift->get<float>(output) + (biasValue - b_mean->get<float>(output)) * factor; + + bias->set<float>(output, biasValue); + + } + auto g = std::make_shared<GraphView>(); + g->add(std::set<std::shared_ptr<Node>>({ + batchnorm, + batchnorm->input(1).first, + batchnorm->input(2).first, + batchnorm->input(3).first, + batchnorm->input(4).first + })); + g->replaceWith({}); + +} + +void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ + std::map<std::string,NodeRegex*> nodesRegex ; + nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm"); + nodesRegex["Conv"] = new NodeRegex("Conv"); + nodesRegex["FC"] = new NodeRegex("FC"); + + + std::vector<std::string> seqRegex; + seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC) + GRegex GReg(nodesRegex, seqRegex); + Match matches = GReg.match(graphView); + std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + for (size_t i = 0; i < matches.getNbMatch(); ++i) { + fuseBatchNorm(matchNodes[i]); + } +} -- GitLab