From b2b97cfc6d191a8fb8bb9adfa853265c2e5cbc45 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 22 Aug 2024 11:25:40 +0200
Subject: [PATCH] Fix fuseBatchNorm() using alternate graph matching

---
 include/aidge/recipes/Recipes.hpp |  4 ---
 src/recipes/FuseBatchNorm.cpp     | 47 ++++---------------------------
 2 files changed, 6 insertions(+), 45 deletions(-)

diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp
index e33abcaeb..205c9f966 100644
--- a/include/aidge/recipes/Recipes.hpp
+++ b/include/aidge/recipes/Recipes.hpp
@@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView);
  */
 void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm);
 
-
-
-void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
-
 /**
  * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
  * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
diff --git a/src/recipes/FuseBatchNorm.cpp b/src/recipes/FuseBatchNorm.cpp
index aa20a056a..e1553fda5 100644
--- a/src/recipes/FuseBatchNorm.cpp
+++ b/src/recipes/FuseBatchNorm.cpp
@@ -16,6 +16,7 @@
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/graph/Node.hpp"
+#include "aidge/graph/Matching.hpp"
 #include "aidge/operator/BatchNorm.hpp"
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/ConvDepthWise.hpp"
@@ -25,9 +26,6 @@
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/Types.h"
 
-// Graph Regex
-#include "aidge/graphRegex/GraphRegex.hpp"
-
 void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
                           std::shared_ptr<Aidge::Node> batchnormNode) {
     // Case: convNode is a MetaOperator ending with a Convolution
@@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
 
 }
 
-void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
-    assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
-    assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
-
-    for (const auto& op : solution->at("OP")) {
-        if (op->getOperator()->isAtomic()) {
-            for (const auto& batchNorm : solution->at("BatchNorm")) {
-                fuseBatchNorm(op, batchNorm);
-            }
-        } else {  // op is a MetaOperator
-            auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
-            if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
-                ((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
-                  Conv_Op<2>::Type) ||
-                 (metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
-                  ConvDepthWise_Op<2>::Type))) {
-                for (const auto& batchNorm : solution->at("BatchNorm")) {
-                    fuseBatchNorm(op, batchNorm);
-                }
-            }
-        }
-    }
-}
-
 void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
-    std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
-    regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
-    fmt::print("\n============================\nSearching for solutions\n==============================\n");
-    regex->setNodeKey(
-            "OP",
-            "getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
-            //  || getType($) =='FC' ");
-
-    regex->addQuery("OP -> BatchNorm");
-
-    for (const auto& solution : regex->match(graphView)) {
-
-        fuseBatchNorm(solution);
+    auto matches = SinglePassGraphMatching(graphView).match("(Conv|ConvDepthWise|PaddedConv|PaddedConvDepthWise)->BatchNorm");
 
+    for (auto match : matches) {
+        auto rootNode = match.graph->rootNode();
+        fuseBatchNorm(rootNode, *rootNode->getChildren().begin());
     }
-}
\ No newline at end of file
+}
-- 
GitLab