diff --git a/include/aidge/graphRegex/GraphFsmInterpreter.hpp b/include/aidge/graphRegex/GraphFsmInterpreter.hpp index 9e92b6fe8fc9d5e44cb8051e687e33d7192e0eb7..e2fd43b9e641e8cb4a695e3a3eecf5975610d564 100644 --- a/include/aidge/graphRegex/GraphFsmInterpreter.hpp +++ b/include/aidge/graphRegex/GraphFsmInterpreter.hpp @@ -19,13 +19,16 @@ namespace Aidge { std::size_t mActGroupe; std::map<std::string,std::shared_ptr<ConditionalInterpreter>> mNodesCondition; + const std::string mGraphMatchExpr; public: - GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition); + GraphFsmInterpreter(const std::string graphMatchExpr,std::vector<std::shared_ptr<ConditionalInterpreter>> & nodesCondition); virtual ~GraphFsmInterpreter() =default; std::shared_ptr<FsmGraph> interpret(void); + + private: diff --git a/include/aidge/graphRegex/GraphLexer.hpp b/include/aidge/graphRegex/GraphLexer.hpp index e4137ab093c466b7349007da91e032dae48eda51..bd65dfc15d18533676b19e148a98185d3844acbd 100644 --- a/include/aidge/graphRegex/GraphLexer.hpp +++ b/include/aidge/graphRegex/GraphLexer.hpp @@ -36,6 +36,9 @@ namespace Aidge { bool isEnd(void); + const std::string getQuery(); + + /** * @brief Get the representation of the class * @return string @@ -46,7 +49,7 @@ namespace Aidge { /** * @brief Constructs an error message to display the character not understood by the lexer - * @return error mesage + * @return error message */ std::runtime_error badTokenError(const std::string& currentChars,std::size_t position); diff --git a/include/aidge/graphRegex/GraphParser.hpp b/include/aidge/graphRegex/GraphParser.hpp index 73406203a8be87e1df75cc694ab1ff281c27fbfa..29ee8c7b294eae2b8d8196de1702cb7e194cfa84 100644 --- a/include/aidge/graphRegex/GraphParser.hpp +++ b/include/aidge/graphRegex/GraphParser.hpp @@ -30,6 +30,13 @@ class GraphParser{ std::shared_ptr<AstNode<gRegexTokenTypes>> parse(void); + /** + * @brief get the query that be use in the parsing + * @return query + */ + const std::string getQuery(); + + private: /** * @brief restart at the start of the ConditionalExpressions for LEXER and restart mCurrentToken diff --git a/include/aidge/graphRegex/GraphRegex.hpp b/include/aidge/graphRegex/GraphRegex.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12a5139a36135979639d2447b869568b943ee840 --- /dev/null +++ b/include/aidge/graphRegex/GraphRegex.hpp @@ -0,0 +1,87 @@ +#ifndef AIDGE_CORE_GRAPH_REGEX_H_ +#define AIDGE_CORE_GRAPH_REGEX_H_ + +#include <string> + +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge{ + +/** + * @brief class which is the hight level interface for graph matching, used to simplify match definition + * + */ +class GraphRegex{ + + private: + + std::vector<std::string> mQuery; + std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest; + std::map<std::string, std::function<bool(NodePtr)>> mAllLambda; + + public: + GraphRegex(){}; + virtual ~GraphRegex() = default; + + /** + * @brief add a topology query to the match + * @param query the topology query to find + **/ + void addQuery(const std::string query); + + /** + * @brief get all the types of a graph and set it as type key in the query + * @param Reference graph use to get all the node types + **/ + void setKeyFromGraph(std::shared_ptr<GraphView> ref); + + /** + * @brief set a node test manually + * @param key the ref of this test used in the query + * @param ConditionalExpressions expression to test the node + **/ + void setNodeKey(const std::string key, const std::string conditionalExpressions ); + + /** + * @brief set a specific lambda that can be used in setQueryKey + * @param key ref to the lambda to use in the + * @param f expression to test the node ConditionalExpressions + **/ + void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); + + /*** + * @brief brief match the queries in the graph + * @param Reference the graph were the querys in search + * @return the result + */ + std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref); + + private: + + void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index, + std::vector<NodePtr>& current, std::set<std::vector<NodePtr>>& combinations); + + + + void _findLargestCompatibleSet( + const std::vector<std::shared_ptr<MatchSolution>>& solutions, + std::set<std::shared_ptr<MatchSolution>>& currentSet, + std::set<std::shared_ptr<MatchSolution>>& largestSet, + size_t currentIndex + ); + + std::set<std::shared_ptr<MatchSolution>> _findLargestCompatibleSet( + const std::vector<std::shared_ptr<MatchSolution>>& solutions + ); + + void _majConditionalInterpreterLambda(); + +}; +} + + +#endif //AIDGE_CORE_GRAPH_REGEX_H_ \ No newline at end of file diff --git a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp index c3eae528808dbdb8023718c961b7c45cbf4afac9..3e63f92337f6394382f6d92ef9f6dd7b5098a454 100644 --- a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp @@ -87,7 +87,7 @@ namespace Aidge{ * @brief set a new source to the edge * @return FsmNode */ - void reSetSouceNode(const std::shared_ptr<FsmNode>& newSource); + void reSetSourceNode(const std::shared_ptr<FsmNode>& newSource); /** * @brief get dest FsmNode * @return FsmNode diff --git a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp index 0a74551367dd492cb0abb820e4c5ce5a601d071e..d718009e87e5360981ff93ff808124581917c089 100644 --- a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp @@ -18,78 +18,89 @@ class FsmGraph { private: /** - * @brief all node origine + * @brief all node Origin */ - std::set<std::size_t> mAllOrigine; + std::set<std::size_t> mAllOrigin; std::set<std::shared_ptr<FsmEdge>> mEdges; + + + const std::string mQuery; + public: - FsmGraph(/* args */); + + FsmGraph(const std::string query); virtual ~FsmGraph() = default; -std::shared_ptr<MatchResult> test(std::vector<NodePtr>& StartNodes); - - - -const std::set<std::shared_ptr<FsmEdge>>& getEdge(void); -/** - * @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph -*/ -void addEdge(std::shared_ptr<FsmEdge>& edge); - -/** - * @brief get the liste of the starting states - * @details we need to use a vector because the order of the nodes is important for start node initialization \ref test() -*/ -const std::vector<std::shared_ptr<FsmNode>> getStartNodes(void); - -/** - * @brief get the set of the valide states - * @return set of valide state -*/ -const std::set<std::shared_ptr<FsmNode>> getValidNodes(void); - -/** - * @brief get the set of all the node in the graph - * @return set of all nodes -*/ -const std::set<std::shared_ptr<FsmNode>> getNodes(void); - -/** - * @brief set a groupe idx for all the nodes in the graph -*/ -void setGroupe(std::size_t groupeIdx); - -/** - * @brief make the union beteen this graph and an input graph - * @param fsmGraph graph to union -*/ -void unionG(const std::shared_ptr<FsmGraph> fsmGraph); - - -/** - * @brief make the union beteen this graph and an input graph and merge the valide state to the start state - * @param fsmGraph graph to merge -*/ -void mergeOneStartOneValid(const std::shared_ptr< FsmGraph> fsmGraph); -/** - * @brief get the number of sub FSM - * @return number of sub Fsm -*/ -std::size_t getNbSubFsm(void); - -/** - * @brief increment the origine of all node in the graph - * @param incr the incrémentation value -*/ -void incOrigineAllNodeBy(std::size_t incr); + std::vector<std::shared_ptr<MatchSolution>> test(const std::vector<NodePtr>& StartNodes); -private: -/** - * @brief merge tow node of the graph - * @param node -*/ -void _mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + + const std::set<std::shared_ptr<FsmEdge>>& getEdge(void); + /** + * @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph + */ + void addEdge(std::shared_ptr<FsmEdge>& edge); + + /** + * @brief get the list of the starting states + * @details we need to use a vector because the order of the nodes is important for start node initialization \ref test() + */ + const std::vector<std::shared_ptr<FsmNode>> getStartNodes(void); + + /** + * @brief get the set of the valid states + * @return set of valide state + */ + const std::set<std::shared_ptr<FsmNode>> getValidNodes(void); + + /** + * @brief get the set of all the node in the graph + * @return set of all nodes + */ + const std::set<std::shared_ptr<FsmNode>> getNodes(void); + + /** + * @brief set a groupe idx for all the nodes in the graph + */ + void setGroupe(std::size_t groupeIdx); + + /** + * @brief make the union between this graph and an input graph + * @param fsmGraph graph to union + */ + void unionG(const std::shared_ptr<FsmGraph> fsmGraph); + + + /** + * @brief make the union between this graph and an input graph and merge the valid state to the start state + * @param fsmGraph graph to merge + */ + void mergeOneStartOneValid(const std::shared_ptr< FsmGraph> fsmGraph); + /** + * @brief get the number of sub FSM + * @return number of sub Fsm + */ + std::size_t getNbSubFsm(void); + + /** + * @brief get the number of start state + * @return number of start state + */ + std::size_t getNbStart(void); + + /** + * @brief increment the origin of all nodes in the graph + * @param incr value + */ + void incOriginAllNodeBy(std::size_t incr); + + private: + + /** + * @brief merge tow node of the graph + * @param node + */ + void _mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); }; diff --git a/include/aidge/graphRegex/matchFsm/FsmNode.hpp b/include/aidge/graphRegex/matchFsm/FsmNode.hpp index 2776ff8eb297fd5ad9a4c425fb386adde0a25269..7987c5ce33522ca7d43de1918d53e68738af6d18 100644 --- a/include/aidge/graphRegex/matchFsm/FsmNode.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmNode.hpp @@ -33,7 +33,7 @@ namespace Aidge{ * @details a state can be and/or : * - a valide state, the match is valide if it stop on this edge * - a start state , the match start on this state - * The state is also define by this origine (is the unique id of it's expretion ) + * The state is also define by this Origin (is the unique id of it's expretion ) * and it's groupe (for inner expression TODO) */ class FsmNode : public std::enable_shared_from_this<FsmNode> @@ -49,8 +49,8 @@ namespace Aidge{ */ std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>> mParents; - std::size_t mOrigineStm = 0; - std::size_t mGroupeStm = 0; + std::size_t mOriginFsm = 0; + std::size_t mGroupeFsm = 0; bool mIsAValid; bool mIsAStart; @@ -59,7 +59,7 @@ namespace Aidge{ FsmNode(bool isAValid,bool isAStart ); virtual ~FsmNode() = default; /** - * @brief use to MAG the actual context , and return all the posible new context + * @brief use to MAG the actual context , and return all the possible new context * @details one input context can generate a multitude of contexts because a graph node * can have more than one child, and each traversal possibility is a new context. * @param actContext the actual context @@ -68,8 +68,8 @@ namespace Aidge{ const std::vector<std::shared_ptr<FsmRunTimeContext>> test( std::shared_ptr<FsmRunTimeContext>); - std::size_t getOrigine(void); - void incOrigine(std::size_t inc); + std::size_t getOrigin(void); + void incOrigin(std::size_t inc); void rmEdge(std::shared_ptr<FsmEdge>); diff --git a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp index 6f1b9fc2bfe68195f67cfc0bf17d57aed5345219..2f6066ba4cd97284c43b509c9d5eb988b65b53a5 100644 --- a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp @@ -152,7 +152,7 @@ namespace Aidge{ std::set<NodePtr> getValidNodes(void); std::set<NodePtr> getValidNodesNoCommon(void); - std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> getValid(void); + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& getValid(void); NodePtr getActNode(void); diff --git a/include/aidge/graphRegex/matchFsm/MatchResult.hpp b/include/aidge/graphRegex/matchFsm/MatchResult.hpp index ac2f2a627a9d88b3cabeac4b181af2f3b7566d72..29b9abb616a80899b9c2ad8d5e01e5f00e674757 100644 --- a/include/aidge/graphRegex/matchFsm/MatchResult.hpp +++ b/include/aidge/graphRegex/matchFsm/MatchResult.hpp @@ -11,9 +11,31 @@ namespace Aidge{ +/** + * @brief contained the result of one match and the associate key , the query and the start node +*/ + +class MatchSolution{ +private: + std::map<std::string,std::set<NodePtr>> mSolution; + const std::string mQueryFrom; + const std::vector<NodePtr> mStartNode; + +public: + MatchSolution(std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string query,const std::vector<NodePtr> startNode); + const std::set<NodePtr> & at(const std::string key); + const std::set<NodePtr> getAll(); + bool areCompatible(std::shared_ptr<MatchSolution> solution); + + const std::string& getQuery(){ return mQueryFrom ;} + const std::vector<NodePtr>& getStartNode(){ return mStartNode ;} + +}; + + /** * @brief class that old the result of a matching - * give acess to all node ant there tag in the expression + * give access to all node and there tag in the expression */ class MatchResult { @@ -22,17 +44,20 @@ private: std::vector<std::shared_ptr<FsmRunTimeContext>> mAllValid; /* - the Run time of eatch sub FSM , to have a valide match we need a set of one run time per FSM compatible - the id must be contigue + the Run time of each sub FSM , to have a valid match we need a set of one run time per FSM compatible + the id must be continue */ std::vector<std::vector<std::shared_ptr<FsmRunTimeContext>>> mIdToRunTime; - std::vector<std::set<NodePtr>> mSolve; + std::vector<std::shared_ptr<MatchSolution>> mSolve; std::size_t mNbSubStm; + + public: - MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm); + MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm, + const std::string& query,const std::vector<NodePtr>& startNodes); virtual ~MatchResult() = default; @@ -40,16 +65,18 @@ public: * @brief get the set of the node match for une expression * @return the set of node of the graph that corresponding to an expression */ - std::set<NodePtr> getBiggerSolution(void); + std::shared_ptr<MatchSolution> getBiggerSolution(void); + + std::vector<std::shared_ptr<MatchSolution>> getSolutions(void); private: /** - * @brief recurent function use to inite mSolve in the constructor + * @brief recurrent function use to init mSolve in the constructor * **/ -void _generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence); +void _generateCombination( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string& query,const std::vector<NodePtr>& startNodes); }; diff --git a/include/aidge/nodeTester/ConditionalInterpreter.hpp b/include/aidge/nodeTester/ConditionalInterpreter.hpp index 165fac1c2ae98bf76b73c039de9fc975e9845cc9..674e942c7b77a6e572b0ffbaa90a2571f7a8118a 100644 --- a/include/aidge/nodeTester/ConditionalInterpreter.hpp +++ b/include/aidge/nodeTester/ConditionalInterpreter.hpp @@ -198,6 +198,13 @@ class ConditionalRegisterFunction { */ ConditionalData* run(const std::string key,std::vector<ConditionalData*> & datas); + bool isLambdaRegister(const std::string &key) { + if(mWlambda.find(key) != mWlambda.end()){ + return true; + } + return false; + } + private: /// @brief map of name and the converted function. std::map<const std::string, std::function<ConditionalData*(std::vector<ConditionalData*> &)>> mWlambda; @@ -227,7 +234,7 @@ class ConditionalInterpreter * @brief the registery for the lambda fuction * @see ConditionalRegisterFunction */ - ConditionalRegisterFunction mLambdaRegiter; + ConditionalRegisterFunction mLambdaRegister; std::vector<ConditionalData*> mResolution ; @@ -241,15 +248,25 @@ class ConditionalInterpreter } public: + + const std::string mKey; + /** * @brief Constructor * @param ConditionalExpressions The expression of the test to be performed on the nodes */ - ConditionalInterpreter(const std::string ConditionalExpressions); + ConditionalInterpreter(const std::string key,const std::string ConditionalExpressions); ~ConditionalInterpreter(){clearRes();} + /** + * @brief get the condition key + * @return the key + */ + + const std::string& getKey(); + /** * @brief Test a node depending of the ConditionalExpressions * @details the AST is visit using \ref visit() whith the $ init whit the nodeOp @@ -266,7 +283,7 @@ class ConditionalInterpreter */ void insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f); - + bool isLambdaRegister(const std::string &key); ///// private: diff --git a/src/graphRegex/GraphFsmInterpreter.cpp b/src/graphRegex/GraphFsmInterpreter.cpp index 2984ab4fb3864244c9e32dbfcda9ef2ae080acf0..03e86487513065af47d91fc5265335bba456e64e 100644 --- a/src/graphRegex/GraphFsmInterpreter.cpp +++ b/src/graphRegex/GraphFsmInterpreter.cpp @@ -3,15 +3,24 @@ using namespace Aidge; -GraphFsmInterpreter::GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition):mParser(graphMatchExpr){ +GraphFsmInterpreter::GraphFsmInterpreter(const std::string graphMatchExpr,std::vector<std::shared_ptr<ConditionalInterpreter>>&nodesCondition):mParser(graphMatchExpr){ mActGroupe = 0; - mNodesCondition = nodesCondition; + + for (const auto &obj : nodesCondition) { + if(mNodesCondition.find(obj->getKey()) ==mNodesCondition.end()){ + mNodesCondition[obj->getKey()] = obj; + }else{ + throw std::logic_error("GraphFsmInterpreter Bad Key" ); + } + } } std::shared_ptr<FsmGraph> GraphFsmInterpreter::interpret(void){ mActGroupe = 0; std::shared_ptr<AstNode<gRegexTokenTypes>> tree = mParser.parse(); - return visit(tree); + std::shared_ptr<FsmGraph> out = visit(tree); + return out; } + std::shared_ptr<FsmGraph> GraphFsmInterpreter::visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree){ std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>> nextAstNodes = AstTree->getChilds(); @@ -44,7 +53,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::keyF(std::shared_ptr<AstNode<gReg std::shared_ptr<FsmNode> start = std::make_shared<FsmNode>(false,true); std::shared_ptr<FsmNode> valid = std::make_shared<FsmNode>(true,false); - std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(mParser.getQuery()); std::shared_ptr<FsmEdge> edge; @@ -66,7 +75,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::keyF(std::shared_ptr<AstNode<gReg std::shared_ptr<FsmGraph> GraphFsmInterpreter::sepF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm){ size_t idxLeft = leftFsm->getNbSubFsm(); - rigthFsm->incOrigineAllNodeBy(idxLeft); + rigthFsm->incOriginAllNodeBy(idxLeft); leftFsm->unionG(rigthFsm); //the rigthFsm is no longer usfull return leftFsm; diff --git a/src/graphRegex/GraphLexer.cpp b/src/graphRegex/GraphLexer.cpp index 61214f96a090fef5d28cb0ce1a009644d9570880..f504ad025940c88058ce5949259c464ae2cedfb6 100644 --- a/src/graphRegex/GraphLexer.cpp +++ b/src/graphRegex/GraphLexer.cpp @@ -133,6 +133,11 @@ bool GraphLexer::isEnd(void){ return mPosition >= mRegularExpressions.length(); } + +const std::string GraphLexer::getQuery(){ + return mRegularExpressions; +} + std::runtime_error GraphLexer::badTokenError(const std::string& currentChars,std::size_t position){ std::ostringstream errorMessage; errorMessage << "\nBad syntax " << currentChars << " :\n" << mRegularExpressions << "\n"; diff --git a/src/graphRegex/GraphParser.cpp b/src/graphRegex/GraphParser.cpp index 5aa653c482dae82c2e9fa02bfc36b2ffc821785f..9c3d10114d777cf7755432a5723a3b70b81d37a1 100644 --- a/src/graphRegex/GraphParser.cpp +++ b/src/graphRegex/GraphParser.cpp @@ -9,6 +9,10 @@ mLexer(gRegexExpressions) } +const std::string GraphParser::getQuery(){ + return mLexer.getQuery(); +} + std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::parse(void){ std::shared_ptr<AstNode<gRegexTokenTypes>> astTree = constructAstAllExpr(); diff --git a/src/graphRegex/GraphRegex.cpp b/src/graphRegex/GraphRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef0db8c88f3e753f9b9633b1ffb05bbec6d00424 --- /dev/null +++ b/src/graphRegex/GraphRegex.cpp @@ -0,0 +1,140 @@ +#include "aidge/graphRegex/GraphRegex.hpp" +using namespace Aidge; + + +void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ + + for (const NodePtr& node : ref->getNodes()) { + std::string type = node->type(); + bool isIn = false; + for(const auto &test:mAllTest){ + if(test->getKey() == type){ + isIn = true; + break; + } + } + if(!isIn){ + mAllTest.push_back(std::make_shared<ConditionalInterpreter>(type,"getType($) =='" + type + "'")); + } + // auto it = mAllTest.find(type); + // if (it == mAllTest.end()) { + // mAllTest[type] = std::make_shared<ConditionalInterpreter>(type,"getType($) =='" + type + "'"); + // } + // //if the key exist it's ok, but not make 2 ConditionalInterpreter + } +} + + + +void GraphRegex::addQuery(const std::string query){ + mQuery.push_back(query); +} + + + +// Function to generate all combinations of n elements from a set +void GraphRegex::_generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index, std::vector<NodePtr>& current, std::set<std::vector<NodePtr>>& combinations) { + if (n == 0) { + combinations.insert(current); + return; + } + for (auto it = elements.begin(); it != elements.end(); ++it) { + current.push_back(*it); + _generateCombinationsStart(elements, n - 1, index + 1, current, combinations); + current.pop_back(); + } +} + + +void GraphRegex::_findLargestCompatibleSet( + const std::vector<std::shared_ptr<MatchSolution>>& solutions, + std::set<std::shared_ptr<MatchSolution>>& currentSet, + std::set<std::shared_ptr<MatchSolution>>& largestSet, + size_t currentIndex +) { + if (currentIndex >= solutions.size()) { + if (currentSet.size() > largestSet.size()) { + largestSet = currentSet; + } + return; + } + + for (size_t i = currentIndex; i < solutions.size(); ++i) { + if (std::all_of(currentSet.begin(), currentSet.end(), + [&](const std::shared_ptr<MatchSolution>& solution) { + return solution->areCompatible(solutions[i]); + } + )) { + currentSet.insert(solutions[i]); + _findLargestCompatibleSet(solutions, currentSet, largestSet, i + 1); + currentSet.erase(solutions[i]); + } + } +} + +std::set<std::shared_ptr<MatchSolution>> GraphRegex::_findLargestCompatibleSet( + const std::vector<std::shared_ptr<MatchSolution>>& solutions +) { + std::set<std::shared_ptr<MatchSolution>> largestSet; + std::set<std::shared_ptr<MatchSolution>> currentSet; + _findLargestCompatibleSet(solutions, currentSet, largestSet, 0); + return largestSet; +} + + + +std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<GraphView> ref){ + + std::vector<std::shared_ptr<MatchSolution>> solutions = {}; + + for (const std::string& query : mQuery) { + + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + // generate all the start possibility + std::size_t nb_startSt = fsm->getNbStart(); + std::set<std::vector<NodePtr>> combinations; + std::vector<NodePtr> current; + _generateCombinationsStart(ref->getNodes(), nb_startSt, 0, current, combinations); + + + // all start + for (const auto& combination : combinations) { + std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination); + solutions.insert(solutions.end(), solution.begin(), solution.end()); + } + } + return _findLargestCompatibleSet(solutions); +} + +void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){ + mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions)); + _majConditionalInterpreterLambda(); +} + + +void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f){ + //we can applied to all key but it's not efficient + if(mAllLambda.find(key) != mAllLambda.end()){ + throw std::runtime_error(key + " is define"); + } + mAllLambda[key] = f; + _majConditionalInterpreterLambda(); +} + +void GraphRegex::_majConditionalInterpreterLambda(){ + + for (const auto& test : mAllTest) { + for (const auto& pair : mAllLambda) { + const std::string& key = pair.first; + const std::function<bool(NodePtr)>& lambda = pair.second; + + if(!test->isLambdaRegister(key)){ + test->insertLambda(key,lambda); + } + + } + } +} + diff --git a/src/graphRegex/matchFsm/FsmEdge.cpp b/src/graphRegex/matchFsm/FsmEdge.cpp index 593da06abe18576d435ae55718d379aa5b682d60..ab307e023209ab770fc63f0550811279bd42eb46 100644 --- a/src/graphRegex/matchFsm/FsmEdge.cpp +++ b/src/graphRegex/matchFsm/FsmEdge.cpp @@ -24,7 +24,7 @@ void FsmEdge::updateRelative( const std::map<size_t,int>& relativePos ){ std::shared_ptr<FsmNode> FsmEdge::getSourceNode(void){ return mNodeSource; } -void FsmEdge::reSetSouceNode(const std::shared_ptr<FsmNode>& newSource){ +void FsmEdge::reSetSourceNode(const std::shared_ptr<FsmNode>& newSource){ mNodeSource->rmEdge(shared_from_this()); mNodeSource = newSource; mNodeSource->addEdge(shared_from_this()); @@ -258,6 +258,11 @@ const std::string lexeme) std::string commonId = m[2]; size_t commonIdx = commonId.empty() ? 0 : std::stoi(commonId) + 1; std::string commonKey = edgeType + std::to_string(commonIdx); + + if(allTest.find(edgeType) == allTest.end()){ + throw std::invalid_argument("Bad Node Test " + edgeType ); + } + return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey); } else { throw std::invalid_argument("error lexem COMMON " + lexeme); @@ -267,6 +272,11 @@ const std::string lexeme) std::smatch m; if (std::regex_match(lexeme, m, uniqueRegex)) { std::string edgeType = m[1]; + + if(allTest.find(edgeType) == allTest.end()){ + throw std::invalid_argument("Bad Node Test " + edgeType ); + } + return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType)); } else { throw std::invalid_argument("error lexem UNIQUE \"" + std::string(lexeme) +" eee\""); diff --git a/src/graphRegex/matchFsm/FsmGraph.cpp b/src/graphRegex/matchFsm/FsmGraph.cpp index 5a9f00d728cd2cd9f58c2228361f8393de2a3d9d..a56474e042cc44a68938b1d19e19a0c6841cb8cb 100644 --- a/src/graphRegex/matchFsm/FsmGraph.cpp +++ b/src/graphRegex/matchFsm/FsmGraph.cpp @@ -4,12 +4,13 @@ using namespace Aidge; -FsmGraph::FsmGraph(/* args */){ +FsmGraph::FsmGraph(const std::string query):mQuery(query){ } //TODO - std::shared_ptr<MatchResult> FsmGraph::test(std::vector<NodePtr>& startNodes){ + std::vector<std::shared_ptr<MatchSolution>> FsmGraph::test(const std::vector<NodePtr>& startNodes){ + std::vector<std::shared_ptr<Aidge::FsmNode>> startNodesFsm = getStartNodes(); if(startNodes.size() != startNodesFsm.size()){ throw std::runtime_error("bad number of Start nodes"); @@ -60,9 +61,9 @@ FsmGraph::FsmGraph(/* args */){ walks.swap(nextWalks); nextWalks.clear(); } - - - return std::make_shared<MatchResult>(allValidContext,getNbSubFsm()); + + MatchResult allMatch(allValidContext,getNbSubFsm(),mQuery,startNodes); + return allMatch.getSolutions(); } @@ -77,8 +78,8 @@ const std::set<std::shared_ptr<FsmEdge>>& FsmGraph::getEdge(void){ void FsmGraph::addEdge(std::shared_ptr<FsmEdge>& edge){ edge->updateWeak(); mEdges.insert(edge); - mAllOrigine.insert(edge->getDestNode()->getOrigine()); - mAllOrigine.insert(edge->getSourceNode()->getOrigine()); + mAllOrigin.insert(edge->getDestNode()->getOrigin()); + mAllOrigin.insert(edge->getSourceNode()->getOrigin()); } const std::vector<std::shared_ptr<FsmNode>> FsmGraph::getStartNodes(void){ @@ -151,19 +152,23 @@ void FsmGraph::mergeOneStartOneValid(const std::shared_ptr<FsmGraph> fsmGraph){ } std::size_t FsmGraph::getNbSubFsm(void){ - return mAllOrigine.size(); + return mAllOrigin.size(); +} + +std::size_t FsmGraph::getNbStart(void){ + return getStartNodes().size(); } -void FsmGraph::incOrigineAllNodeBy(std::size_t incr){ +void FsmGraph::incOriginAllNodeBy(std::size_t incr){ std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); for(auto node :nodes){ - node->incOrigine(incr); + node->incOrigin(incr); } std::set<std::size_t> updatedOrigin; - for(auto origin : mAllOrigine){ + for(auto origin : mAllOrigin){ updatedOrigin.insert(origin + incr); } - mAllOrigine.swap(updatedOrigin); + mAllOrigin.swap(updatedOrigin); } void FsmGraph::_mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest){ @@ -187,7 +192,7 @@ void FsmGraph::_mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNod if(edge->getDestNode() == source ){ edge->reSetDestNode(dest); }else if(edge->getSourceNode() == source ){ - edge->reSetSouceNode(dest); + edge->reSetSourceNode(dest); } } diff --git a/src/graphRegex/matchFsm/FsmNode.cpp b/src/graphRegex/matchFsm/FsmNode.cpp index 84b4a0c3fdbe0730a12a2a62db9158e2538d646f..7bc4cf105b43a540bd0e9c686af35dd220611a09 100644 --- a/src/graphRegex/matchFsm/FsmNode.cpp +++ b/src/graphRegex/matchFsm/FsmNode.cpp @@ -53,11 +53,11 @@ const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared -std::size_t FsmNode::getOrigine(void){ - return mOrigineStm; +std::size_t FsmNode::getOrigin(void){ + return mOriginFsm; } -void FsmNode::incOrigine(std::size_t inc){ - mOrigineStm += inc; +void FsmNode::incOrigin(std::size_t inc){ + mOriginFsm += inc; } void FsmNode::rmEdge(std::shared_ptr<FsmEdge> edge){ mEdges.erase(edge); @@ -93,7 +93,7 @@ const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& FsmNode::getEdges(v } void FsmNode::setGroupe(std::size_t groupeIdx){ - mGroupeStm = groupeIdx; + mGroupeFsm = groupeIdx; } diff --git a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp index 787cf2322a5b8e7001cdc59325345000dbb61553..ddf6a46cc7c75dc853d71ba98b051b4263a31164 100644 --- a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp +++ b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp @@ -155,7 +155,7 @@ void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpr } std::size_t FsmRunTimeContext::getSubStmId(void){ - return mActState->getOrigine(); + return mActState->getOrigin(); } NodePtr FsmRunTimeContext::getCommonNodeFromIdx(std::size_t commonIdx){ @@ -207,7 +207,7 @@ std::set<NodePtr> FsmRunTimeContext::getValidNodesNoCommon(void){ return differenceSet; } -std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> FsmRunTimeContext::getValid(void){ +std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& FsmRunTimeContext::getValid(void){ return mValidNodes; } diff --git a/src/graphRegex/matchFsm/MatchResult.cpp b/src/graphRegex/matchFsm/MatchResult.cpp index c35f1a7348e365baa8a27854ee6b0a833e342ee7..c871b3d0e22f3fa1f28b7bcea46ee8b9f61a3178 100644 --- a/src/graphRegex/matchFsm/MatchResult.cpp +++ b/src/graphRegex/matchFsm/MatchResult.cpp @@ -2,10 +2,63 @@ using namespace Aidge; + MatchSolution::MatchSolution(std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string query,const std::vector<NodePtr> startNode):mQueryFrom(query),mStartNode(startNode){ + //reformat the solution + for (const auto& context : precedence) { + for (const auto& pair : context->getValid()) { + + if(mSolution.find(pair.first->getKey()) == mSolution.end()){ + mSolution[pair.first->getKey()] = pair.second; + }else{ + mSolution[pair.first->getKey()].insert(pair.second.begin(), pair.second.end()); + } + } + } + } + + + const std::set<NodePtr> & MatchSolution::at(const std::string key){ + + return mSolution[key]; -MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm):mIdToRunTime(nbSubStm){ + } + + const std::set<NodePtr> MatchSolution::getAll(){ + + // Create a unique set to store all the elements + std::set<NodePtr> uniqueSet; + + // Iterate through the map and insert elements from each set into the unique set + for (const auto& pair : mSolution) { + const std::set<NodePtr>& nodeSet = pair.second; + + // Insert elements from the current set into the unique set + uniqueSet.insert(nodeSet.begin(), nodeSet.end()); + } + + return uniqueSet; + + } + + bool MatchSolution::areCompatible(std::shared_ptr<MatchSolution> solution){ + std::set<NodePtr> set1 = solution->getAll(); + std::set<NodePtr> set2 = getAll(); + std::set<NodePtr> intersection ; + std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(), std::inserter(intersection, intersection.begin())); + if (intersection.empty()) { + return true; + } + return false; + } + + + +//////////////////////////////// +// +//////////////////////////////// +MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm, +const std::string& query,const std::vector<NodePtr>& startNodes):mIdToRunTime(nbSubStm),mNbSubStm(nbSubStm){ mAllValid = allValid; - mNbSubStm = nbSubStm; //mIdToRunTimm for (const auto& contextPtr : allValid) { @@ -13,25 +66,26 @@ MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allVali } std::vector<std::shared_ptr<FsmRunTimeContext>> precedence; - //make all solution posible - _generateCombinationd(0,precedence); + //make all solution possible + _generateCombination(0,precedence,query,startNodes); //sort by solution number of elements - std::sort(mSolve.begin(), mSolve.end(), [](const std::set<NodePtr>& set1, const std::set<NodePtr>& set2) { - return set1.size() < set2.size(); + std::sort(mSolve.begin(), mSolve.end(), [](std::shared_ptr<MatchSolution>& set1, std::shared_ptr<MatchSolution>& set2) { + return set1->getAll().size() < set2->getAll().size(); }); } -void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence){ +void MatchResult::_generateCombination( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence, +const std::string& query,const std::vector<NodePtr>& startNodes){ //it's end , we are below the number of stm if (idxSubStm == mNbSubStm) { - //precedence containe a liste of FSM compatible, we just need to - //check if all the node have been valide by at least one contetext + //precedence contain a list of FSM compatible, we just need to + //check if all the nodes have been validated by at least one context - //1) make the set of all node for the comput graph that are valide in all the FsmRunTimeContext + //1) make the set of all node for the compute graph that are valid in all the FsmRunTimeContext std::set<NodePtr> validNode; std::set<NodePtr> rejectNode; for (const auto& contextPtr : precedence) { @@ -40,11 +94,11 @@ void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std: std::set<NodePtr> tmpR = contextPtr->getRejectedNodes(); rejectNode.insert(tmpR.begin(),tmpR.end()); } - // 2) all RejectedNodes need to be valide by an others stm + // 2) all RejectedNodes need to be valid by an others stm // if it's not the case the match is not valid if(std::includes(validNode.begin(), validNode.end(), rejectNode.begin(), rejectNode.end())){ //we can save the solution - mSolve.push_back(validNode); + mSolve.push_back(std::make_shared<MatchSolution>(precedence,query,startNodes)); } precedence.pop_back(); return; @@ -55,10 +109,10 @@ void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std: { if(idxSubStm == 0){ precedence.push_back(contextPtrOneFsm); - _generateCombinationd(idxSubStm+1,precedence); + _generateCombination(idxSubStm+1,precedence,query,startNodes); }else{ - //test if the new context is compatible whith all the context in the precedence + //test if the new context is compatible with all the context in the precedence // bool compatibleSolutionFsm = true; for (const auto& contextPtrOfOtherFsm : precedence) { @@ -70,7 +124,7 @@ void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std: if(compatibleSolutionFsm){ precedence.push_back(contextPtrOneFsm); - _generateCombinationd(idxSubStm+1,precedence); + _generateCombination(idxSubStm+1,precedence,query,startNodes); } } @@ -83,11 +137,16 @@ void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std: } -std::set<NodePtr> MatchResult::getBiggerSolution(void){ +std::shared_ptr<MatchSolution> MatchResult::getBiggerSolution(void){ + if(mSolve.empty()){ - return std::set<NodePtr>(); + return nullptr; }else{ return mSolve[0]; } +} + +std::vector<std::shared_ptr<MatchSolution>> MatchResult::getSolutions(void){ + return mSolve; } \ No newline at end of file diff --git a/src/nodeTester/ConditionalInterpreter.cpp b/src/nodeTester/ConditionalInterpreter.cpp index e01bdd76a28576451a1a09202d5fd1e87a4856e5..59515d0acd77a6202e698ca1e8f1bb28b105266c 100644 --- a/src/nodeTester/ConditionalInterpreter.cpp +++ b/src/nodeTester/ConditionalInterpreter.cpp @@ -18,19 +18,29 @@ using namespace Aidge; } } + ////////////////////// //ConditionalInterpreter /////////////////////// - ConditionalInterpreter::ConditionalInterpreter(const std::string ConditionalExpressions) - :mLambdaRegiter() + ConditionalInterpreter::ConditionalInterpreter(const std::string key,const std::string ConditionalExpressions) + :mLambdaRegister(),mKey(key) { ConditionalParser conditionalParser = ConditionalParser(ConditionalExpressions); mTree = conditionalParser.parse(); + ///lambda by default - mLambdaRegiter.insert("getType",+[](NodePtr NodeOp){return NodeOp->type();}); + mLambdaRegister.insert("getType",+[](NodePtr NodeOp){return NodeOp->type();}); } + + bool ConditionalInterpreter::isLambdaRegister(const std::string &key){ + return mLambdaRegister.isLambdaRegister(key); + } + + const std::string& ConditionalInterpreter::getKey(){ + return mKey; + } bool ConditionalInterpreter::test( const NodePtr nodeOp) @@ -39,16 +49,16 @@ using namespace Aidge; clearRes(); try{ std::vector<ConditionalData*> r = visit({mTree},nodeOp); - - if (mResolution.size() != 1){ - throw std::runtime_error("Multi-output interpretation output"); - }else{ - if (!mResolution[0]->isTypeEqualTo<bool>()){ - throw std::runtime_error("TEST OUT MUST BE A BOOL "); + + if (mResolution.size() != 1){ + throw std::runtime_error("Multi output interpretation output"); }else{ - return mResolution[0]->getValue<bool>(); + if (!mResolution[0]->isTypeEqualTo<bool>()){ + throw std::runtime_error("TEST OUT MUST BE A BOOL "); + }else{ + return mResolution[0]->getValue<bool>(); + } } - } }catch(const std::exception& e){ std::ostringstream errorMessage; @@ -58,7 +68,7 @@ using namespace Aidge; } void ConditionalInterpreter::insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f){ - mLambdaRegiter.insert<std::function<bool(Aidge::NodePtr)> >(key, f); + mLambdaRegister.insert<std::function<bool(Aidge::NodePtr)> >(key, f); } ///// @@ -169,8 +179,8 @@ using namespace Aidge; } }catch(const std::exception& e){ std::ostringstream errorMessage; - errorMessage << "Error in visiting AST for node"<< nodeOp->name() << "\n\t" << e.what() << "\n"; - throw std::runtime_error(errorMessage.str()); + errorMessage << "Error in visiting AST for node "<< nodeOp->name() << "\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); } } @@ -210,7 +220,7 @@ using namespace Aidge; //if the lambda have input ConditionalData* data; try { - data = mLambdaRegiter.run(node->getValue(),mResolution); + data = mLambdaRegister.run(node->getValue(),mResolution); } catch (const std::exception& e) { std::ostringstream errorMessage; errorMessage << "Error in conditional interpretation when run the "<< node->getValue() <<" Lambda\n\t" << e.what() << "\n"; @@ -230,7 +240,7 @@ using namespace Aidge; auto b = mResolution[1]; if (a->getType() != b->getType()){ - throw std::runtime_error("EQ Unsuported between type :" + a->getType() +" "+ b->getType()); + throw std::runtime_error("EQ Unsupported between type :" + a->getType() +" "+ b->getType()); } @@ -262,7 +272,7 @@ using namespace Aidge; auto b = mResolution[1]; if (a->getType() != b->getType()){ - throw std::runtime_error("NEQ Unsuported between type :" + a->getType() +" "+ b->getType()); + throw std::runtime_error("NEQ Unsupported between type :" + a->getType() +" "+ b->getType()); } ConditionalData* data = new ConditionalData; diff --git a/unit_tests/CMakeLists.txt b/unit_tests/CMakeLists.txt index 9d9f81516b0cd2611484ee9e3e06e838833200db..5ccfa3832a8ce2522f18ab07e11a78cf8b462a40 100644 --- a/unit_tests/CMakeLists.txt +++ b/unit_tests/CMakeLists.txt @@ -10,6 +10,8 @@ FetchContent_MakeAvailable(Catch2) file(GLOB_RECURSE src_files "*.cpp") +#file(GLOB_RECURSE src_files "graphRegex/Test_GraphRegex.cpp") + add_executable(tests${module_name} ${src_files}) target_link_libraries(tests${module_name} PUBLIC ${module_name}) diff --git a/unit_tests/graphRegex/Test_Fsm.cpp b/unit_tests/graphRegex/Test_Fsm.cpp index e5950f21b323f07b380ae95f70637ca48a173481..c011a50455e9e21f3df66c3ed46a835bed5346b3 100644 --- a/unit_tests/graphRegex/Test_Fsm.cpp +++ b/unit_tests/graphRegex/Test_Fsm.cpp @@ -14,10 +14,10 @@ using namespace Aidge; TEST_CASE("matchFSM", "FsmEdge") { - std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); - std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); - FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest); + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); + FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest); SECTION("FsmEdgeUnique constructor") { REQUIRE(EdgeToTest.getSourceNode() == nodeA); @@ -28,7 +28,7 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("FsmEdgeCommon constructor") { std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); FsmEdgeCommon EdgeToTest(nodeA,nodeB,toTest,"A"); @@ -40,7 +40,7 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("FsmEdgeRef constructor") { std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); FsmEdgeRef EdgeToTest(nodeA,nodeB,0,-1); @@ -52,7 +52,7 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("FsmEdgeEmpty constructor") { std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); FsmEdgeEmpty EdgeToTest(nodeA,nodeB); @@ -65,9 +65,9 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("FsmEdgeFactory"){ std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest = { - {"A",std::make_shared<ConditionalInterpreter>("true==true")}, - {"B",std::make_shared<ConditionalInterpreter>("true==true")}, - {"C",std::make_shared<ConditionalInterpreter>("true==true")} + {"A",std::make_shared<ConditionalInterpreter>("A","true==true")}, + {"B",std::make_shared<ConditionalInterpreter>("B","true==true")}, + {"C",std::make_shared<ConditionalInterpreter>("C","true==true")} }; // make(std::shared_ptr<FsmNode> source, std::shared_ptr<FsmNode> dest, @@ -103,11 +103,11 @@ TEST_CASE("matchFSM", "FsmEdge") { std::shared_ptr<FsmNode> nodeC = std::make_shared<FsmNode>(false,true); //make the edges - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); - std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(""); graph->addEdge(edgeAB); graph->addEdge(edgeBC); @@ -120,7 +120,7 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("graph merge") { - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); //make the nodes std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(false,true); @@ -132,7 +132,7 @@ TEST_CASE("matchFSM", "FsmEdge") { std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); - std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(""); graph->addEdge(edgeAB); graph->addEdge(edgeBC); @@ -149,7 +149,7 @@ TEST_CASE("matchFSM", "FsmEdge") { std::shared_ptr<FsmEdge> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest); std::shared_ptr<FsmEdge> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest); - std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>(""); graph2->addEdge(edge2AB); @@ -184,7 +184,7 @@ TEST_CASE("matchFSM", "FsmEdge") { // std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); // std::shared_ptr<FsmEdgeUnique> edge = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); -// std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); +// std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(""); // graph->addEdge(edge); diff --git a/unit_tests/graphRegex/Test_FsmMatch.cpp b/unit_tests/graphRegex/Test_FsmMatch.cpp index 1fe75be1a47033f75af7ccc4dc5202774444cd10..4b0a009a4b142f56334b133919025e5e83b7435a 100644 --- a/unit_tests/graphRegex/Test_FsmMatch.cpp +++ b/unit_tests/graphRegex/Test_FsmMatch.cpp @@ -14,14 +14,14 @@ using namespace Aidge; TEST_CASE("FsmMatch") { SECTION("Construction") { - std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { - {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, - {"B",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, - {"C",std::make_shared<ConditionalInterpreter>("true==true")} + std::vector<std::shared_ptr<ConditionalInterpreter>> allTest = { + std::make_shared<ConditionalInterpreter>("A","isConv($)==true"), + std::make_shared<ConditionalInterpreter>("B","isConv($)==true"), + std::make_shared<ConditionalInterpreter>("C","true==true") }; - allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); - allTest["B"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest[0]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest[1]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->A",allTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); @@ -41,14 +41,14 @@ TEST_CASE("FsmMatch") { g1->addChild(conv1, "c"); - REQUIRE(allTest["A"]->test(conv) == true); - REQUIRE(allTest["B"]->test(conv) == true); + REQUIRE(allTest[0]->test(conv) == true); + REQUIRE(allTest[1]->test(conv) == true); std::vector<std::shared_ptr<Node>> startNodes = {conv}; auto result = fsm->test(startNodes); - REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1}); + REQUIRE( result[0]->getAll() == std::set<NodePtr>{conv,conv1}); } @@ -70,19 +70,20 @@ TEST_CASE("FsmMatch") { ///////////// - std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { - {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, - {"B",std::make_shared<ConditionalInterpreter>("isFc($)==true")} + std::vector<std::shared_ptr<ConditionalInterpreter>> allTest = { + std::make_shared<ConditionalInterpreter>("A","isConv($)==true"), + std::make_shared<ConditionalInterpreter>("B","isFc($)==true") }; - allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); - allTest["B"]->insertLambda("isFc",+[](NodePtr NodeOp){return NodeOp->type() == "Fc";}); + allTest[0]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest[1]->insertLambda("isFc",+[](NodePtr NodeOp){return NodeOp->type() == "Fc";}); std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A#->A; A#->B",allTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); std::vector<std::shared_ptr<Node>> startNodes = {conv,conv}; auto result = fsm->test(startNodes); - REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1,conv2}); + + REQUIRE( result[0]->getAll() == std::set<NodePtr>{conv,conv1,conv2}); } diff --git a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp index 9ce090506c9a61abd928b3ae590ee838afb05999..e789677d44efa68071017a9832fa01b5ed340f75 100644 --- a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp +++ b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp @@ -8,10 +8,10 @@ using namespace Aidge; TEST_CASE("GraphFsmInterpreter", "GraphFsmInterpreter") { SECTION("Construction") { - std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { - {"A",std::make_shared<ConditionalInterpreter>("true==true")}, - {"B",std::make_shared<ConditionalInterpreter>("true==true")}, - {"C",std::make_shared<ConditionalInterpreter>("true==true")} + std::vector<std::shared_ptr<ConditionalInterpreter>> allTest = { + std::make_shared<ConditionalInterpreter>("A","true==true"), + std::make_shared<ConditionalInterpreter>("B","true==true"), + std::make_shared<ConditionalInterpreter>("C","true==true") }; //GraphFsmInterpreter("A->B",allTest); diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b30560ea3ea696821d2422bf760a11973a104e85 --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -0,0 +1,84 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphRegex.hpp" + +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" + +using namespace Aidge; + +TEST_CASE("GraphRegexUser") { + + SECTION("INIT") { + + const std::string query = "Conv->FC"; + + std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> fc = GenericOperator("FC", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 1, 1, "c3"); + + g1->add(conv); + g1->addChild(fc, "c"); + g1->addChild(conv2, "c1"); + g1->addChild(fc2, "c2"); + + + sut->setKeyFromGraph(g1); + sut->addQuery(query); + + for (const auto& solution : sut->match(g1)) { + + REQUIRE(solution->getQuery() == query); + if(solution->getStartNode() == std::vector<NodePtr>{conv}){ + REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv} ); + REQUIRE(solution->at("FC") == std::set<NodePtr>{fc} ); + }else if (solution->getStartNode() == std::vector<NodePtr>{conv2}) + { + REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv2} ); + REQUIRE(solution->at("FC") == std::set<NodePtr>{fc2} ); + } + } + //REQUIRE( sut->match(g1)[1]->getAll() == std::set<NodePtr>{conv,fc}); + + } + + SECTION("CC") { + + + + std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + + g1->add(conv); + g1->addChild(conv1, "c"); + g1->addChild(conv2, "c1"); + g1->addChild(conv3, "c2"); + + + sut->setKeyFromGraph(g1); + + const std::string query = "Conv->Conv"; + const std::string query2 = "Conv->FC"; + + sut->setNodeKey("FC","getType($) =='FC'"); + + sut->addQuery(query); + sut->addQuery(query2); + + + for (const auto& solution : sut->match(g1)) { + REQUIRE(solution->getQuery() == query); + } + + } +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp index 8b502fb546e2f1396b629ebc78bc1bd4d67842e2..6143b7e3d8c4331c178afa6267de723cbea7dfdb 100644 --- a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp +++ b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp @@ -13,7 +13,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("custom Lambda") { const std::string test = " !toto($) == true " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); conditionalParser.insertLambda("toto",+[](NodePtr NodeOp){return false;}); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); @@ -24,7 +24,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("syntax error") { const std::string test = "'A' == 'A' ,&& "; - REQUIRE_THROWS_AS( ConditionalInterpreter(test), std::runtime_error); + REQUIRE_THROWS_AS( ConditionalInterpreter("A",test), std::runtime_error); } @@ -32,7 +32,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("test false int ") { const std::string test = " 10 == 11 " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); bool result = conditionalParser.test(nodeOp); REQUIRE(result == false); @@ -40,7 +40,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("test true int ") { const std::string test = " 42 == 42 " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); bool result = conditionalParser.test(nodeOp); REQUIRE(result == true); @@ -48,7 +48,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("test false str ") { const std::string test = " 'toto' == 'Corgi' " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); bool result = conditionalParser.test(nodeOp); REQUIRE(result == false); @@ -57,7 +57,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("test true str ") { const std::string test = " 'Corgi' == 'Corgi' " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); bool result = conditionalParser.test(nodeOp); REQUIRE(result == true);