From 83339ba3b6edfddd45fd387388377ca99e019ac0 Mon Sep 17 00:00:00 2001 From: vl241552 <vincent.lorrain@cea.fr> Date: Tue, 19 Sep 2023 14:55:47 +0000 Subject: [PATCH] [graphRegex] Add unitest and fix --- include/aidge/graph/Node.hpp | 12 ++++ include/aidge/graphRegex/matchFsm/FsmEdge.hpp | 22 +++++-- .../aidge/graphRegex/matchFsm/FsmGraph.hpp | 7 +- .../graphRegex/matchFsm/FsmRunTimeContext.hpp | 1 - .../aidge/graphRegex/matchFsm/MatchResult.hpp | 5 +- src/graph/Node.cpp | 29 +++++++++ src/graphRegex/matchFsm/FsmEdge.cpp | 40 +++++++----- src/graphRegex/matchFsm/FsmGraph.cpp | 41 ++++++++++-- src/graphRegex/matchFsm/FsmNode.cpp | 27 ++++++-- src/graphRegex/matchFsm/FsmRunTimeContext.cpp | 21 +++++- src/nodeTester/ConditionalParser.cpp | 2 +- unit_tests/graph/Test_get.cpp | 49 ++++++++++++++ unit_tests/graphRegex/Test_Fsm.cpp | 15 +++-- unit_tests/graphRegex/Test_FsmMatch.cpp | 64 +++++++++++++++++++ .../graphRegex/Test_GraphFsmInterpreter.cpp | 15 +++++ 15 files changed, 302 insertions(+), 48 deletions(-) create mode 100644 unit_tests/graph/Test_get.cpp create mode 100644 unit_tests/graphRegex/Test_FsmMatch.cpp diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 11def52db..799b8b3e0 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -350,6 +350,18 @@ public: */ void resetConnections(bool includeLearnableParam = false); + + /** + * @brief Get the set of pointers to connected node at a distance of a delta. + * @details the recution are cut + * Return a nullptr is nofing found. + * @param delta Input delta. + * @return std::shared_ptr<Node> + */ + + std::set<NodePtr> getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee); + + private: /////////////////////////////////////////////////////// // OPERATORS diff --git a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp index 995f14789..bb13e0d2a 100644 --- a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp @@ -36,14 +36,20 @@ namespace Aidge{ * @brief the ptr on the source node */ std::shared_ptr<FsmNode> mNodeSource; - /** + /** * @brief the ptr on the dest node */ std::shared_ptr<FsmNode> mNodeDest; + /** + * @brief the weak ptr + */ + std::weak_ptr<FsmEdge> weakPtr; public: - FsmEdge(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest); + FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest); virtual ~FsmEdge() =default; + FsmEdge() : weakPtr(shared_from_this()) {} + /** * @brief test is the validation of the node , it must be deffine for all types of edge @@ -102,6 +108,11 @@ namespace Aidge{ * @see ConditionalInterpreter */ const std::shared_ptr<ConditionalInterpreter> mToTest; + + /** + * @brief update week ptr for the node, TODO best + */ + void updateWeak(void); }; /** @@ -111,8 +122,7 @@ namespace Aidge{ { public: - FsmEdgeUnique(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest); - //~FsmEdgeUnique() override {} + FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest); const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; }; @@ -139,7 +149,7 @@ namespace Aidge{ * @details during construction, * the node key found by the lexer is converted to a unique id and the relative positions are updated. */ - FsmEdgeCommon(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey); + FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey); // ~FsmEdgeCommon() override {} const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; bool isCommon(void) override; @@ -164,7 +174,7 @@ namespace Aidge{ */ const int mdeltaCommonIdx; public: - FsmEdgeRef(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const size_t refCommonIdx,const int deltaCommonIdx); + FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx); //~FsmEdgeRef() override {} const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; diff --git a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp index a5257eb44..d2eb54847 100644 --- a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp @@ -17,13 +17,16 @@ namespace Aidge{ class FsmGraph { private: + /** + * @brief all node origine + */ std::set<std::size_t> mAllOrigine; std::set<std::shared_ptr<FsmEdge>> mEdges; public: FsmGraph(/* args */); virtual ~FsmGraph() = default; -std::vector<std::shared_ptr<MatchResult>> test(std::vector<NodePtr>& StartNodes); +std::shared_ptr<MatchResult> test(std::vector<NodePtr>& StartNodes); @@ -31,7 +34,7 @@ const std::set<std::shared_ptr<FsmEdge>>& getEdge(void); /** * @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph */ -void addEdge(std::shared_ptr<FsmEdge> edge); +void addEdge(std::shared_ptr<FsmEdge>& edge); /** * @brief get the liste of the starting states diff --git a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp index 07ada14db..f789969df 100644 --- a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp @@ -6,7 +6,6 @@ #include <set> #include <algorithm> -//#include "graphRegex/matchFsm/FsmNode.hpp" #include "aidge/nodeTester/ConditionalInterpreter.hpp" #include "aidge/graph/Node.hpp" diff --git a/include/aidge/graphRegex/matchFsm/MatchResult.hpp b/include/aidge/graphRegex/matchFsm/MatchResult.hpp index 8c96cc0e3..0618dfd54 100644 --- a/include/aidge/graphRegex/matchFsm/MatchResult.hpp +++ b/include/aidge/graphRegex/matchFsm/MatchResult.hpp @@ -16,8 +16,11 @@ class MatchResult { private: /* data */ + std::vector<std::shared_ptr<FsmRunTimeContext>> mAllValid; public: - MatchResult(std::shared_ptr<FsmRunTimeContext> contexte); + MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid){ + mAllValid = allValid; + }; virtual ~MatchResult() = default; /** diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 5fcc0e113..a23d3c66f 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -321,6 +321,35 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { } } + +std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ + + std::set<Aidge::NodePtr> out; + nodeSee.insert(shared_from_this()); + + if(delta == 0) { + out.insert(shared_from_this()); + + }else if (delta > 0){ + for (const NodePtr& node : getChildren()) { + if(nodeSee.find(node) == out.end()){ //loop avoidance + for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ + out.insert(ch); + } + } + } + }else{ + for (const NodePtr& node : getParents()) { + if(nodeSee.find(node) == out.end()){ //loop avoidance + for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ + out.insert(pr); + } + } + } + } + + return out; +} ///////////////////////////////////////////////////////////////////////////////////////////// // private diff --git a/src/graphRegex/matchFsm/FsmEdge.cpp b/src/graphRegex/matchFsm/FsmEdge.cpp index cb5f8c058..19b6c356e 100644 --- a/src/graphRegex/matchFsm/FsmEdge.cpp +++ b/src/graphRegex/matchFsm/FsmEdge.cpp @@ -131,26 +131,35 @@ void FsmEdge::propagateRelativePos(void){ } } +void FsmEdge::updateWeak(void){ + mNodeSource->addEdge(shared_from_this()); + mNodeDest->addParent(mNodeSource); +} - -FsmEdge::FsmEdge(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest) +FsmEdge::FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest) :mToTest(toTest) { mNodeSource = source; mNodeDest = dest; + // wen i make the edge I init the nodes + // mNodeSource->addEdge(shared_from_this()); + // mNodeDest->addParent(mNodeSource); } + /////surchage -FsmEdgeUnique::FsmEdgeUnique(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest) +FsmEdgeUnique::FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest) :FsmEdge(source,dest,toTest) { } const EdgeTestResult FsmEdgeUnique::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ - if(stmContext == nullptr){ + auto opNode = stmContext->getActNode(); + + if(opNode == nullptr){ return {false,std::set<NodePtr>()};//none } - auto opNode = stmContext->getActNode(); + if(mToTest->test(opNode) && opNode->getChildren().size() <= 1){ stmContext->setValid(opNode,mToTest); return {true,opNode->getChildren()} ; @@ -160,7 +169,7 @@ const EdgeTestResult FsmEdgeUnique::test(const std::shared_ptr<FsmRunTimeContext } } ///////////////////// -FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey) +FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey) :FsmEdge(source,dest,toTest) { //make a uid for common node @@ -173,11 +182,12 @@ FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode> source,std::shared_ptr<Fsm const EdgeTestResult FsmEdgeCommon::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ - if(stmContext == nullptr){ - return {false,std::set<NodePtr>()};//none - } + auto opNode = stmContext->getActNode(); + if(opNode == nullptr){ + return {false,std::set<NodePtr>()};//none + } if(mToTest->test(opNode)){ stmContext->setCommon(opNode,mCommonIdx); stmContext->setValid(opNode,mToTest); @@ -191,7 +201,7 @@ bool FsmEdgeCommon::isCommon(void){ return true; } //////////////////// TODO FsmEdgeEmpty must be size_t -FsmEdgeRef::FsmEdgeRef(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const size_t refCommonIdx,const int deltaCommonIdx) +FsmEdgeRef::FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx) :FsmEdge(source,dest,nullptr),mRefCommonIdx(refCommonIdx),mdeltaCommonIdx(deltaCommonIdx) { @@ -200,22 +210,20 @@ const EdgeTestResult FsmEdgeRef::test(const std::shared_ptr<FsmRunTimeContext> s NodePtr refNode = stmContext->getCommonNodeFromIdx(mRefCommonIdx); if (refNode){ - - //return {true,refNode->getNodeDelta(mdeltaCommonIdx)}; TODO + std::set<std::shared_ptr<Node>> see; + return {true,refNode->getNodeDelta(mdeltaCommonIdx,see)}; } return {false,std::set<NodePtr>()}; - - } //////////////////// FsmEdgeEmpty::FsmEdgeEmpty(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest) :FsmEdge(source,dest,nullptr) {} const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ - if(stmContext == nullptr){ + auto opNode = stmContext->getActNode(); + if(opNode == nullptr){ return {false,std::set<NodePtr>()}; } - auto opNode = stmContext->getActNode(); return {true,std::set<NodePtr>({opNode})};//none } diff --git a/src/graphRegex/matchFsm/FsmGraph.cpp b/src/graphRegex/matchFsm/FsmGraph.cpp index 2ec9fe75e..59f634d43 100644 --- a/src/graphRegex/matchFsm/FsmGraph.cpp +++ b/src/graphRegex/matchFsm/FsmGraph.cpp @@ -9,7 +9,7 @@ FsmGraph::FsmGraph(/* args */){ } //TODO -std::vector<std::shared_ptr<MatchResult>> FsmGraph::test(std::vector<NodePtr>& startNodes){ + std::shared_ptr<MatchResult> FsmGraph::test(std::vector<NodePtr>& startNodes){ std::vector<std::shared_ptr<Aidge::FsmNode>> startNodesFsm = getStartNodes(); if(startNodes.size() != startNodesFsm.size()){ throw std::runtime_error("bad number of Start nodes"); @@ -19,34 +19,63 @@ std::vector<std::shared_ptr<MatchResult>> FsmGraph::test(std::vector<NodePtr>& s for(std::size_t i = 0; i < startNodes.size(); i++){ walks.push_back(std::make_shared<FsmRunTimeContext>(startNodesFsm[i],startNodes[i])); } + std::vector<std::shared_ptr<FsmRunTimeContext>> nextWalks; std::vector<std::shared_ptr<FsmRunTimeContext>> allValidContext; + std::vector<std::shared_ptr<FsmRunTimeContext>> allContextSee; + + + while (!walks.empty()) { - std::vector<std::shared_ptr<FsmRunTimeContext>> nextWalks; - for(auto fsmContext : walks){ + allContextSee.push_back(fsmContext); //if we are in a valid st we save it //it's one solution of the posible solution of the matching if(fsmContext->isOnValidState()){ - //TODO not push same fsm use are_equal - allValidContext.push_back(fsmContext); + //not save 2 time the same end point + if(!std::any_of(allValidContext.begin(), allValidContext.end(), + [&](std::shared_ptr<Aidge::FsmRunTimeContext> oldValid) { + return fsmContext->areEqual(oldValid); + })){ + allValidContext.push_back(fsmContext); + } + } + //dont test 2 time a fsmContext std::vector<std::shared_ptr<FsmRunTimeContext>> tmpNextWalks = fsmContext->getActState()->test(fsmContext); + for(auto PotentialFsmContext : tmpNextWalks){ + + if(!std::any_of(allContextSee.begin(), allContextSee.end(), + [&](std::shared_ptr<Aidge::FsmRunTimeContext> oldSee) { + return PotentialFsmContext->areEqual(oldSee); + })){ + nextWalks.push_back(PotentialFsmContext); + } + } + } + walks.swap(nextWalks); + nextWalks.clear(); } + + return std::make_shared<MatchResult>(allValidContext); } +/////////////// +// FSM construction +/////////////// const std::set<std::shared_ptr<FsmEdge>>& FsmGraph::getEdge(void){ return mEdges; } -void FsmGraph::addEdge(std::shared_ptr<FsmEdge> edge){ +void FsmGraph::addEdge(std::shared_ptr<FsmEdge>& edge){ + edge->updateWeak(); mEdges.insert(edge); mAllOrigine.insert(edge->getDestNode()->getOrigine()); mAllOrigine.insert(edge->getSourceNode()->getOrigine()); diff --git a/src/graphRegex/matchFsm/FsmNode.cpp b/src/graphRegex/matchFsm/FsmNode.cpp index 321b78610..46c23c83b 100644 --- a/src/graphRegex/matchFsm/FsmNode.cpp +++ b/src/graphRegex/matchFsm/FsmNode.cpp @@ -15,19 +15,21 @@ const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared std::vector<std::shared_ptr<FsmRunTimeContext>> out; - std::shared_ptr<FsmRunTimeContext> newFsmContext ; + for(auto edge : mEdges){ if (auto sharedEdge = edge.lock()) { + std::shared_ptr<FsmNode> nextState = sharedEdge->getDestNode(); - newFsmContext = std::make_shared<FsmRunTimeContext>(fsmContext); + + //make copy of the fsmContext + std::shared_ptr<FsmRunTimeContext> newFsmContext = std::make_shared<FsmRunTimeContext>(fsmContext); EdgeTestResult edgeRes = sharedEdge->test(newFsmContext); if(edgeRes.success){ if(edgeRes.node.size() != 0){ for(auto nextNode :edgeRes.node ){ - - if(!newFsmContext->isAlreadyValid(nextNode) || newFsmContext->isCommonDefined(nextNode) ){ + if(!newFsmContext->isAlreadyValid(nextNode)|| newFsmContext->isCommonDefined(nextNode) ){ out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nextNode)); }else{ @@ -35,6 +37,8 @@ const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared } } + }else{ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nullptr)); } } newFsmContext.reset(); @@ -60,7 +64,12 @@ void FsmNode::rmEdge(std::shared_ptr<FsmEdge> edge){ } void FsmNode::addEdge(std::shared_ptr<FsmEdge> edge){ - mEdges.insert(edge); + std::weak_ptr<FsmEdge> edgeW(edge); + if (!edgeW.expired()) { + mEdges.insert(edgeW); + }else{ + throw std::runtime_error("addEdge FsmNode weak pointer is expired" ); + } } const std::set<std::shared_ptr<FsmNode>> FsmNode::getChildNodes(void){ @@ -110,7 +119,13 @@ void FsmNode::start(void){ void FsmNode::addParent(std::shared_ptr<FsmNode> node){ - mParents.insert(node); + + std::weak_ptr<FsmNode> nodeW(node); + if (!nodeW.expired()) { + mParents.insert(nodeW); + }else{ + throw std::runtime_error("addParent FsmNode weak pointer is expired" ); + } } void FsmNode::rmParent(std::shared_ptr<FsmNode> node){ mParents.erase(node); diff --git a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp index 3e1b5f2a1..f8171733c 100644 --- a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp +++ b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp @@ -51,11 +51,28 @@ bool FsmRunTimeContext::isOnValidState(void){ } bool FsmRunTimeContext::isCommonDefined(NodePtr node){ - return mCommonNodes.find(node) != mCommonNodes.end(); + //return mCommonNodes.find(node) != mCommonNodes.end(); + + std::set<NodePtr> nodes = getCommonNodes(); + for(const auto& nodeC : nodes){ + if(nodeC.get() == node.get()){ + return true; + } + } + return false; } bool FsmRunTimeContext::isAlreadyValid(NodePtr node){ - return getValidNodes().find(node) != getValidNodes().end(); + + std::set<NodePtr> nodes = getValidNodes(); + for(const auto& nodeV : nodes){ + if(nodeV.get() == node.get()){ + return true; + } + } + return false; + + //return getValidNodes().find(node) != getValidNodes().end(); } bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext){ diff --git a/src/nodeTester/ConditionalParser.cpp b/src/nodeTester/ConditionalParser.cpp index 7b5dddcc5..3ca2843aa 100644 --- a/src/nodeTester/ConditionalParser.cpp +++ b/src/nodeTester/ConditionalParser.cpp @@ -148,7 +148,7 @@ std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstE //pratt while (mCurrentToken->getType() != ConditionalTokenTypes::STOP ) //security { - std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + token = mCurrentToken->copy(); //if the token is not in the map is not a operator so we consider a prec of 0 if (ConditionalPrec.find(token->getType()) ==ConditionalPrec.end() ){ return left; diff --git a/unit_tests/graph/Test_get.cpp b/unit_tests/graph/Test_get.cpp new file mode 100644 index 000000000..bc54fc83f --- /dev/null +++ b/unit_tests/graph/Test_get.cpp @@ -0,0 +1,49 @@ + + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" + + +using namespace Aidge; +TEST_CASE("get Delta") { + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + g1->addChild(conv1, "c"); + + + std::set<Aidge::NodePtr> see; + conv->getNodeDelta(0,see); + + SECTION("Self return") { + + REQUIRE(conv->getNodeDelta(0,see) == std::set<std::shared_ptr<Node>>{conv}); + } + + +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_Fsm.cpp b/unit_tests/graphRegex/Test_Fsm.cpp index 518987f53..e5950f21b 100644 --- a/unit_tests/graphRegex/Test_Fsm.cpp +++ b/unit_tests/graphRegex/Test_Fsm.cpp @@ -13,12 +13,13 @@ using namespace Aidge; TEST_CASE("matchFSM", "FsmEdge") { - SECTION("FsmEdgeUnique constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest); + SECTION("FsmEdgeUnique constructor") { REQUIRE(EdgeToTest.getSourceNode() == nodeA); REQUIRE(EdgeToTest.getDestNode() == nodeB); REQUIRE(EdgeToTest.isCommon() == false); @@ -103,8 +104,8 @@ TEST_CASE("matchFSM", "FsmEdge") { //make the edges std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); - std::shared_ptr<FsmEdgeUnique> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); - std::shared_ptr<FsmEdgeUnique> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); + std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); @@ -128,8 +129,8 @@ TEST_CASE("matchFSM", "FsmEdge") { //make the edges - std::shared_ptr<FsmEdgeUnique> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); - std::shared_ptr<FsmEdgeUnique> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); + std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); graph->addEdge(edgeAB); @@ -145,8 +146,8 @@ TEST_CASE("matchFSM", "FsmEdge") { std::shared_ptr<FsmNode> node2C = std::make_shared<FsmNode>(true,false); - std::shared_ptr<FsmEdgeUnique> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest); - std::shared_ptr<FsmEdgeUnique> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest); + std::shared_ptr<FsmEdge> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest); + std::shared_ptr<FsmEdge> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest); std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>(); diff --git a/unit_tests/graphRegex/Test_FsmMatch.cpp b/unit_tests/graphRegex/Test_FsmMatch.cpp new file mode 100644 index 000000000..168bebebe --- /dev/null +++ b/unit_tests/graphRegex/Test_FsmMatch.cpp @@ -0,0 +1,64 @@ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" + + +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("FsmMatch") { + + SECTION("Construction") { + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"B",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + + allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest["B"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + + //GraphFsmInterpreter("A->B",allTest); + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->B",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + + REQUIRE(fsm->getNodes().size() == 3); + REQUIRE(fsm->getStartNodes().size() == 1); + + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + g1->addChild(conv1, "c"); + + + REQUIRE(allTest["A"]->test(conv) == true); + REQUIRE(allTest["B"]->test(conv) == true); + + std::vector<std::shared_ptr<Node>> startNodes = {conv}; + + fsm->test(startNodes); + + + + } + + + + + +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp index d151f46c0..9ce090506 100644 --- a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp +++ b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp @@ -21,7 +21,22 @@ TEST_CASE("GraphFsmInterpreter", "GraphFsmInterpreter") { REQUIRE(fsm->getNodes().size() == 3); REQUIRE(fsm->getStartNodes().size() == 1); + REQUIRE(fsm->getEdge().size() == 2); + + for(auto node : fsm->getNodes()){ + if(node->isValid()){ + REQUIRE(node->getEdges().size() == 0); + }else{ + REQUIRE(node->getEdges().size() == 1); + } + + } + + } + + + } \ No newline at end of file -- GitLab