diff --git a/include/aidge/graphRegex/GraphRegex.hpp b/include/aidge/graphRegex/GraphRegex.hpp index 12a5139a36135979639d2447b869568b943ee840..b62a42fcfeb258e5c659eaeb6681190482f37aa4 100644 --- a/include/aidge/graphRegex/GraphRegex.hpp +++ b/include/aidge/graphRegex/GraphRegex.hpp @@ -11,6 +11,11 @@ namespace Aidge{ +/** + * type for recipes function use in query and resolve +*/ +using RecipesFunctionType = std::function<void(std::shared_ptr<MatchSolution>)>; + /** * @brief class which is the hight level interface for graph matching, used to simplify match definition * @@ -19,9 +24,10 @@ class GraphRegex{ private: - std::vector<std::string> mQuery; + //std::vector<std::string> mQuery; std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest; std::map<std::string, std::function<bool(NodePtr)>> mAllLambda; + std::map<std::string,RecipesFunctionType> mQueryRecipe; public: GraphRegex(){}; @@ -31,7 +37,15 @@ class GraphRegex{ * @brief add a topology query to the match * @param query the topology query to find **/ - void addQuery(const std::string query); + //void addQuery(const std::string query); + + /** + * @brief add a topology query to the match and a function for recipe + * @param query the topology query to find + * @param f the funct + **/ + void addQuery(const std::string query,RecipesFunctionType f = nullptr); + /** * @brief get all the types of a graph and set it as type key in the query @@ -53,13 +67,19 @@ class GraphRegex{ **/ void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); - /*** + /** * @brief brief match the queries in the graph - * @param Reference the graph were the querys in search + * @param ref the graph were the querys in search * @return the result */ std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref); + /*** + * @brief match the queries in the graph and applied the recipes fuction + * @param ref the graph were the querys in search + */ + void appliedRecipes(std::shared_ptr<GraphView> ref); + private: void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index, diff --git a/src/graphRegex/GraphRegex.cpp b/src/graphRegex/GraphRegex.cpp index ef0db8c88f3e753f9b9633b1ffb05bbec6d00424..9a9b53da615f77dbdb8e597763411a2e84920b2a 100644 --- a/src/graphRegex/GraphRegex.cpp +++ b/src/graphRegex/GraphRegex.cpp @@ -26,10 +26,17 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ -void GraphRegex::addQuery(const std::string query){ - mQuery.push_back(query); -} +// void GraphRegex::addQuery(const std::string query){ +// //TODO one query only but the same string is a same query but +// //2 different string it's maybe the same query , we need to check the AST +// mQueryRecipe[query] = nullptr; +// } + +void GraphRegex::addQuery(const std::string query,RecipesFunctionType f ){ + mQueryRecipe[query] = f; + +} // Function to generate all combinations of n elements from a set @@ -87,7 +94,9 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph std::vector<std::shared_ptr<MatchSolution>> solutions = {}; - for (const std::string& query : mQuery) { + //for (const std::string& query : mQuery) { + for (auto it = mQueryRecipe.begin(); it != mQueryRecipe.end(); ++it) { + const std::string query = it->first; std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); @@ -108,6 +117,15 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph return _findLargestCompatibleSet(solutions); } +void GraphRegex::appliedRecipes(std::shared_ptr<GraphView> ref){ + std::set<std::shared_ptr<MatchSolution>> matchRef = match(ref); + for (const auto& solution : matchRef) { + if(mQueryRecipe[solution->getQuery()] != nullptr){ + mQueryRecipe[solution->getQuery()](solution); + } + } +} + void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){ mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions)); _majConditionalInterpreterLambda();