diff --git a/include/aidge/graphRegex/GraphRegex.hpp b/include/aidge/graphRegex/GraphRegex.hpp index 12a5139a36135979639d2447b869568b943ee840..b62a42fcfeb258e5c659eaeb6681190482f37aa4 100644 --- a/include/aidge/graphRegex/GraphRegex.hpp +++ b/include/aidge/graphRegex/GraphRegex.hpp @@ -11,6 +11,11 @@ namespace Aidge{ +/** + * type for recipes function use in query and resolve +*/ +using RecipesFunctionType = std::function<void(std::shared_ptr<MatchSolution>)>; + /** * @brief class which is the hight level interface for graph matching, used to simplify match definition * @@ -19,9 +24,10 @@ class GraphRegex{ private: - std::vector<std::string> mQuery; + //std::vector<std::string> mQuery; std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest; std::map<std::string, std::function<bool(NodePtr)>> mAllLambda; + std::map<std::string,RecipesFunctionType> mQueryRecipe; public: GraphRegex(){}; @@ -31,7 +37,15 @@ class GraphRegex{ * @brief add a topology query to the match * @param query the topology query to find **/ - void addQuery(const std::string query); + //void addQuery(const std::string query); + + /** + * @brief add a topology query to the match and a function for recipe + * @param query the topology query to find + * @param f the funct + **/ + void addQuery(const std::string query,RecipesFunctionType f = nullptr); + /** * @brief get all the types of a graph and set it as type key in the query @@ -53,13 +67,19 @@ class GraphRegex{ **/ void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); - /*** + /** * @brief brief match the queries in the graph - * @param Reference the graph were the querys in search + * @param ref the graph were the querys in search * @return the result */ std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref); + /*** + * @brief match the queries in the graph and applied the recipes fuction + * @param ref the graph were the querys in search + */ + void appliedRecipes(std::shared_ptr<GraphView> ref); + private: void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index, diff --git a/src/graphRegex/GraphRegex.cpp b/src/graphRegex/GraphRegex.cpp index ef0db8c88f3e753f9b9633b1ffb05bbec6d00424..9a9b53da615f77dbdb8e597763411a2e84920b2a 100644 --- a/src/graphRegex/GraphRegex.cpp +++ b/src/graphRegex/GraphRegex.cpp @@ -26,10 +26,17 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ -void GraphRegex::addQuery(const std::string query){ - mQuery.push_back(query); -} +// void GraphRegex::addQuery(const std::string query){ +// //TODO one query only but the same string is a same query but +// //2 different string it's maybe the same query , we need to check the AST +// mQueryRecipe[query] = nullptr; +// } + +void GraphRegex::addQuery(const std::string query,RecipesFunctionType f ){ + mQueryRecipe[query] = f; + +} // Function to generate all combinations of n elements from a set @@ -87,7 +94,9 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph std::vector<std::shared_ptr<MatchSolution>> solutions = {}; - for (const std::string& query : mQuery) { + //for (const std::string& query : mQuery) { + for (auto it = mQueryRecipe.begin(); it != mQueryRecipe.end(); ++it) { + const std::string query = it->first; std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); @@ -108,6 +117,15 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph return _findLargestCompatibleSet(solutions); } +void GraphRegex::appliedRecipes(std::shared_ptr<GraphView> ref){ + std::set<std::shared_ptr<MatchSolution>> matchRef = match(ref); + for (const auto& solution : matchRef) { + if(mQueryRecipe[solution->getQuery()] != nullptr){ + mQueryRecipe[solution->getQuery()](solution); + } + } +} + void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){ mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions)); _majConditionalInterpreterLambda(); diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index b30560ea3ea696821d2422bf760a11973a104e85..c9feded0b39410815a49cccb2d5746491c2f8cb1 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -2,6 +2,15 @@ #include <catch2/catch_test_macros.hpp> #include "aidge/graphRegex/GraphRegex.hpp" + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Recipies.hpp" + #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" @@ -47,12 +56,8 @@ TEST_CASE("GraphRegexUser") { } SECTION("CC") { - - - std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); @@ -81,4 +86,58 @@ TEST_CASE("GraphRegexUser") { } } + + + SECTION("Applied Recipes"){ + + // generate the original GraphView + auto matmul0 = MatMul(5, "matmul0"); + auto add0 = Add<2>("add0"); + auto matmul1 = MatMul(5, "matmul1"); + auto add1 = Add<2>("add1"); + + auto b0 = Producer({5}, "B0"); + auto w0 = Producer({5, 5}, "W0"); + auto b1 = Producer({5}, "B1"); + auto w1 = Producer({5,5},"W1"); + auto input = Producer({2,5}, "input"); + + input->addChild(matmul0, 0, 0); + w0->addChild(matmul0, 0, 1); + + matmul0->addChild(add0, 0, 0); + b0->addChild(add0, 0, 1); + + add0->addChild(matmul1, 0, 0); + w1->addChild(matmul1, 0, 1); + + matmul1->addChild(add1, 0, 0); + b1->addChild(add1, 0, 1); + + auto fc = GenericOperator("FC", 1, 1, 1, "c"); + auto fl = GenericOperator("Flatten", 1, 1, 1, "c"); + + + auto g = std::make_shared<GraphView>(); + g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc}); + + std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); + + kitchenBook->setNodeKey("Add","getType($) =='Add'"); + kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'"); + kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); + kitchenBook->setNodeKey("FC","getType($) =='FC'"); + + kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); + kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); + + kitchenBook->appliedRecipes(g); + + std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc})); + //REQUIRE(newNodes.size() == 6); + + + } + } \ No newline at end of file