#include <catch2/catch_test_macros.hpp> #include "aidge/graph/GraphView.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/graphRegex/GraphFsmInterpreter.hpp" using namespace Aidge; TEST_CASE("FsmMatch") { SECTION("Construction") { std::vector<std::shared_ptr<ConditionalInterpreter>> allTest = { std::make_shared<ConditionalInterpreter>("A","isConv($)==true"), std::make_shared<ConditionalInterpreter>("B","isConv($)==true"), std::make_shared<ConditionalInterpreter>("C","true==true") }; allTest[0]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); allTest[1]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->A",allTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); //REQUIRE(fsm->getNodes().size() == 3); //REQUIRE(fsm->getStartNodes().size() == 1); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); g1->add(conv); g1->addChild(conv1, "c"); REQUIRE(allTest[0]->test(conv) == true); REQUIRE(allTest[1]->test(conv) == true); std::vector<std::shared_ptr<Node>> startNodes = {conv}; auto result = fsm->test(startNodes); REQUIRE( result[0]->getAll() == std::set<NodePtr>{conv,conv1}); } SECTION("2 branche graph"){ std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Fc", 1, 0, 1, "c2"); g1->add(conv); g1->addChild(conv1,conv); g1->addChild(conv2,conv); REQUIRE(g1->getNodes() == std::set<std::shared_ptr<Node>>({conv,conv1,conv2})); REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>({conv})); REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>({conv1,conv2})); ///////////// std::vector<std::shared_ptr<ConditionalInterpreter>> allTest = { std::make_shared<ConditionalInterpreter>("A","isConv($)==true"), std::make_shared<ConditionalInterpreter>("B","isFc($)==true") }; allTest[0]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); allTest[1]->insertLambda("isFc",+[](NodePtr NodeOp){return NodeOp->type() == "Fc";}); std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A#->A; A#->B",allTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); std::vector<std::shared_ptr<Node>> startNodes = {conv,conv}; auto result = fsm->test(startNodes); REQUIRE( result[0]->getAll() == std::set<NodePtr>{conv,conv1,conv2}); } }