From 6682b88517dfc30c210763fc0233aef76526868c Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 22 Sep 2023 11:39:37 +0000
Subject: [PATCH] [Recipies] Add method FuseBatchNorm.

---
 include/aidge/operator/BatchNorm.hpp        |   3 +-
 include/aidge/operator/Conv.hpp             |  10 +-
 python_binding/recipies/pybind_Recipies.cpp |  21 +++
 src/recipies/FuseBatchNorm.cpp              | 146 ++++++++++++++++++++
 4 files changed, 176 insertions(+), 4 deletions(-)
 create mode 100644 src/recipies/FuseBatchNorm.cpp

diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp
index 6861c1359..d808e0ce6 100644
--- a/include/aidge/operator/BatchNorm.hpp
+++ b/include/aidge/operator/BatchNorm.hpp
@@ -76,7 +76,6 @@ public:
         if (!mInputs[0]->empty()) {
             for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) {
                 if(mInputs[i]->size() != mInputs[0]->dims()[1]) {
-                    assert(!mInputs[0]->hasImpl() && "Incompatible size with already implemented learnable parameter");
                     mInputs[i]->resize(std::array<DimSize_t, 1>({mInputs[0]->dims()[1]}));
                 }
             }
@@ -160,4 +159,4 @@ template <>
 const char *const EnumStrings<Aidge::BatchNormParam>::data[] = { "Epsilon", "Momentum" };
 }
 
-#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_
\ No newline at end of file
+#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_
diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp
index 1edc94b96..96804f811 100644
--- a/include/aidge/operator/Conv.hpp
+++ b/include/aidge/operator/Conv.hpp
@@ -195,8 +195,14 @@ inline std::shared_ptr<Node> Conv(
 
 namespace {
 template <>
-const char *const EnumStrings<Aidge::ConvParam>::data[] = {"StrideDims", "DilationDims", "InChannels", "OutChannels",
-                                                          "KernelDims", "PaddingDims"};
+const char *const EnumStrings<Aidge::ConvParam>::data[] = {
+    "StrideDims",
+    "DilationDims",
+    "InChannels",
+    "OutChannels",
+    "KernelDims",
+    "PaddingDims"
+};
 }
 
 #endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp
index d1c392f74..93c131ef7 100644
--- a/python_binding/recipies/pybind_Recipies.cpp
+++ b/python_binding/recipies/pybind_Recipies.cpp
@@ -20,6 +20,8 @@ namespace py = pybind11;
 
 namespace Aidge {
 void init_Recipies(py::module &m) {
+
+
   m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter(
     Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
 
@@ -32,6 +34,7 @@ void init_Recipies(py::module &m) {
     :param nodes: The MatMul and Add nodes to fuse.
     :type nodes: list of :py:class:`aidge_core.Node`
     )mydelimiter");
+
   m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
     Recipie to remove a flatten operator.
 
@@ -44,6 +47,24 @@ void init_Recipies(py::module &m) {
     :param nodes: The flatten operator to remove.
     :type nodes: list of :py:class:`aidge_core.Node`
     )mydelimiter");
+  m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
+    Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
+
+    :param nodes: The MatMul and Add nodes to fuse.
+    :type nodes: list of :py:class:`aidge_core.Node`
+    )mydelimiter");
 
+  m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
+    Recipie to remove a flatten operator.
+
+    :param graph_view: Graph view on which we want to apply the recipie
+    :type graph_view: :py:class:`aidge_core.GraphView`
+    )mydelimiter");
+  m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
+    Recipie to remove a flatten operator.
+
+    :param nodes: The flatten operator to remove.
+    :type nodes: list of :py:class:`aidge_core.Node`
+    )mydelimiter");
 }
 } // namespace Aidge
diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp
new file mode 100644
index 000000000..2250bf87c
--- /dev/null
+++ b/src/recipies/FuseBatchNorm.cpp
@@ -0,0 +1,146 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+#include <set>
+#include <cassert>
+#include <memory>
+#include <string>
+#include "aidge/operator/FC.hpp"
+#include "aidge/operator/BatchNorm.hpp"
+#include "aidge/operator/Conv.hpp"
+
+#include "aidge/utils/Recipies.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/operator/Producer.hpp"
+#include "aidge/operator/GenericOperator.hpp"
+// Graph Regex
+#include "aidge/graphmatching/GRegex.hpp"
+#include "aidge/graphmatching/NodeRegex.hpp"
+using namespace Aidge;
+
+void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){
+
+    assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
+
+    // Assert the nodes types are correct to be fused
+    std::shared_ptr<Node> conv;
+    std::shared_ptr<Node> batchnorm;
+    for (const auto& element : nodes) {
+        assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace");
+        if (element->type() == "Conv"){
+            conv = element;
+        }
+        else if (element->type() == "BatchNorm") {
+            batchnorm = element;
+        }
+    }
+    // TODO : check if batchnorm is the only child of the Conv or FC
+    std::shared_ptr<Tensor> scale  = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second);
+    std::shared_ptr<Tensor> shift  = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second);
+    std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second);
+    std::shared_ptr<Tensor> b_var  = batchnorm->input(4).first->getOperator()->getOutput(batchnorm->input(4).second);
+
+
+    // TODO : Find a way to remove the template
+    const float epsilon = std::static_pointer_cast<BatchNorm_Op<2>>(batchnorm->getOperator())->get<float>("Epsilon");
+    DimSize_t convOutDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<DimSize_t>("OutChannels");
+
+
+    assert(scale->size()  == convOutDims);
+    assert(shift->size()  == convOutDims);
+    assert(b_mean->size() == convOutDims);
+    assert(b_var->size()  == convOutDims);
+    assert(epsilon > 0.0);
+    // TODO : no no_bias attribute ?
+    float meanVariance = 0.0;
+    unsigned int count = 0;
+
+    for (std::size_t output = 0; output < convOutDims; ++output) {
+        // TODO : get suppose datatype is float ..
+        if (b_var->get<float>(output) > 1.0e-12) {
+            meanVariance += b_var->get<float>(output);
+            ++count;
+        }
+        else {
+            printf("Zero-variance: %s [%lu]\n", conv->name().c_str(), output);
+        }
+    }
+    if (count > 0)
+        meanVariance /= count;
+    else {
+        printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n");
+    }
+
+    const DimSize_t channelsSize = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<DimSize_t>("InChannels");
+
+    // TODO : suppose we have Conv2D ...
+    const std::array<DimSize_t, 2> kernelDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<std::array<DimSize_t, 2>>("KernelDims");
+
+    std::shared_ptr<Tensor> weight  = conv->input(1).first->getOperator()->getOutput(conv->input(1).second);
+    std::shared_ptr<Tensor> bias  = conv->input(2).first->getOperator()->getOutput(conv->input(2).second);
+
+    for (std::size_t output = 0; output < convOutDims; ++output) {
+        // Corrected for zero-variance issue:
+        // "A Quantization-Friendly Separable Convolution for MobileNets"
+        // https://arxiv.org/pdf/1803.08607.pdf
+        // to help post-training quantization
+        const float factor = scale->get<float>(output)
+            / std::sqrt(epsilon + ((b_var->get<float>(output) > 1.0e-12 || count == 0)
+                        ? b_var->get<float>(output) : meanVariance));
+        // Weights adjustments
+        for (std::size_t channel = 0; channel < channelsSize; ++channel) {
+            // TODO : Suppose kerneldims = 2
+            for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
+                for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
+                    std::vector<DimSize_t> currentIdx = {output, channel, k0, k1};
+                    // TODO : suppose weights are float
+                    float weightValue = weight->get<float>(currentIdx);
+                    weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
+                }
+            }
+        }
+
+        // TODO : check if noBias==true is set, then set biasValue to 0
+        float biasValue = bias->get<float>(output);
+
+        biasValue = shift->get<float>(output) + (biasValue - b_mean->get<float>(output)) * factor;
+
+        bias->set<float>(output, biasValue);
+
+    }
+    auto g = std::make_shared<GraphView>();
+    g->add(std::set<std::shared_ptr<Node>>({
+        batchnorm,
+        batchnorm->input(1).first,
+        batchnorm->input(2).first,
+        batchnorm->input(3).first,
+        batchnorm->input(4).first
+    }));
+    g->replaceWith({});
+
+}
+
+void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){
+    std::map<std::string,NodeRegex*> nodesRegex ;
+    nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm");
+    nodesRegex["Conv"] = new NodeRegex("Conv");
+    nodesRegex["FC"] = new NodeRegex("FC");
+
+
+    std::vector<std::string> seqRegex;
+    seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC)
+    GRegex GReg(nodesRegex, seqRegex);
+    Match matches = GReg.match(graphView);
+    std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
+    for (size_t i = 0; i < matches.getNbMatch(); ++i) {
+        fuseBatchNorm(matchNodes[i]);
+    }
+}
-- 
GitLab