diff --git a/src/recipes/ConvToMatMul.cpp b/src/recipes/ConvToMatMul.cpp index 9b88ffc73204b44cf857213d1fdfff49b3191f73..31462861e1bbe29cb467ad719576ec86c2d46f7f 100644 --- a/src/recipes/ConvToMatMul.cpp +++ b/src/recipes/ConvToMatMul.cpp @@ -24,7 +24,7 @@ #include "aidge/recipes/Recipes.hpp" size_t Aidge::convToMatMul(std::shared_ptr<GraphView> graphView) { - const auto matches = SinglePassGraphMatching(graphView).match("Conv"); + const auto matches = SinglePassGraphMatching(graphView).match("Conv2D"); size_t nbReplaced = 0; for (const auto& match : matches) { diff --git a/src/recipes/FuseBatchNorm.cpp b/src/recipes/FuseBatchNorm.cpp index 34722c19f8c0fddaffa7357136f1512a027e1617..4c4de25282c487d023f9c184b015ac332e716b7b 100644 --- a/src/recipes/FuseBatchNorm.cpp +++ b/src/recipes/FuseBatchNorm.cpp @@ -190,9 +190,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, } void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { - auto matches = SinglePassGraphMatching(graphView).match("(Conv|ConvDepthWise|PaddedConv|PaddedConvDepthWise)->BatchNorm"); + auto matches = SinglePassGraphMatching(graphView).match("(Conv2D|ConvDepthWise2D|PaddedConv2D|PaddedConvDepthWise2D)->BatchNorm2D"); for (auto match : matches) { + fmt::println("Match !"); auto rootNode = match.graph->rootNode(); fuseBatchNorm(rootNode, *rootNode->getChildren().begin()); }