Skip to content
Snippets Groups Projects
Commit 83339ba3 authored by vincent  lorrain's avatar vincent lorrain
Browse files

[graphRegex] Add unitest and fix

parent 483dcace
No related branches found
No related tags found
1 merge request!14Graph regex
Pipeline #31752 failed
Showing
with 302 additions and 48 deletions
......@@ -350,6 +350,18 @@ public:
*/
void resetConnections(bool includeLearnableParam = false);
/**
* @brief Get the set of pointers to connected node at a distance of a delta.
* @details the recution are cut
* Return a nullptr is nofing found.
* @param delta Input delta.
* @return std::shared_ptr<Node>
*/
std::set<NodePtr> getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee);
private:
///////////////////////////////////////////////////////
// OPERATORS
......
......@@ -36,14 +36,20 @@ namespace Aidge{
* @brief the ptr on the source node
*/
std::shared_ptr<FsmNode> mNodeSource;
/**
/**
* @brief the ptr on the dest node
*/
std::shared_ptr<FsmNode> mNodeDest;
/**
* @brief the weak ptr
*/
std::weak_ptr<FsmEdge> weakPtr;
public:
FsmEdge(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest);
FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest);
virtual ~FsmEdge() =default;
FsmEdge() : weakPtr(shared_from_this()) {}
/**
* @brief test is the validation of the node , it must be deffine for all types of edge
......@@ -102,6 +108,11 @@ namespace Aidge{
* @see ConditionalInterpreter
*/
const std::shared_ptr<ConditionalInterpreter> mToTest;
/**
* @brief update week ptr for the node, TODO best
*/
void updateWeak(void);
};
/**
......@@ -111,8 +122,7 @@ namespace Aidge{
{
public:
FsmEdgeUnique(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest);
//~FsmEdgeUnique() override {}
FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest);
const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override;
};
......@@ -139,7 +149,7 @@ namespace Aidge{
* @details during construction,
* the node key found by the lexer is converted to a unique id and the relative positions are updated.
*/
FsmEdgeCommon(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey);
FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey);
// ~FsmEdgeCommon() override {}
const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override;
bool isCommon(void) override;
......@@ -164,7 +174,7 @@ namespace Aidge{
*/
const int mdeltaCommonIdx;
public:
FsmEdgeRef(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const size_t refCommonIdx,const int deltaCommonIdx);
FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx);
//~FsmEdgeRef() override {}
const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override;
......
......@@ -17,13 +17,16 @@ namespace Aidge{
class FsmGraph
{
private:
/**
* @brief all node origine
*/
std::set<std::size_t> mAllOrigine;
std::set<std::shared_ptr<FsmEdge>> mEdges;
public:
FsmGraph(/* args */);
virtual ~FsmGraph() = default;
std::vector<std::shared_ptr<MatchResult>> test(std::vector<NodePtr>& StartNodes);
std::shared_ptr<MatchResult> test(std::vector<NodePtr>& StartNodes);
......@@ -31,7 +34,7 @@ const std::set<std::shared_ptr<FsmEdge>>& getEdge(void);
/**
* @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph
*/
void addEdge(std::shared_ptr<FsmEdge> edge);
void addEdge(std::shared_ptr<FsmEdge>& edge);
/**
* @brief get the liste of the starting states
......
......@@ -6,7 +6,6 @@
#include <set>
#include <algorithm>
//#include "graphRegex/matchFsm/FsmNode.hpp"
#include "aidge/nodeTester/ConditionalInterpreter.hpp"
#include "aidge/graph/Node.hpp"
......
......@@ -16,8 +16,11 @@ class MatchResult
{
private:
/* data */
std::vector<std::shared_ptr<FsmRunTimeContext>> mAllValid;
public:
MatchResult(std::shared_ptr<FsmRunTimeContext> contexte);
MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid){
mAllValid = allValid;
};
virtual ~MatchResult() = default;
/**
......
......@@ -321,6 +321,35 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) {
}
}
std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){
std::set<Aidge::NodePtr> out;
nodeSee.insert(shared_from_this());
if(delta == 0) {
out.insert(shared_from_this());
}else if (delta > 0){
for (const NodePtr& node : getChildren()) {
if(nodeSee.find(node) == out.end()){ //loop avoidance
for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){
out.insert(ch);
}
}
}
}else{
for (const NodePtr& node : getParents()) {
if(nodeSee.find(node) == out.end()){ //loop avoidance
for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){
out.insert(pr);
}
}
}
}
return out;
}
/////////////////////////////////////////////////////////////////////////////////////////////
// private
......
......@@ -131,26 +131,35 @@ void FsmEdge::propagateRelativePos(void){
}
}
void FsmEdge::updateWeak(void){
mNodeSource->addEdge(shared_from_this());
mNodeDest->addParent(mNodeSource);
}
FsmEdge::FsmEdge(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest)
FsmEdge::FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest)
:mToTest(toTest)
{
mNodeSource = source;
mNodeDest = dest;
// wen i make the edge I init the nodes
// mNodeSource->addEdge(shared_from_this());
// mNodeDest->addParent(mNodeSource);
}
/////surchage
FsmEdgeUnique::FsmEdgeUnique(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest)
FsmEdgeUnique::FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest)
:FsmEdge(source,dest,toTest)
{
}
const EdgeTestResult FsmEdgeUnique::test(const std::shared_ptr<FsmRunTimeContext> stmContext){
if(stmContext == nullptr){
auto opNode = stmContext->getActNode();
if(opNode == nullptr){
return {false,std::set<NodePtr>()};//none
}
auto opNode = stmContext->getActNode();
if(mToTest->test(opNode) && opNode->getChildren().size() <= 1){
stmContext->setValid(opNode,mToTest);
return {true,opNode->getChildren()} ;
......@@ -160,7 +169,7 @@ const EdgeTestResult FsmEdgeUnique::test(const std::shared_ptr<FsmRunTimeContext
}
}
/////////////////////
FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey)
FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey)
:FsmEdge(source,dest,toTest)
{
//make a uid for common node
......@@ -173,11 +182,12 @@ FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode> source,std::shared_ptr<Fsm
const EdgeTestResult FsmEdgeCommon::test(const std::shared_ptr<FsmRunTimeContext> stmContext){
if(stmContext == nullptr){
return {false,std::set<NodePtr>()};//none
}
auto opNode = stmContext->getActNode();
if(opNode == nullptr){
return {false,std::set<NodePtr>()};//none
}
if(mToTest->test(opNode)){
stmContext->setCommon(opNode,mCommonIdx);
stmContext->setValid(opNode,mToTest);
......@@ -191,7 +201,7 @@ bool FsmEdgeCommon::isCommon(void){
return true;
}
//////////////////// TODO FsmEdgeEmpty must be size_t
FsmEdgeRef::FsmEdgeRef(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest, const size_t refCommonIdx,const int deltaCommonIdx)
FsmEdgeRef::FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx)
:FsmEdge(source,dest,nullptr),mRefCommonIdx(refCommonIdx),mdeltaCommonIdx(deltaCommonIdx)
{
......@@ -200,22 +210,20 @@ const EdgeTestResult FsmEdgeRef::test(const std::shared_ptr<FsmRunTimeContext> s
NodePtr refNode = stmContext->getCommonNodeFromIdx(mRefCommonIdx);
if (refNode){
//return {true,refNode->getNodeDelta(mdeltaCommonIdx)}; TODO
std::set<std::shared_ptr<Node>> see;
return {true,refNode->getNodeDelta(mdeltaCommonIdx,see)};
}
return {false,std::set<NodePtr>()};
}
////////////////////
FsmEdgeEmpty::FsmEdgeEmpty(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest)
:FsmEdge(source,dest,nullptr)
{}
const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext> stmContext){
if(stmContext == nullptr){
auto opNode = stmContext->getActNode();
if(opNode == nullptr){
return {false,std::set<NodePtr>()};
}
auto opNode = stmContext->getActNode();
return {true,std::set<NodePtr>({opNode})};//none
}
......
......@@ -9,7 +9,7 @@ FsmGraph::FsmGraph(/* args */){
}
//TODO
std::vector<std::shared_ptr<MatchResult>> FsmGraph::test(std::vector<NodePtr>& startNodes){
std::shared_ptr<MatchResult> FsmGraph::test(std::vector<NodePtr>& startNodes){
std::vector<std::shared_ptr<Aidge::FsmNode>> startNodesFsm = getStartNodes();
if(startNodes.size() != startNodesFsm.size()){
throw std::runtime_error("bad number of Start nodes");
......@@ -19,34 +19,63 @@ std::vector<std::shared_ptr<MatchResult>> FsmGraph::test(std::vector<NodePtr>& s
for(std::size_t i = 0; i < startNodes.size(); i++){
walks.push_back(std::make_shared<FsmRunTimeContext>(startNodesFsm[i],startNodes[i]));
}
std::vector<std::shared_ptr<FsmRunTimeContext>> nextWalks;
std::vector<std::shared_ptr<FsmRunTimeContext>> allValidContext;
std::vector<std::shared_ptr<FsmRunTimeContext>> allContextSee;
while (!walks.empty())
{
std::vector<std::shared_ptr<FsmRunTimeContext>> nextWalks;
for(auto fsmContext : walks){
allContextSee.push_back(fsmContext);
//if we are in a valid st we save it
//it's one solution of the posible solution of the matching
if(fsmContext->isOnValidState()){
//TODO not push same fsm use are_equal
allValidContext.push_back(fsmContext);
//not save 2 time the same end point
if(!std::any_of(allValidContext.begin(), allValidContext.end(),
[&](std::shared_ptr<Aidge::FsmRunTimeContext> oldValid) {
return fsmContext->areEqual(oldValid);
})){
allValidContext.push_back(fsmContext);
}
}
//dont test 2 time a fsmContext
std::vector<std::shared_ptr<FsmRunTimeContext>> tmpNextWalks = fsmContext->getActState()->test(fsmContext);
for(auto PotentialFsmContext : tmpNextWalks){
if(!std::any_of(allContextSee.begin(), allContextSee.end(),
[&](std::shared_ptr<Aidge::FsmRunTimeContext> oldSee) {
return PotentialFsmContext->areEqual(oldSee);
})){
nextWalks.push_back(PotentialFsmContext);
}
}
}
walks.swap(nextWalks);
nextWalks.clear();
}
return std::make_shared<MatchResult>(allValidContext);
}
///////////////
// FSM construction
///////////////
const std::set<std::shared_ptr<FsmEdge>>& FsmGraph::getEdge(void){
return mEdges;
}
void FsmGraph::addEdge(std::shared_ptr<FsmEdge> edge){
void FsmGraph::addEdge(std::shared_ptr<FsmEdge>& edge){
edge->updateWeak();
mEdges.insert(edge);
mAllOrigine.insert(edge->getDestNode()->getOrigine());
mAllOrigine.insert(edge->getSourceNode()->getOrigine());
......
......@@ -15,19 +15,21 @@ const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared
std::vector<std::shared_ptr<FsmRunTimeContext>> out;
std::shared_ptr<FsmRunTimeContext> newFsmContext ;
for(auto edge : mEdges){
if (auto sharedEdge = edge.lock()) {
std::shared_ptr<FsmNode> nextState = sharedEdge->getDestNode();
newFsmContext = std::make_shared<FsmRunTimeContext>(fsmContext);
//make copy of the fsmContext
std::shared_ptr<FsmRunTimeContext> newFsmContext = std::make_shared<FsmRunTimeContext>(fsmContext);
EdgeTestResult edgeRes = sharedEdge->test(newFsmContext);
if(edgeRes.success){
if(edgeRes.node.size() != 0){
for(auto nextNode :edgeRes.node ){
if(!newFsmContext->isAlreadyValid(nextNode) || newFsmContext->isCommonDefined(nextNode) ){
if(!newFsmContext->isAlreadyValid(nextNode)|| newFsmContext->isCommonDefined(nextNode) ){
out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nextNode));
}else{
......@@ -35,6 +37,8 @@ const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared
}
}
}else{
out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nullptr));
}
}
newFsmContext.reset();
......@@ -60,7 +64,12 @@ void FsmNode::rmEdge(std::shared_ptr<FsmEdge> edge){
}
void FsmNode::addEdge(std::shared_ptr<FsmEdge> edge){
mEdges.insert(edge);
std::weak_ptr<FsmEdge> edgeW(edge);
if (!edgeW.expired()) {
mEdges.insert(edgeW);
}else{
throw std::runtime_error("addEdge FsmNode weak pointer is expired" );
}
}
const std::set<std::shared_ptr<FsmNode>> FsmNode::getChildNodes(void){
......@@ -110,7 +119,13 @@ void FsmNode::start(void){
void FsmNode::addParent(std::shared_ptr<FsmNode> node){
mParents.insert(node);
std::weak_ptr<FsmNode> nodeW(node);
if (!nodeW.expired()) {
mParents.insert(nodeW);
}else{
throw std::runtime_error("addParent FsmNode weak pointer is expired" );
}
}
void FsmNode::rmParent(std::shared_ptr<FsmNode> node){
mParents.erase(node);
......
......@@ -51,11 +51,28 @@ bool FsmRunTimeContext::isOnValidState(void){
}
bool FsmRunTimeContext::isCommonDefined(NodePtr node){
return mCommonNodes.find(node) != mCommonNodes.end();
//return mCommonNodes.find(node) != mCommonNodes.end();
std::set<NodePtr> nodes = getCommonNodes();
for(const auto& nodeC : nodes){
if(nodeC.get() == node.get()){
return true;
}
}
return false;
}
bool FsmRunTimeContext::isAlreadyValid(NodePtr node){
return getValidNodes().find(node) != getValidNodes().end();
std::set<NodePtr> nodes = getValidNodes();
for(const auto& nodeV : nodes){
if(nodeV.get() == node.get()){
return true;
}
}
return false;
//return getValidNodes().find(node) != getValidNodes().end();
}
bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext){
......
......@@ -148,7 +148,7 @@ std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstE
//pratt
while (mCurrentToken->getType() != ConditionalTokenTypes::STOP ) //security
{
std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy();
token = mCurrentToken->copy();
//if the token is not in the map is not a operator so we consider a prec of 0
if (ConditionalPrec.find(token->getType()) ==ConditionalPrec.end() ){
return left;
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
using namespace Aidge;
TEST_CASE("get Delta") {
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3");
std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5");
std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4");
std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5");
g1->add(conv);
g1->addChild(conv1, "c");
std::set<Aidge::NodePtr> see;
conv->getNodeDelta(0,see);
SECTION("Self return") {
REQUIRE(conv->getNodeDelta(0,see) == std::set<std::shared_ptr<Node>>{conv});
}
}
\ No newline at end of file
......@@ -13,12 +13,13 @@ using namespace Aidge;
TEST_CASE("matchFSM", "FsmEdge") {
SECTION("FsmEdgeUnique constructor") {
std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false);
std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true);
std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true");
FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest);
SECTION("FsmEdgeUnique constructor") {
REQUIRE(EdgeToTest.getSourceNode() == nodeA);
REQUIRE(EdgeToTest.getDestNode() == nodeB);
REQUIRE(EdgeToTest.isCommon() == false);
......@@ -103,8 +104,8 @@ TEST_CASE("matchFSM", "FsmEdge") {
//make the edges
std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true");
std::shared_ptr<FsmEdgeUnique> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest);
std::shared_ptr<FsmEdgeUnique> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest);
std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest);
std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest);
std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>();
......@@ -128,8 +129,8 @@ TEST_CASE("matchFSM", "FsmEdge") {
//make the edges
std::shared_ptr<FsmEdgeUnique> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest);
std::shared_ptr<FsmEdgeUnique> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest);
std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest);
std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest);
std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>();
graph->addEdge(edgeAB);
......@@ -145,8 +146,8 @@ TEST_CASE("matchFSM", "FsmEdge") {
std::shared_ptr<FsmNode> node2C = std::make_shared<FsmNode>(true,false);
std::shared_ptr<FsmEdgeUnique> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest);
std::shared_ptr<FsmEdgeUnique> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest);
std::shared_ptr<FsmEdge> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest);
std::shared_ptr<FsmEdge> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest);
std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>();
......
#include <catch2/catch_test_macros.hpp>
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graphRegex/GraphFsmInterpreter.hpp"
using namespace Aidge;
TEST_CASE("FsmMatch") {
SECTION("Construction") {
std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = {
{"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")},
{"B",std::make_shared<ConditionalInterpreter>("isConv($)==true")},
{"C",std::make_shared<ConditionalInterpreter>("true==true")}
};
allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";});
allTest["B"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";});
//GraphFsmInterpreter("A->B",allTest);
std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->B",allTest);
std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret();
REQUIRE(fsm->getNodes().size() == 3);
REQUIRE(fsm->getStartNodes().size() == 1);
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3");
std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5");
std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4");
std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5");
g1->add(conv);
g1->addChild(conv1, "c");
REQUIRE(allTest["A"]->test(conv) == true);
REQUIRE(allTest["B"]->test(conv) == true);
std::vector<std::shared_ptr<Node>> startNodes = {conv};
fsm->test(startNodes);
}
}
\ No newline at end of file
......@@ -21,7 +21,22 @@ TEST_CASE("GraphFsmInterpreter", "GraphFsmInterpreter") {
REQUIRE(fsm->getNodes().size() == 3);
REQUIRE(fsm->getStartNodes().size() == 1);
REQUIRE(fsm->getEdge().size() == 2);
for(auto node : fsm->getNodes()){
if(node->isValid()){
REQUIRE(node->getEdges().size() == 0);
}else{
REQUIRE(node->getEdges().size() == 1);
}
}
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment