diff --git a/aidge_core/unit_tests/test_recipes.py b/aidge_core/unit_tests/test_recipes.py index 8a0a470221e118fd450be7a7bf1bf6ede2df6178..c8dd4c727fbaf8224e8d04111a5054caeb5e5c99 100644 --- a/aidge_core/unit_tests/test_recipes.py +++ b/aidge_core/unit_tests/test_recipes.py @@ -65,7 +65,7 @@ class test_recipes(unittest.TestCase): graph_view.add(b1) old_nodes = graph_view.get_nodes() - aidge_core.fuse_mul_add(graph_view) + aidge_core.matmul_to_fc(graph_view) self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2) self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()]) diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 205c9f966b7d7cf984dd591daf110d1304216ec0..c42b285dacb6c59c5fa30388c268f1680152a5e0 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -31,18 +31,14 @@ void constantFolding(std::shared_ptr<GraphView> graph); * * @param nodes Strict set of Node to merge. */ -//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); - -void fuseMulAdd(std::shared_ptr<MatchSolution> solution); - -void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); +void matMulToFC(std::shared_ptr<Node> matmul, std::shared_ptr<Node> add = nullptr); /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * * @param graphView Graph view to use graph matching on, in order to apply transformations. */ -void fuseMulAdd(std::shared_ptr<GraphView> graphView); +void matMulToFC(std::shared_ptr<GraphView> graphView); /** * @brief Remove a node type. diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index c0392287a756b6272a59275b6d12b3a70c1c9420..1c04a320d85a833cc3c0b666390edc7a8648214b 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -25,14 +25,14 @@ void init_Recipes(py::module &m) { - m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( + m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. :param graph_view: Graph view on which we want to apply the recipe :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + // m.def("matmul_to_fc", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(matMulToFC), py::arg("nodes"), R"mydelimiter( // recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // :param nodes: The MatMul and Add nodes to fuse. @@ -84,13 +84,6 @@ void init_Recipes(py::module &m) // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); - // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - // Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. - - // :param nodes: The MatMul and Add nodes to fuse. - // :type nodes: list of :py:class:`aidge_core.Node` - // )mydelimiter"); - m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( Recipe to remove a flatten operator. diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/MatMulToFC.cpp similarity index 63% rename from src/recipes/FuseMulAdd.cpp rename to src/recipes/MatMulToFC.cpp index 6112fc47ece6bb361ebad626be7b5a6b1c2189bd..af775ced740e69e34ec8b728b8482226a98fb155 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/MatMulToFC.cpp @@ -22,28 +22,27 @@ #include "aidge/operator/GenericOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/operator/MatMul.hpp" +#include "aidge/graph/Matching.hpp" -//Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" - -void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) { //std::set<std::shared_ptr<Node>> nodes){ +void Aidge::matMulToFC(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) { // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - - assert((matmulNode->type() == "MatMul" && addNode->type() == "Add") && "Wrong type for the nodes to replace"); + AIDGE_ASSERT((matmulNode->type() == "MatMul" && (addNode == nullptr || addNode->type() == "Add")), "Wrong type for the nodes to replace"); // Step 1 : Create FC // Fetch the output dimension throught the bias size std::shared_ptr<Node> bias = nullptr; - if (addNode->getParent(0) == matmulNode) { - AIDGE_ASSERT(addNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe."); - bias = addNode->getParent(1); - } - else if (addNode->getParent(1) == matmulNode) { - AIDGE_ASSERT(addNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe."); - bias = addNode->getParent(0); + if (addNode) { + if (addNode->getParent(0) == matmulNode) { + AIDGE_ASSERT(addNode->getParent(1), "No bias detected to produce the matMulToFC recipe."); + bias = addNode->getParent(1); + } + else if (addNode->getParent(1) == matmulNode) { + AIDGE_ASSERT(addNode->getParent(0), "No bias detected to produce the matMulToFC recipe."); + bias = addNode->getParent(0); + } } std::shared_ptr<Node> weight = nullptr; @@ -75,24 +74,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< } AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator."); - // TODO: find another way to get OutChannels for FC operator. - // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output - DimSize_t outSize = 0; - AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); - const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator()); - for (size_t i = 0; i < op->nbInputs(); i++) - { - const auto& inTensor = op->getInput(i); - if(inTensor->nbDims() > 0) { - outSize = inTensor->dims()[inTensor->nbDims()-1]; - break; - } - } - AIDGE_ASSERT(outSize, "Could not get output number of channels for FC operator."); - // Instanciate FC std::string fcName = matmulNode->name(); - if (!addNode->name().empty()) { + if (addNode && !addNode->name().empty()) { fcName += "_" + addNode->name(); } @@ -105,7 +89,6 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< bias->cloneSharedOperators()->addChild(fc, 0, 2); } - // Step 3 : Update all graphviews that contains at least one node to replace // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview @@ -115,33 +98,11 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< } +void Aidge::matMulToFC(std::shared_ptr<Aidge::GraphView> graphView){ + const auto matches = SinglePassGraphMatching(graphView).match("MatMul->Add#?"); -void Aidge::fuseMulAdd(std::shared_ptr<Aidge::MatchSolution> solution){ - - assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); - assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); - - for (const auto& matmulNode : solution->at("MatMul")) { - for (const auto& addNode : solution->at("Add")) { - fuseMulAdd(matmulNode,addNode); - } - } -} - - -void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){ - - - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey("Add","getType($) =='Add'"); - regex->setNodeKey("MatMul","getType($) =='MatMul'"); - regex->addQuery("MatMul -> Add ;"); - - for (const auto& solution : regex->match(graphView)) { - - fuseMulAdd(solution); - - - + for (const auto& match : matches) { + const auto it = match.anchors.find("Add"); + matMulToFC(match.startNode, (it != match.anchors.end()) ? it->second.at("#") : nullptr); } } diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index fbbc3f766857f15af0da8004c35078993d71e973..e05e105d34a981e33cc1a0baaffa2702f1f6bbbb 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -189,7 +189,7 @@ TEST_CASE("GraphRegexUser") { kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); kitchenBook->setNodeKey("FC","getType($) =='FC'"); - kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); + //kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); kitchenBook->appliedRecipes(g); diff --git a/unit_tests/recipes/Test_FuseMulAdd.cpp b/unit_tests/recipes/Test_MatMulToFC.cpp similarity index 96% rename from unit_tests/recipes/Test_FuseMulAdd.cpp rename to unit_tests/recipes/Test_MatMulToFC.cpp index 9ea151039f07e5c688572d61b746d8fc26f1c3fe..358204da6cba4c2d047e4d4f4b8ca3f6a06ebb54 100644 --- a/unit_tests/recipes/Test_FuseMulAdd.cpp +++ b/unit_tests/recipes/Test_MatMulToFC.cpp @@ -23,7 +23,7 @@ namespace Aidge { -TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") { +TEST_CASE("[cpu/recipes] MatMulToFC", "[MatMulToFC][recipes]") { // generate the original GraphView auto matmul0 = MatMul("matmul0"); auto add0 = Add(2, "add0"); @@ -60,7 +60,7 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") { REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); // Transform GraphView inplace - fuseMulAdd(g); + matMulToFC(g); // Check new GraphView std::set<std::shared_ptr<Node>> newNodes = g->getNodes();