diff --git a/include/aidge/graphRegex/GraphRegex.hpp b/include/aidge/graphRegex/GraphRegex.hpp index 12a5139a36135979639d2447b869568b943ee840..b62a42fcfeb258e5c659eaeb6681190482f37aa4 100644 --- a/include/aidge/graphRegex/GraphRegex.hpp +++ b/include/aidge/graphRegex/GraphRegex.hpp @@ -11,6 +11,11 @@ namespace Aidge{ +/** + * type for recipes function use in query and resolve +*/ +using RecipesFunctionType = std::function<void(std::shared_ptr<MatchSolution>)>; + /** * @brief class which is the hight level interface for graph matching, used to simplify match definition * @@ -19,9 +24,10 @@ class GraphRegex{ private: - std::vector<std::string> mQuery; + //std::vector<std::string> mQuery; std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest; std::map<std::string, std::function<bool(NodePtr)>> mAllLambda; + std::map<std::string,RecipesFunctionType> mQueryRecipe; public: GraphRegex(){}; @@ -31,7 +37,15 @@ class GraphRegex{ * @brief add a topology query to the match * @param query the topology query to find **/ - void addQuery(const std::string query); + //void addQuery(const std::string query); + + /** + * @brief add a topology query to the match and a function for recipe + * @param query the topology query to find + * @param f the funct + **/ + void addQuery(const std::string query,RecipesFunctionType f = nullptr); + /** * @brief get all the types of a graph and set it as type key in the query @@ -53,13 +67,19 @@ class GraphRegex{ **/ void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); - /*** + /** * @brief brief match the queries in the graph - * @param Reference the graph were the querys in search + * @param ref the graph were the querys in search * @return the result */ std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref); + /*** + * @brief match the queries in the graph and applied the recipes fuction + * @param ref the graph were the querys in search + */ + void appliedRecipes(std::shared_ptr<GraphView> ref); + private: void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index, diff --git a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp index 3e63f92337f6394382f6d92ef9f6dd7b5098a454..a6cc3e59247d4be98caa9881182bfba1c44e0178 100644 --- a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp @@ -116,7 +116,7 @@ namespace Aidge{ }; /** - * @brief class spesialisation for not commun node (node that must be match one Unique) transition + * @brief class specialization for not commun node (node that must be match one Unique) transition */ class FsmEdgeUnique:public FsmEdge { @@ -127,7 +127,7 @@ namespace Aidge{ }; /** - * @brief class spesialisation for commun node transition + * @brief class specialization for commun node transition * @see FsmEdge */ class FsmEdgeCommon:public FsmEdge @@ -181,7 +181,7 @@ namespace Aidge{ }; /** - * @brief class spesialisation for ref empty transition + * @brief class specialization for ref empty transition * @see FsmEdge */ class FsmEdgeEmpty:public FsmEdge @@ -195,6 +195,20 @@ namespace Aidge{ }; + /** + * @brief class specialization for ref empty transition + * @see FsmEdge + */ + class FsmEdgeNone:public FsmEdge + { + + public: + FsmEdgeNone(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> /*stmContext*/) override; + + }; + + //////////////////////// // FACTORY diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 72058dfcba6e811a01a22e261208741879638cad..baa7a486ced74375792cf1ebd3b2f7056168f027 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -78,12 +78,6 @@ public: void computeOutputDims() override final { // Forward dims of micro-graph mGraph->forwardDims(); - - // Associate outputs to micro-graph outputs for custom implementation - for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { - const auto& outputOp = mOutputOps[outputIdx]; - mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); - } } bool outputDimsForwarded() const override final { return !(mOutputs[0]->empty()); } diff --git a/src/graphRegex/GraphFsmInterpreter.cpp b/src/graphRegex/GraphFsmInterpreter.cpp index 03e86487513065af47d91fc5265335bba456e64e..18b768c6567e64caf6841ed4a339f13fd16f69d6 100644 --- a/src/graphRegex/GraphFsmInterpreter.cpp +++ b/src/graphRegex/GraphFsmInterpreter.cpp @@ -128,7 +128,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fs for(auto valid : allValid){ if(haveCommon){ /* - the // quantif case + the // quantify case get the go back and make a lexeme id(number) we need to go back to the ref delta min #TODO */ @@ -145,7 +145,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fs edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::REF,mNodesCondition, lexem.str()); }else{ /* - the sequensial quantif case + the sequencial quantify case no reference to common */ edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::EMPTY,mNodesCondition,""); diff --git a/src/graphRegex/GraphRegex.cpp b/src/graphRegex/GraphRegex.cpp index ef0db8c88f3e753f9b9633b1ffb05bbec6d00424..9a9b53da615f77dbdb8e597763411a2e84920b2a 100644 --- a/src/graphRegex/GraphRegex.cpp +++ b/src/graphRegex/GraphRegex.cpp @@ -26,10 +26,17 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ -void GraphRegex::addQuery(const std::string query){ - mQuery.push_back(query); -} +// void GraphRegex::addQuery(const std::string query){ +// //TODO one query only but the same string is a same query but +// //2 different string it's maybe the same query , we need to check the AST +// mQueryRecipe[query] = nullptr; +// } + +void GraphRegex::addQuery(const std::string query,RecipesFunctionType f ){ + mQueryRecipe[query] = f; + +} // Function to generate all combinations of n elements from a set @@ -87,7 +94,9 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph std::vector<std::shared_ptr<MatchSolution>> solutions = {}; - for (const std::string& query : mQuery) { + //for (const std::string& query : mQuery) { + for (auto it = mQueryRecipe.begin(); it != mQueryRecipe.end(); ++it) { + const std::string query = it->first; std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); @@ -108,6 +117,15 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph return _findLargestCompatibleSet(solutions); } +void GraphRegex::appliedRecipes(std::shared_ptr<GraphView> ref){ + std::set<std::shared_ptr<MatchSolution>> matchRef = match(ref); + for (const auto& solution : matchRef) { + if(mQueryRecipe[solution->getQuery()] != nullptr){ + mQueryRecipe[solution->getQuery()](solution); + } + } +} + void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){ mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions)); _majConditionalInterpreterLambda(); diff --git a/src/graphRegex/matchFsm/FsmEdge.cpp b/src/graphRegex/matchFsm/FsmEdge.cpp index ab307e023209ab770fc63f0550811279bd42eb46..d16dcf9505f5c3324fa621df2895065b7b019e19 100644 --- a/src/graphRegex/matchFsm/FsmEdge.cpp +++ b/src/graphRegex/matchFsm/FsmEdge.cpp @@ -226,6 +226,14 @@ const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext> } return {true,std::set<NodePtr>({opNode})};//none } +////////////// + +FsmEdgeNone::FsmEdgeNone(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest) +:FsmEdge(source,dest,nullptr) +{} + const EdgeTestResult FsmEdgeNone::test(const std::shared_ptr<FsmRunTimeContext> /*stmContext*/){ + return {false,std::set<NodePtr>()}; + } /// factory std::shared_ptr<FsmEdge> FsmEdgeFactory::make( @@ -260,7 +268,10 @@ const std::string lexeme) std::string commonKey = edgeType + std::to_string(commonIdx); if(allTest.find(edgeType) == allTest.end()){ - throw std::invalid_argument("Bad Node Test " + edgeType ); + //if the key is not linked to a condition + //by default, it is initialized by a edge that is always false + return std::make_shared<FsmEdgeNone>(source, dest); + //throw std::invalid_argument("Bad Node Test " + edgeType ); } return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey); @@ -274,7 +285,11 @@ const std::string lexeme) std::string edgeType = m[1]; if(allTest.find(edgeType) == allTest.end()){ - throw std::invalid_argument("Bad Node Test " + edgeType ); + + //if the key is not linked to a condition + //by default, it is initialized by a edge that is always false + return std::make_shared<FsmEdgeNone>(source, dest); + //throw std::invalid_argument("Bad Node Test " + edgeType ); } return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType)); diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index c1f58c68686d9359fa3b8ea4b5eb54244e988895..23a98152a2b155b5e059c25e616eee47040c0aed 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -22,10 +22,6 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< for (std::size_t i = 0; i < mInputs.size(); ++i) { mInputs[i] = std::make_shared<Tensor>(); } - mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->outputs().size()); - for (std::size_t i = 0; i < mOutputs.size(); ++i) { - mOutputs[i] = std::make_shared<Tensor>(); - } // Fill inputsNodes and outputsNodes when there is no ambiguity if (inputNodes.empty()) { @@ -46,7 +42,7 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph"); const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); - + int inputIdx = 0; // input idx relative to the current node for (const auto& in : inputNodeinputs) { if (in.first == nullptr || !mGraph->inView(in.first)) { @@ -71,8 +67,15 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< } } + AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size()); AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size()); + mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->outputs().size()); + // Associate outputs to micro-graph outputs for custom implementation + for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { + const auto& outputOp = mOutputOps[outputIdx]; + mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); + } } Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { @@ -114,7 +117,7 @@ void Aidge::MetaOperator_Op::updateConsummerProducer() { // Lazy initialization mScheduler = std::make_shared<SequentialScheduler>(mGraph); } - + // TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule. // It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()" mScheduler->generateScheduling(); diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index b30560ea3ea696821d2422bf760a11973a104e85..19859fd16345ff7f8d85b24e43d23c02f9ec22ee 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -2,6 +2,15 @@ #include <catch2/catch_test_macros.hpp> #include "aidge/graphRegex/GraphRegex.hpp" + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Recipies.hpp" + #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" @@ -46,13 +55,9 @@ TEST_CASE("GraphRegexUser") { } - SECTION("CC") { - - - + SECTION("2 query") { std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); @@ -81,4 +86,93 @@ TEST_CASE("GraphRegexUser") { } } + + + SECTION("Not define node Test") { + + //test if the FC is not define only match query not query2 + std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 1, 1, "c3"); + + g1->add(conv); + g1->addChild(conv1, "c"); + g1->addChild(conv2, "c1"); + g1->addChild(conv3, "c2"); + + + //sut->setKeyFromGraph(g1); + + const std::string query = "Conv->Conv"; + const std::string query2 = "Conv->FC"; + + sut->setNodeKey("Conv","getType($) =='Conv'"); + + sut->addQuery(query); + sut->addQuery(query2); + + + for (const auto& solution : sut->match(g1)) { + REQUIRE(solution->getQuery() == query); + } + + } + + + SECTION("Applied Recipes"){ + + // generate the original GraphView + auto matmul0 = MatMul(5, "matmul0"); + auto add0 = Add<2>("add0"); + auto matmul1 = MatMul(5, "matmul1"); + auto add1 = Add<2>("add1"); + + auto b0 = Producer({5}, "B0"); + auto w0 = Producer({5, 5}, "W0"); + auto b1 = Producer({5}, "B1"); + auto w1 = Producer({5,5},"W1"); + auto input = Producer({2,5}, "input"); + + input->addChild(matmul0, 0, 0); + w0->addChild(matmul0, 0, 1); + + matmul0->addChild(add0, 0, 0); + b0->addChild(add0, 0, 1); + + add0->addChild(matmul1, 0, 0); + w1->addChild(matmul1, 0, 1); + + matmul1->addChild(add1, 0, 0); + b1->addChild(add1, 0, 1); + + auto fc = GenericOperator("FC", 1, 1, 1, "c"); + auto fl = GenericOperator("Flatten", 1, 1, 1, "c"); + + + auto g = std::make_shared<GraphView>(); + g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc}); + + std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); + + kitchenBook->setNodeKey("Add","getType($) =='Add'"); + kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'"); + kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); + kitchenBook->setNodeKey("FC","getType($) =='FC'"); + + kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); + kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); + + kitchenBook->appliedRecipes(g); + + std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc})); + //REQUIRE(newNodes.size() == 6); + + + } + } \ No newline at end of file