diff --git a/aidge_core/unit_tests/test_recipes.py b/aidge_core/unit_tests/test_recipes.py index 8a0a470221e118fd450be7a7bf1bf6ede2df6178..c8dd4c727fbaf8224e8d04111a5054caeb5e5c99 100644 --- a/aidge_core/unit_tests/test_recipes.py +++ b/aidge_core/unit_tests/test_recipes.py @@ -65,7 +65,7 @@ class test_recipes(unittest.TestCase): graph_view.add(b1) old_nodes = graph_view.get_nodes() - aidge_core.fuse_mul_add(graph_view) + aidge_core.matmul_to_fc(graph_view) self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2) self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()]) diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 205c9f966b7d7cf984dd591daf110d1304216ec0..c42b285dacb6c59c5fa30388c268f1680152a5e0 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -31,18 +31,14 @@ void constantFolding(std::shared_ptr<GraphView> graph); * * @param nodes Strict set of Node to merge. */ -//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); - -void fuseMulAdd(std::shared_ptr<MatchSolution> solution); - -void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); +void matMulToFC(std::shared_ptr<Node> matmul, std::shared_ptr<Node> add = nullptr); /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * * @param graphView Graph view to use graph matching on, in order to apply transformations. */ -void fuseMulAdd(std::shared_ptr<GraphView> graphView); +void matMulToFC(std::shared_ptr<GraphView> graphView); /** * @brief Remove a node type. diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index c0392287a756b6272a59275b6d12b3a70c1c9420..1c04a320d85a833cc3c0b666390edc7a8648214b 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -25,14 +25,14 @@ void init_Recipes(py::module &m) { - m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( + m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. :param graph_view: Graph view on which we want to apply the recipe :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + // m.def("matmul_to_fc", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(matMulToFC), py::arg("nodes"), R"mydelimiter( // recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // :param nodes: The MatMul and Add nodes to fuse. @@ -84,13 +84,6 @@ void init_Recipes(py::module &m) // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); - // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - // Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. - - // :param nodes: The MatMul and Add nodes to fuse. - // :type nodes: list of :py:class:`aidge_core.Node` - // )mydelimiter"); - m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( Recipe to remove a flatten operator. diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/MatMulToFC.cpp similarity index 60% rename from src/recipes/FuseMulAdd.cpp rename to src/recipes/MatMulToFC.cpp index 6112fc47ece6bb361ebad626be7b5a6b1c2189bd..9b5addd3bb971b3f61980a582d4cce6435c57219 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/MatMulToFC.cpp @@ -22,28 +22,29 @@ #include "aidge/operator/GenericOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/operator/MatMul.hpp" +#include "aidge/graph/Matching.hpp" -//Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" - -void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) { //std::set<std::shared_ptr<Node>> nodes){ +void Aidge::matMulToFC(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) { // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - - assert((matmulNode->type() == "MatMul" && addNode->type() == "Add") && "Wrong type for the nodes to replace"); + AIDGE_ASSERT((matmulNode->type() == "MatMul" && (addNode == nullptr || addNode->type() == "Add")), + "Wrong type for the nodes to replace: {} and {}", + matmulNode->type(), (addNode) ? addNode->type() : "nullptr"); // Step 1 : Create FC // Fetch the output dimension throught the bias size std::shared_ptr<Node> bias = nullptr; - if (addNode->getParent(0) == matmulNode) { - AIDGE_ASSERT(addNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe."); - bias = addNode->getParent(1); - } - else if (addNode->getParent(1) == matmulNode) { - AIDGE_ASSERT(addNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe."); - bias = addNode->getParent(0); + if (addNode) { + if (addNode->getParent(0) == matmulNode) { + AIDGE_ASSERT(addNode->getParent(1), "No bias detected to produce the matMulToFC recipe."); + bias = addNode->getParent(1); + } + else if (addNode->getParent(1) == matmulNode) { + AIDGE_ASSERT(addNode->getParent(0), "No bias detected to produce the matMulToFC recipe."); + bias = addNode->getParent(0); + } } std::shared_ptr<Node> weight = nullptr; @@ -75,24 +76,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< } AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator."); - // TODO: find another way to get OutChannels for FC operator. - // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output - DimSize_t outSize = 0; - AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); - const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator()); - for (size_t i = 0; i < op->nbInputs(); i++) - { - const auto& inTensor = op->getInput(i); - if(inTensor->nbDims() > 0) { - outSize = inTensor->dims()[inTensor->nbDims()-1]; - break; - } - } - AIDGE_ASSERT(outSize, "Could not get output number of channels for FC operator."); - // Instanciate FC std::string fcName = matmulNode->name(); - if (!addNode->name().empty()) { + if (addNode && !addNode->name().empty()) { fcName += "_" + addNode->name(); } @@ -105,43 +91,26 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< bias->cloneSharedOperators()->addChild(fc, 0, 2); } - // Step 3 : Update all graphviews that contains at least one node to replace // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory? - auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1), fc->getParent(2)}); - GraphView::replace({matmulNode, addNode, bias, weight}, newNodes); - -} - - -void Aidge::fuseMulAdd(std::shared_ptr<Aidge::MatchSolution> solution){ - - assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); - assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); - - for (const auto& matmulNode : solution->at("MatMul")) { - for (const auto& addNode : solution->at("Add")) { - fuseMulAdd(matmulNode,addNode); - } + if (addNode) { + auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1), fc->getParent(2)}); + GraphView::replace({matmulNode, addNode, bias, weight}, newNodes); + } + else { + auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1)}); + GraphView::replace({matmulNode, weight}, newNodes); } -} - - -void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){ - - - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey("Add","getType($) =='Add'"); - regex->setNodeKey("MatMul","getType($) =='MatMul'"); - regex->addQuery("MatMul -> Add ;"); - - for (const auto& solution : regex->match(graphView)) { - - fuseMulAdd(solution); +} +void Aidge::matMulToFC(std::shared_ptr<Aidge::GraphView> graphView){ + const auto matches = SinglePassGraphMatching(graphView).match("MatMul->Add#?"); + for (const auto& match : matches) { + const auto it = match.anchors.find("Add"); + matMulToFC(match.graph->rootNode(), (it != match.anchors.end()) ? it->second.at("#") : nullptr); } } diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index fbbc3f766857f15af0da8004c35078993d71e973..e05e105d34a981e33cc1a0baaffa2702f1f6bbbb 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -189,7 +189,7 @@ TEST_CASE("GraphRegexUser") { kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); kitchenBook->setNodeKey("FC","getType($) =='FC'"); - kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); + //kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); kitchenBook->appliedRecipes(g); diff --git a/unit_tests/recipes/Test_FuseMulAdd.cpp b/unit_tests/recipes/Test_FuseMulAdd.cpp deleted file mode 100644 index 9ea151039f07e5c688572d61b746d8fc26f1c3fe..0000000000000000000000000000000000000000 --- a/unit_tests/recipes/Test_FuseMulAdd.cpp +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <catch2/catch_test_macros.hpp> -#include <set> - -#include "aidge/data/Tensor.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/operator/Add.hpp" -#include "aidge/operator/FC.hpp" -#include "aidge/operator/MatMul.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/recipes/Recipes.hpp" - -namespace Aidge { - - -TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") { - // generate the original GraphView - auto matmul0 = MatMul("matmul0"); - auto add0 = Add(2, "add0"); - auto matmul1 = MatMul("matmul1"); - auto add1 = Add(2, "add1"); - - auto b0 = Producer({5}, "B0"); - auto w0 = Producer({5, 5}, "W0"); - auto b1 = Producer({5}, "B1"); - auto w1 = Producer({5,5},"W1"); - auto input = Producer({2,5}, "input"); - - input->addChild(matmul0, 0, 0); - w0->addChild(matmul0, 0, 1); - - matmul0->addChild(add0, 0, 0); - b0->addChild(add0, 0, 1); - - add0->addChild(matmul1, 0, 1); - w1->addChild(matmul1, 0, 0); - - matmul1->addChild(add1, 0, 0); - b1->addChild(add1, 0, 1); - - auto g = std::make_shared<GraphView>(); - g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1}); - - // Check original graph - REQUIRE(g->getNodes() == - std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); - REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0))); - REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0))); - REQUIRE(((matmul1->getParent(1) == add0) && (matmul1->getParent(0) == w1))); - REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); - - // Transform GraphView inplace - fuseMulAdd(g); - - // Check new GraphView - std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); - REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); - REQUIRE(newNodes.size() == 6); - for (const auto& node : newNodes) { - REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); - } -} - -} // namespace Aidge diff --git a/unit_tests/recipes/Test_MatMulToFC.cpp b/unit_tests/recipes/Test_MatMulToFC.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2adf882ca69e0d5ca5f050d1b89cfb09d81b536b --- /dev/null +++ b/unit_tests/recipes/Test_MatMulToFC.cpp @@ -0,0 +1,118 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <set> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/recipes/Recipes.hpp" + +namespace Aidge { + + +TEST_CASE("[cpu/recipes] MatMulToFC", "[MatMulToFC][recipes]") { + SECTION("with Add") { + // generate the original GraphView + auto matmul0 = MatMul("matmul0"); + auto add0 = Add(2, "add0"); + auto matmul1 = MatMul("matmul1"); + auto add1 = Add(2, "add1"); + + auto b0 = Producer({5}, "B0"); + auto w0 = Producer({5, 5}, "W0"); + auto b1 = Producer({5}, "B1"); + auto w1 = Producer({5,5},"W1"); + auto input = Producer({2,5}, "input"); + + input->addChild(matmul0, 0, 0); + w0->addChild(matmul0, 0, 1); + + matmul0->addChild(add0, 0, 0); + b0->addChild(add0, 0, 1); + + add0->addChild(matmul1, 0, 1); + w1->addChild(matmul1, 0, 0); + + matmul1->addChild(add1, 0, 0); + b1->addChild(add1, 0, 1); + + auto g = std::make_shared<GraphView>(); + g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1}); + + // Check original graph + REQUIRE(g->getNodes() == + std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); + REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0))); + REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0))); + REQUIRE(((matmul1->getParent(1) == add0) && (matmul1->getParent(0) == w1))); + REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); + + // Transform GraphView inplace + matMulToFC(g); + + // Check new GraphView + std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); + REQUIRE(newNodes.size() == 6); + for (const auto& node : newNodes) { + REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); + } + } + + SECTION("without Add") { + // generate the original GraphView + auto matmul0 = MatMul("matmul0"); + auto matmul1 = MatMul("matmul1"); + auto add1 = Add(2, "add1"); + + auto w0 = Producer({5, 5}, "W0"); + auto b1 = Producer({5}, "B1"); + auto w1 = Producer({5,5},"W1"); + auto input = Producer({2,5}, "input"); + + input->addChild(matmul0, 0, 0); + w0->addChild(matmul0, 0, 1); + + matmul0->addChild(matmul1, 0, 1); + w1->addChild(matmul1, 0, 0); + + matmul1->addChild(add1, 0, 0); + b1->addChild(add1, 0, 1); + + auto g = std::make_shared<GraphView>(); + g->add({w0, matmul0, w1, matmul1, b1, add1}); + + // Check original graph + REQUIRE(g->getNodes() == + std::set<std::shared_ptr<Node>>({w0, matmul0, w1, matmul1, b1, add1})); + REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0))); + REQUIRE(((matmul1->getParent(1) == matmul0) && (matmul1->getParent(0) == w1))); + REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); + + // Transform GraphView inplace + matMulToFC(g); + + // Check new GraphView + std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, w1, matmul1, b1, add1})); + REQUIRE(newNodes.size() == 5); + for (const auto& node : newNodes) { + REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); + } + } +} + +} // namespace Aidge