From b2b97cfc6d191a8fb8bb9adfa853265c2e5cbc45 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 22 Aug 2024 11:25:40 +0200 Subject: [PATCH] Fix fuseBatchNorm() using alternate graph matching --- include/aidge/recipes/Recipes.hpp | 4 --- src/recipes/FuseBatchNorm.cpp | 47 ++++--------------------------- 2 files changed, 6 insertions(+), 45 deletions(-) diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index e33abcaeb..205c9f966 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView); */ void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm); - - -void fuseBatchNorm(std::shared_ptr<MatchSolution> solution); - /** * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ diff --git a/src/recipes/FuseBatchNorm.cpp b/src/recipes/FuseBatchNorm.cpp index aa20a056a..e1553fda5 100644 --- a/src/recipes/FuseBatchNorm.cpp +++ b/src/recipes/FuseBatchNorm.cpp @@ -16,6 +16,7 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" +#include "aidge/graph/Matching.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" @@ -25,9 +26,6 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" -// Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" - void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr<Aidge::Node> batchnormNode) { // Case: convNode is a MetaOperator ending with a Convolution @@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, } -void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) { - assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n"); - assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n"); - - for (const auto& op : solution->at("OP")) { - if (op->getOperator()->isAtomic()) { - for (const auto& batchNorm : solution->at("BatchNorm")) { - fuseBatchNorm(op, batchNorm); - } - } else { // op is a MetaOperator - auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator()); - if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) && - ((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() == - Conv_Op<2>::Type) || - (metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() == - ConvDepthWise_Op<2>::Type))) { - for (const auto& batchNorm : solution->at("BatchNorm")) { - fuseBatchNorm(op, batchNorm); - } - } - } - } -} - void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'"); - fmt::print("\n============================\nSearching for solutions\n==============================\n"); - regex->setNodeKey( - "OP", - "getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'"); - // || getType($) =='FC' "); - - regex->addQuery("OP -> BatchNorm"); - - for (const auto& solution : regex->match(graphView)) { - - fuseBatchNorm(solution); + auto matches = SinglePassGraphMatching(graphView).match("(Conv|ConvDepthWise|PaddedConv|PaddedConvDepthWise)->BatchNorm"); + for (auto match : matches) { + auto rootNode = match.graph->rootNode(); + fuseBatchNorm(rootNode, *rootNode->getChildren().begin()); } -} \ No newline at end of file +} -- GitLab