diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index c110c9cf8e2ccc84112f7ac48b438f470ee21465..4b21a4f59366d6dca1b75dc353dc00c9a78960f3 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -17,6 +17,8 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" + namespace Aidge{ @@ -27,7 +29,12 @@ namespace Aidge{ * * @param nodes Strict set of Node to merge. */ -void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); +//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); + +void fuseMulAdd(std::shared_ptr<MatchSolution> solution); + +void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); + /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 528d57e31a5ecf3f5a633a20205e79f7926a1f61..8061659670106d1745e967ab443e6b9a60e2cb6d 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -25,27 +25,17 @@ // Graph Regex #include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/NodeRegex.hpp" +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + using namespace Aidge; -void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ +void Aidge::fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add){//std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ? + assert((matmul->type() == "MatMul" && add->type() == "Add") && "Wrong type for the nodes to replace"); - // Step 0 : Assert the nodes types are correct to be fused - std::shared_ptr<Node> add; - std::shared_ptr<Node> matmul; - for (const auto& element : nodes) { - assert((element->type() == "MatMul" || element->type() == "Add") && "Wrong type for the nodes to replace"); - if (element->type() == "MatMul"){ - matmul = element; - } - else if (element->type() == "Add") { - add = element; - } - } // Step 1 : Create FC // Fetch the output dimension throught the bias size @@ -78,17 +68,55 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ } + +void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); + assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); + + for (const auto& matmul : solution->at("MatMul")) { + for (const auto& add : solution->at("Add")) { + fuseMulAdd(matmul,add); + } + } +} + + void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ - std::map<std::string,NodeRegex*> nodesRegex ; - nodesRegex["MatMul"] = new NodeRegex("MatMul"); - nodesRegex["Add"] = new NodeRegex("Add"); - std::vector<std::string> seqRegex; - seqRegex.push_back("MatMul -> Add;"); - GRegex GReg(nodesRegex, seqRegex); - Match matches = GReg.match(graphView); - std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - for (size_t i = 0; i < matches.getNbMatch(); ++i) { - fuseMulAdd(matchNodes[i]); + // std::map<std::string,NodeRegex*> nodesRegex ; + // nodesRegex["MatMul"] = new NodeRegex("MatMul"); + // nodesRegex["Add"] = new NodeRegex("Add"); + + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Add","getType($) =='Add'"); + regex->setNodeKey("MatMul","getType($) =='MatMul'"); + regex->addQuery("MatMul -> Add ;"); + + for (const auto& solution : regex->match(graphView)) { + + fuseMulAdd(solution); + + // // solution->at("MatMul"); + // // solution->at("Add"); + // assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); + // assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); + // for (const auto& matmul : solution->at("MatMul")) { + // for (const auto& add : solution->at("Add")) { + // fuseMulAdd(matmul,add); + // } + // } + + } + + + // std::vector<std::string> seqRegex; + // seqRegex.push_back("MatMul -> Add;"); + // GRegex GReg(nodesRegex, seqRegex); + // Match matches = GReg.match(graphView); + // std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + // for (size_t i = 0; i < matches.getNbMatch(); ++i) { + // fuseMulAdd(matchNodes[i]); + // } }