diff --git a/aidge/_Core/tests/CMakeLists.txt b/aidge/_Core/tests/CMakeLists.txt index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..fbd9ec8bb132fb9856f12bc01c473d0fda7f94cc 100644 --- a/aidge/_Core/tests/CMakeLists.txt +++ b/aidge/_Core/tests/CMakeLists.txt @@ -0,0 +1,25 @@ + +enable_testing() + +Include(FetchContent) + +FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v3.0.1 # or a later release +) + +FetchContent_MakeAvailable(Catch2) + +file(GLOB_RECURSE src_files "*.cpp") + +add_executable(tests_core ${src_files}) + +target_link_libraries(tests_core PUBLIC core) + +target_link_libraries(tests_core PRIVATE Catch2::Catch2WithMain) + +list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) +include(CTest) +include(Catch) +catch_discover_tests(tests_core) diff --git a/aidge/_Core/tests/Test_Connector.cpp b/aidge/_Core/tests/Test_Connector.cpp new file mode 100644 index 0000000000000000000000000000000000000000..96f5e01653d740cf13d97781343ddd2264edb18c --- /dev/null +++ b/aidge/_Core/tests/Test_Connector.cpp @@ -0,0 +1,246 @@ +#include <catch2/catch_test_macros.hpp> + +#include "graph/Connector.hpp" +#include "graph/Node.hpp" +#include "operator/GenericOperator.hpp" +#include "graph/GraphView.hpp" +#include "graph/OpArgs.hpp" + +using namespace Aidge; + +TEST_CASE("Connector Creation", "[Connector]") { + SECTION("Empty") { + Connector x = Connector(); + REQUIRE(x.index() == gk_IODefaultIndex); + REQUIRE(x.node() == nullptr); + } + SECTION("0 output") { + std::shared_ptr<Node> node = GenericOperator("Producer",1,1,0); + Connector x = Connector(node); + REQUIRE(x.index() == gk_IODefaultIndex); + REQUIRE(x.node() == node); + } + SECTION("1 output") { + std::shared_ptr<Node> node = GenericOperator("ReLU",1,1,1); + Connector x = Connector(node); + REQUIRE(x.index() == 0); + REQUIRE(x.node() == node); + } + SECTION("Several outputs") { + std::shared_ptr<Node> node = GenericOperator("Split",1,1,2); + Connector x = Connector(node); + REQUIRE(x.index() == gk_IODefaultIndex); + REQUIRE(x.node() == node); + } +} + +TEST_CASE("Connector connections Node", "[Connector]") { + SECTION("0 input / 0 output") { + std::shared_ptr<Node> fic = GenericOperator("Display",0,0,0); + Connector x; + x = (*fic)({}); + REQUIRE(x.node() == fic); + } + SECTION("1 input / 0 output") { + std::shared_ptr<Node> fic = GenericOperator("Loss",1,1,0); + Connector x; + x = (*fic)({x}); + REQUIRE(x.node() == fic); + } + SECTION("0 input / 1 output") { // Producers + std::shared_ptr<Node> fic = GenericOperator("Producer",0,0,1); + Connector x = (*fic)({}); + REQUIRE(x.node() == fic); + } + SECTION("1 input / 1 output") { + std::shared_ptr<Node> fic = GenericOperator("Conv",1,1,1); + Connector x(GenericOperator("Producer",0,0,1)); + x = (*fic)({x}); + REQUIRE(x.node() ==fic); + } + SECTION("2+ inputs / 1 output") { // ElemWise + std::shared_ptr<Node> fic = GenericOperator("fictive",3,3,1); + Connector x1(GenericOperator("fictive",0,0,1)); + Connector x2(GenericOperator("fictive",0,0,1)); + Connector x3(GenericOperator("fictive",0,0,1)); + Connector x = (*fic)({x1, x2, x3}); + REQUIRE(x.node() ==fic); + } + SECTION("1 input / 2+ outputs") { // Slice + std::shared_ptr<Node> fic = GenericOperator("fictive",1,1,3); + + Connector x(GenericOperator("fictive2",0,0,1)); + Connector y; + REQUIRE_NOTHROW(y = (*fic)({x})); + REQUIRE(y[0].node() == fic); + REQUIRE(y[1].node() == fic); + REQUIRE(y[2].node() == fic); + } +} + +TEST_CASE("GraphGeneration from Connector", "[GraphView]") { + + auto node01 = GenericOperator("Conv",0,0,1,"g_conv1"); + auto node02 = GenericOperator("ReLU",1,1,1,"g_relu"); + auto node03 = GenericOperator("g_maxpool1", 1,1,1); + auto node04 = GenericOperator("g_conv2_par1",1,1,1); + auto node05 = GenericOperator("g_relu2_par1", 1,1,1); + auto node06 = GenericOperator("g_conv2_par2", 1,1,1); + auto node07 = GenericOperator("g_relu2_par2", 1,1,1); + auto node08 = GenericOperator("g_concat", 2,2,1); + auto node09 = GenericOperator("g_conv3", 1, 1,1); + auto node10 = GenericOperator("g_matmul1", 2,2,1); + Connector a = (*node01)({}); + Connector x = (*node02)({a}); + x = (*node03)({x}); + Connector y = (*node04)({x}); + y = (*node05)({y}); + Connector z = (*node06)({x}); + z = (*node07)({z}); + x = (*node08)({y, z}); + x= (*node09)({x}); + x = (*node10)({a, x}); + std::shared_ptr<GraphView> gv = generateGraph({x}); + gv->save("GraphGeneration"); +} + +TEST_CASE("Connector connection GraphView", "[Connector]") { + SECTION("1 input") { + Connector x = Connector(); + auto prod = GenericOperator("Producer",0,0,1); + auto g = Residual({ + GenericOperator("g_conv1", 1,1,1), + GenericOperator("g_relu", 1,1,1), + GenericOperator("g_maxpool1", 1,1,1), + Parallel({ + Sequential({GenericOperator("g_conv2_par1",1,1,1), GenericOperator("g_relu2_par1", 1,1,1)}), + Sequential({GenericOperator("g_conv2_par2", 1,1,1), GenericOperator("g_relu2_par2", 1,1,1)}) + }), + GenericOperator("g_concat", 2,2,1), + GenericOperator("g_conv3", 1, 1,1), + GenericOperator("g_matmul1", 2,2,1) + }); + x = (*prod)({}); + x = (*g)({x}); + std::shared_ptr<GraphView> g2 = generateGraph({x}); + std::shared_ptr<GraphView> g3 = g; + g3->add(prod); + REQUIRE(*g3== *g2); + } + SECTION("2+ inputs") { + Connector x = (*GenericOperator("Producer",0,0,1))({}); + Connector y = (*GenericOperator("Producer",0,0,1))({}); + Connector z = (*GenericOperator("Producer",0,0,1))({}); + auto g = Sequential({GenericOperator("ElemWise", 3,3,1), + Parallel({ + Sequential({GenericOperator("g_conv2_par1",1,1,1), GenericOperator("g_relu2_par1", 1,1,1)}), + Sequential({GenericOperator("g_conv2_par2", 1,1,1), GenericOperator("g_relu2_par2", 1,1,1)}), + Sequential({GenericOperator("g_conv2_par3", 1,1,1), GenericOperator("g_relu2_par3", 1,1,1)}) + }), + GenericOperator("g_concat", 3,3,1), + GenericOperator("g_conv3", 1, 1,1) + }); + + x = (*g)({x, y, z}); + std::shared_ptr<GraphView> gv = generateGraph({x}); + gv->save("MultiInputSequentialConnector"); + REQUIRE(gv->inputNodes().size() == 0U); + } +} + +TEST_CASE("Connector Mini-graph", "[Connector]") { + Connector x = Connector(); + Connector y = Connector(); + x = (*GenericOperator("Producer",0,0,1))({}); + y = (*GenericOperator("Producer",0,0,1))({}); + for (int i = 0; i<5; ++i) { + x = (*GenericOperator("Conv",1,1,1))({x}); + } + y = (*GenericOperator("ElemWise",2,2,1))({y, x}); + std::shared_ptr<GraphView> g = generateGraph({y}); + g->save("TestGraph"); +} + +TEST_CASE("Structural descrition - Sequential", "[GraphView]") { + // SECTION("Empty Sequence") { + // std::shared_ptr<GraphView> g1 = Sequential(); // Not supported + // REQUIRE(g1->getNodes() == std::set<std::shared_ptr<Node>>()); + // REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>()); + // REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>()); + // } + SECTION("1-element Sequence") { + std::shared_ptr<Node> fic = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Sequential({fic}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic})); + } + SECTION("several-elements simple Sequence") { + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Sequential({fic1, fic2, fic3}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic3})); + } + +} + +TEST_CASE("Structural description - Parallel", "[GraphView]") { + // SECTION("Empty Parallel") { + // std::shared_ptr<GraphView> g1 = Parallel(); // Not supported + // REQUIRE(g1->getNodes() == std::set<std::shared_ptr<Node>>()); + // REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>()); + // REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>()); + // } + SECTION("1-element Parallel") { + std::shared_ptr<Node> fic = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Parallel({fic}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic})); + } + SECTION("several-elements simple Parallel") { + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Parallel({fic1, fic2, fic3}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + } + SECTION("1 Graph in Parallel") { + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Parallel({Sequential({fic1, fic2, fic3})}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic3})); + } + SECTION("several Sequential in Parallel") { + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic4 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic5 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic6 = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Parallel({Sequential({fic1, fic2, fic3}),Sequential({fic4, fic5, fic6})}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3, fic4, fic5, fic6})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic4})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic3, fic6})); + } +} + +TEST_CASE("Strucutral Description - Complex Graph", "[GraphView]") { + std::shared_ptr<Node> firstLayer = GenericOperator("first", 1,1,1); + auto g = Sequential({firstLayer, + GenericOperator("l2",1,1,1), + Parallel({Sequential({GenericOperator("conv1",1,1,1), GenericOperator("relu1",1,1,1)}), + Sequential({GenericOperator("conv2",1,1,1), GenericOperator("relu2",1,1,1)})}), + GenericOperator("concat",2,2,1), + GenericOperator("lastLayer",1,1,1)}); + REQUIRE(g->getNodes().size() == 8U); + REQUIRE(g->inputNodes() == std::set<std::shared_ptr<Node>>({firstLayer})); +} diff --git a/aidge/_Core/tests/Test_GRegex.cpp b/aidge/_Core/tests/Test_GRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9fa7164992a211e0108641101a31f38078f3d499 --- /dev/null +++ b/aidge/_Core/tests/Test_GRegex.cpp @@ -0,0 +1,295 @@ +#include <iostream> +#include <map> +#include <memory> +#include <vector> +#include <utility> +#include <cassert> + +#include <catch2/catch_test_macros.hpp> +//test +#include "graphmatching/GRegex.hpp" +#include "graphmatching/StmFactory.hpp" +#include "graphmatching/SeqStm.hpp" +#include "graphmatching/NodeRegex.hpp" +#include "graphmatching/Match.hpp" +//use +#include "backend/OperatorImpl.hpp" +#include "operator/GenericOperator.hpp" +#include "operator/Producer.hpp" +#include "graph/GraphView.hpp" + +using namespace Aidge; + +TEST_CASE("Create good init GRegex", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("A->B;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + // Perform tests + REQUIRE(GReg.getStmInit().size() == 1); + REQUIRE(GReg.getStmFab().getNumberOfStm() == 1); +} + + +TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"Conv","BN","ReLU"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("Conv->BN->ReLU;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); + std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); + std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); + std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); + std::shared_ptr<Node> Random2 = GenericOperator("Random2", 1, 1, 1); + + + g1->add(Conv1); + g1->addChild(BN1, Conv1); + g1->addChild(ReLU1, BN1); + g1->addChild(Random, ReLU1); + //g1->addChild(BN1, Random2); + + std::vector<std::shared_ptr<Node>> startNodes1; + std::set<std::shared_ptr<Node>> result; + + startNodes1.push_back(Conv1); + result = GReg.matchFromStartNodes(startNodes1, g1); + + std::set<std::shared_ptr<Node>> true_result; + true_result.insert(Conv1); + true_result.insert(BN1); + true_result.insert(ReLU1); + + // Perform tests + REQUIRE(result == true_result); +} + +TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"Add","FC","Conv"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("Add#->Conv;"); + seqRegex.push_back("Add#->FC;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + // Instanciate a graphView + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); + std::shared_ptr<Node> Add1 = GenericOperator("Add", 1, 1, 1); + std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); + std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); + std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); + std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); + std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); + + g1->add(Random0); + g1->addChild(Add1, Random0); + g1->addChild(Conv1, Add1); + g1->addChild(BN1, Conv1); + g1->addChild(ReLU1, BN1); + g1->addChild(FC1, Add1); + g1->addChild(Random, FC1); + + // Test 1 : Find the match + std::vector<std::shared_ptr<Node>> startNodes; + std::set<std::shared_ptr<Node>> result; + + startNodes.push_back(Add1); + startNodes.push_back(Add1); + result = GReg.matchFromStartNodes(startNodes, g1); + + std::set<std::shared_ptr<Node>> true_result; + true_result.insert(Add1); + true_result.insert(Conv1); + true_result.insert(FC1); + + // Test 2 : Return an empty set when the start nodes are wrong + std::vector<std::shared_ptr<Node>> wrong_startNodes; + std::set<std::shared_ptr<Node>> wrong_start_result; + std::set<std::shared_ptr<Node>> empty_result; + + wrong_startNodes.push_back(Random0); + wrong_startNodes.push_back(Random0); + wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); + + // Perform tests + REQUIRE(result == true_result); + REQUIRE(wrong_start_result == empty_result); +} + +/* +TEST_CASE("Function matchFromStartNodes | Match a sequence with quantifier ", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"FC"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("FC+;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + + // Instanciate a graphView + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); + std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); + std::shared_ptr<Node> FC2 = GenericOperator("FC", 1, 1, 1); + std::shared_ptr<Node> FC3 = GenericOperator("FC", 1, 1, 1); + std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); + + g1->add(Random0); + g1->addChild(FC1, Random0); + g1->addChild(FC2, FC1); + g1->addChild(FC3, FC2); + g1->addChild(ReLU1, FC3); + + // Test 1 : Find the match + std::vector<std::shared_ptr<Node>> startNodes; + std::set<std::shared_ptr<Node>> result; + + startNodes.push_back(FC1); + result = GReg.matchFromStartNodes(startNodes, g1); + + std::set<std::shared_ptr<Node>> true_result; + true_result.insert(FC1); + true_result.insert(FC2); + true_result.insert(FC3); + + // Test 2 : Return an empty set when the start nodes are wrong + std::vector<std::shared_ptr<Node>> wrong_startNodes; + std::set<std::shared_ptr<Node>> wrong_start_result; + std::set<std::shared_ptr<Node>> empty_result; + + wrong_startNodes.push_back(Random0); + wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); + + // Perform tests + REQUIRE(result == true_result); + REQUIRE(wrong_start_result == empty_result); +} +*/ + +TEST_CASE("Function match | ALL matches of Nodes sequence", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"GEMM"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("GEMM;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + //init the input graph + std::shared_ptr<GraphView> graphToMatch = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); + std::shared_ptr<Node> GEMM1 = GenericOperator("GEMM", 1, 1, 1); + std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); + std::shared_ptr<Node> GEMM2 = GenericOperator("GEMM", 1, 1, 1); + std::shared_ptr<Node> GEMM3 = GenericOperator("GEMM", 1, 1, 1); + std::shared_ptr<Node> ReLU2 = GenericOperator("ReLU", 1, 1, 1); + std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); + + graphToMatch->add(Random0); + graphToMatch->addChild(GEMM1, Random0); + graphToMatch->addChild(ReLU1, GEMM1); + graphToMatch->addChild(GEMM2, ReLU1); + graphToMatch->addChild(GEMM3, GEMM2); + graphToMatch->addChild(ReLU2, GEMM3); + graphToMatch->addChild(Random, ReLU2); + + + //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); + //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); + Match matches = GReg.match(graphToMatch); + + size_t nb = matches.getNbMatch(); + std::vector<std::vector<NodeTmp>> gm_startnodes = matches.getStartNodes(); + std::vector<std::set<NodeTmp>> gm_matchnodes = matches.getMatchNodes(); + + std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs; + + for (size_t i = 0; i < nb; ++i) { + matchs.insert(std::make_pair(gm_startnodes[i], gm_matchnodes[i])); + } + + //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; + std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; + // Carefull : as the assert is on a vector, the Order of match matters + std::vector<NodeTmp> startNode = {GEMM1}; + std::set<NodeTmp> matchNode = {GEMM1}; + //toMatchs.push_back(std::make_pair(startNode,matchNode)); + toMatchs.insert(std::make_pair(startNode,matchNode)); + + std::vector<NodeTmp> startNode2 = {GEMM2}; + std::set<NodeTmp> matchNode2 = {GEMM2}; + //toMatchs.push_back(std::make_pair(startNode2,matchNode2)); + toMatchs.insert(std::make_pair(startNode2,matchNode2)); + + std::vector<NodeTmp> startNode3 = {GEMM3}; + std::set<NodeTmp> matchNode3 = {GEMM3}; + //toMatchs.push_back(std::make_pair(startNode3,matchNode3)); + toMatchs.insert(std::make_pair(startNode3,matchNode3)); + + REQUIRE(matchs == toMatchs); + REQUIRE(nb == 3); +} + + diff --git a/aidge/_Core/tests/Test_GenericOperator.cpp b/aidge/_Core/tests/Test_GenericOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c78cd23df6cb224a44ca630525e88591f21022e --- /dev/null +++ b/aidge/_Core/tests/Test_GenericOperator.cpp @@ -0,0 +1,66 @@ +#include <catch2/catch_test_macros.hpp> + +#include "operator/GenericOperator.hpp" +#include "graph/GraphView.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[core/operators] GenericOp(add & get parameters)", "[Operator]") { + SECTION("INT") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + int value = 5; + const char* key = "intParam"; + Testop.addParameter<int>(key, value); + REQUIRE(Testop.getParameter<int>(key) == value); + } + SECTION("LONG") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + long value = 3; + const char* key = "longParam"; + Testop.addParameter<long>(key, value); + REQUIRE(Testop.getParameter<long>(key) == value); + } + SECTION("FLOAT") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + float value = 2.0; + const char* key = "floatParam"; + Testop.addParameter<float>(key, value); + REQUIRE(Testop.getParameter<float>(key) == value); + } + SECTION("VECTOR<INT>") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + std::vector<int> value = {1, 2}; + const char* key = "vect"; + Testop.addParameter<std::vector<int>>(key, value); + + REQUIRE(Testop.getParameter<std::vector<int>>(key).size() == value.size()); + for (std::size_t i=0; i < value.size(); ++i){ + REQUIRE(Testop.getParameter<std::vector<int>>(key)[i] == value[i]); + } + } + SECTION("MULTIPLE PARAMS") { + /* + Goal : Test that the offsets are well done by adding different parameters with different size. + */ + GenericOperator_Op Testop("TestOp", 1, 1, 1); + Testop.addParameter<long>("longParam", 3); + Testop.addParameter<float>("floatParam", 2.0); + Testop.addParameter<uint8_t>("uint8Param", 5); + Testop.addParameter<long long>("llParam", 10); + REQUIRE(Testop.getParameter<long>("longParam") == 3); + REQUIRE(Testop.getParameter<float>("floatParam") == 2.0); + REQUIRE(Testop.getParameter<uint8_t>("uint8Param") == 5); + REQUIRE(Testop.getParameter<long long>("llParam") == 10); + } +} + +TEST_CASE("[core/operator] GenericOp(type check)", "[.ass]") { + SECTION("WRONG TYPE FOR GETTER") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + Testop.addParameter<long>("longParam", 3); + + // This line should raise a failled assert + REQUIRE_THROWS(Testop.getParameter<int>("longParameter")); + } +} diff --git a/aidge/_Core/tests/Test_GraphView.cpp b/aidge/_Core/tests/Test_GraphView.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3e2d1553867d3e6a326f0c99ac8fc88d1495b460 --- /dev/null +++ b/aidge/_Core/tests/Test_GraphView.cpp @@ -0,0 +1,311 @@ +#include <cassert> +#include <map> +#include <memory> +#include <string> + +#include <catch2/catch_test_macros.hpp> + +#include "backend/OperatorImpl.hpp" +#include "data/Tensor.hpp" +#include "graph/GraphView.hpp" +#include "operator/Conv.hpp" +#include "operator/GenericOperator.hpp" +#include "operator/Producer.hpp" + +using namespace Aidge; + +TEST_CASE("[core/graph] GraphView(Constructor)") { + std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>(); + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1"); + REQUIRE(g0 != nullptr); + REQUIRE(g1 != nullptr); +} + +TEST_CASE("[core/graph] GraphView(add)") { + SECTION("Node alone") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); + g->add(GOp1); + std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 1, "Gop2"); + g->add(GOp2); + std::shared_ptr<Node> GOp3 = GenericOperator("Fictive", 1, 1, 0, "Gop3"); + g->add(GOp3); + std::shared_ptr<Node> GOp4 = GenericOperator("Fictive", 0, 1, 0, "Gop4"); + g->add(GOp4); + std::shared_ptr<Node> GOp5 = GenericOperator("Fictive", 1, 1, 1, "Gop5"); + g->add(GOp5); + std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 2, 1, "Gop6"); + g->add(GOp6); + } + + SECTION("Several Nodes") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + // should automaticaly add parents for learnable parameters + std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 1, 1, "Gop1"); + std::shared_ptr<Node> GOp1parent = GenericOperator("Fictive", 0, 0, 1, "Gop1parent"); + GOp1parent->addChild(GOp1, 0, 0); + g->add(GOp1); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); + + // there should be no deplicates + g->add(GOp1); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); + } + + SECTION("Initializer list ofr Node") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); + std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 0, "Gop2"); + g->add({GOp1, GOp1, GOp2}); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp2})); + } + + SECTION("another GraphView") { + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph-1"); + std::shared_ptr<GraphView> g2 = std::make_shared<GraphView>("TestGraph-2"); + auto conv = GenericOperator("Conv", 1, 1, 1, "c"); + auto conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + auto conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + auto conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + auto conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + conv->addChild(conv1); + conv1->addChild(conv2); + conv2->addChild(conv3); + conv3->addChild(conv4); + g1->add({conv, conv1, conv2, conv3, conv4}); + g2->add(g1); + REQUIRE(((g1->getNodes() == g2->getNodes()) && (g2->getNodes() == std::set<std::shared_ptr<Node>>({conv, conv1, conv2, conv3, conv4})))); + REQUIRE(((g1->inputNodes() == g2->inputNodes()) && + (g2->inputNodes() == std::set<std::shared_ptr<Node>>({conv})))); + REQUIRE(((g1->outputNodes() == g2->outputNodes()) && + (g2->outputNodes() == std::set<std::shared_ptr<Node>>({conv4})))); + } +} + +TEST_CASE("[core/graph] GraphView(addChild)") { + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + SECTION("add(node)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv}); + } + g1->addChild(conv1, "c"); + SECTION("add(node, outputNodeName)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv1}); + REQUIRE(conv->getChildren() == std::set<std::shared_ptr<Node>>({conv1})); + REQUIRE(conv1->getParents() == std::vector<std::shared_ptr<Node>>({conv})); + } + g1->addChild(conv2, "c1", 0); + SECTION("add(node, pair<outputNodeName, outID>)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv2}); + REQUIRE(conv1->getChildren() == std::set<std::shared_ptr<Node>>({conv2})); + REQUIRE(conv2->getParents() == std::vector<std::shared_ptr<Node>>({conv1})); + } + g1->addChild(conv3, "c2", 0, 0); + SECTION("add(node, list(outputNodeName))") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3}); + REQUIRE(conv2->getChildren() == std::set<std::shared_ptr<Node>>({conv3})); + REQUIRE(conv3->getParents() == std::vector<std::shared_ptr<Node>>({conv2})); + } + g1->addChild(conv3_5, conv3); + SECTION("add(node, list(outputNodeName))") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3_5}); + REQUIRE(conv3->getChildren() == std::set<std::shared_ptr<Node>>({conv3_5})); + REQUIRE(conv3_5->getParents() == + std::vector<std::shared_ptr<Node>>({conv3})); + } + g1->addChild(conv4, conv3_5, 0); + SECTION("add(node, vector<pair<outputNodeName, outID>>)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv4}); + REQUIRE(conv3_5->getChildren() == std::set<std::shared_ptr<Node>>({conv4})); + REQUIRE(conv4->getParents() == + std::vector<std::shared_ptr<Node>>({conv3_5})); + } + g1->addChild(conv5, conv4, 0, 0); + SECTION("add(node, vector<pair<outputNodeName, outID>>)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv5}); + REQUIRE(conv4->getChildren() == std::set<std::shared_ptr<Node>>({conv5})); + REQUIRE(conv5->getParents() == std::vector<std::shared_ptr<Node>>({conv4})); + } + std::set<std::shared_ptr<Node>> requiredNodes = {conv, conv1, conv2, conv3, + conv3_5, conv4, conv5}; + REQUIRE(g1->getNodes() == requiredNodes); + REQUIRE(g1->getChildren(conv3) == std::set<std::shared_ptr<Node>>({conv3_5})); +} + +TEST_CASE("[core/graph] GraphView(inputs)") { + auto g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); + g1->add(conv); + + REQUIRE(g1->inputs() == conv->inputs()); +} + +TEST_CASE("[core/graph] GraphView(outputs)") { + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); + g1->add(conv); + + REQUIRE(g1->outputs() == conv->outputs()); +} + +TEST_CASE("[core/graph] GraphView(save)") { + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + g1->addChild(conv1, "c"); + g1->addChild(conv2, "c1", 0); + g1->addChild(conv3, "c2"); + g1->addChild(conv4, "c3", 0); + g1->addChild(conv5, "c4", 0, 0); + + g1->save("./graphExample"); + printf("File saved in ./graphExample.md\n"); +} + +TEST_CASE("[core/graph] GraphView(resetConnections)") { + SECTION("disconnect data iput") { + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 3, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1"); + std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2"); + conv->addChild(conv1); + prod1->addChild(conv1,0,1); + prod2->addChild(conv1,0,2); + conv1->addChild(conv2); + + conv1->resetConnections(false); + + REQUIRE(conv->output(0).size() == 0); + for (std::size_t i = 0; i < conv1->nbDataInputs(); ++i) { + REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); + } + REQUIRE((conv1->input(1) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod1, 0))); + REQUIRE((conv1->input(2) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod2, 0))); + REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); + for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) { + REQUIRE(conv->output(i).size() == 0U); + } + } + + SECTION("disconnect data iput + learnable parameters") { + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 3, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1"); + std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2"); + conv->addChild(conv1); + prod1->addChild(conv1,0,1); + prod2->addChild(conv1,0,2); + conv1->addChild(conv2); + + conv1->resetConnections(true); + + REQUIRE(conv->output(0).size() == 0); + for (std::size_t i = 0; i < conv1->nbInputs(); ++i) { + REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); + } + REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); + for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) { + REQUIRE(conv->output(i).size() == 0U); + } + } +} + +TEST_CASE("Graph Forward dims", "[GraphView]") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g->add(conv1); + g->addChild(conv2, conv1, 0); + g->addChild(conv3, conv2, 0); + g->save("graphForwardDims"); + g->forwardDims(); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == + conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == + g->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == + g->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == + conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == + g->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == + g->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == + conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == + g->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == + g->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + SECTION("Check forwarded dims") { + REQUIRE(std::static_pointer_cast<Tensor>(conv1->getOperator()->getOutput(0)) + ->dims() == std::vector<DimSize_t>({16, 32, 222, 222})); + REQUIRE(std::static_pointer_cast<Tensor>(conv2->getOperator()->getOutput(0)) + ->dims() == std::vector<DimSize_t>({16, 64, 220, 220})); + } +} + +TEST_CASE("[core/graph] GraphView(replaceWith)") { + // create original graph + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); + auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w"); + auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b"); + auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); + auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); + auto matmul = GenericOperator("MatMul", 1, 2, 1, "matmul"); + auto add = GenericOperator("Add", 1, 2, 1, "add"); + otherInput->addChild(other1); + other1->addChild(matmul); + matmul->addChild(add); + add->addChild(other2); + matmulWeight->addChild(matmul, 0, 1); + addBias->addChild(add, 0, 1); + g->add({other1, matmul, add, other2}); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add})); + + // create graph to replace + std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>(); + nodeToReplace->add({matmul, add}, false); + + // create replacing graph + std::shared_ptr<Node> newNode = GenericOperator("FC", 1, 3, 1, "fc"); + other1->addChild(newNode); + matmulWeight->addChild(newNode, 0, 1); + addBias->addChild(newNode, 0, 2); + + // replace + nodeToReplace->replaceWith({newNode}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, newNode})); +} \ No newline at end of file diff --git a/aidge/_Core/tests/Test_NodeRegex.cpp b/aidge/_Core/tests/Test_NodeRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72bcc08f79a899090799400d4b2fc23f25aa41d4 --- /dev/null +++ b/aidge/_Core/tests/Test_NodeRegex.cpp @@ -0,0 +1,33 @@ +#include <iostream> +#include <map> +#include <memory> +#include <cassert> + +#include <catch2/catch_test_macros.hpp> + +#include "backend/OperatorImpl.hpp" +#include "graphmatching/NodeRegex.hpp" +#include "operator/GenericOperator.hpp" + + +using namespace Aidge; + +TEST_CASE("Create Noderegex", "[Noderegex]") { + std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("conv"); +} + +TEST_CASE("Test _is function", "[Noderegex]") { + // Create Noderegex with only condition on the name of the Node + // Create several operators to pass into Noderegex _is function + // Assert Noderegex._is(operators) are correct + std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("Conv"); + + std::shared_ptr<Node> Conv = GenericOperator("Conv", 1, 1, 1); + std::shared_ptr<Node> FC = GenericOperator("FC", 1, 1, 1); + + REQUIRE(nr->_is(Conv) == true); + REQUIRE(nr->_is(FC) == false); + REQUIRE(nr->isA("Conv") == true); + REQUIRE(nr->isA("FC") == false); + +} \ No newline at end of file diff --git a/aidge/_Core/tests/Test_SeqStm.cpp b/aidge/_Core/tests/Test_SeqStm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9517a6b5eb92f7cf1d2f77937cd5390850052d21 --- /dev/null +++ b/aidge/_Core/tests/Test_SeqStm.cpp @@ -0,0 +1,148 @@ +#include <iostream> +#include <map> +#include <memory> +#include <vector> +#include <utility> +#include <cassert> + +#include <catch2/catch_test_macros.hpp> +//test +#include "graphmatching/SeqStm.hpp" +#include "graphmatching/NodeRegex.hpp" +//use +#include "backend/OperatorImpl.hpp" +#include "operator/GenericOperator.hpp" +#include "operator/Producer.hpp" + +using namespace Aidge; + +TEST_CASE("Create good init SeqStm", "[SeqStm]") { + //init all iniput for SeqStm + + + int stmIdx = 0; + //matrix that in B->C + std::vector<std::vector<int>> transitionMatrix { + { -1, 1, -1 }, + { -1, -1, 2 }, + { -1, -1, -1 } }; + + //std::cout << transitionMatrix.size() << "\n"; + // init the nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + // + + std::map<NodeTypeKey,int> typeToIdxTransition; + std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; + //init nodeTypeCommonTag + int idx = 0; + for (const NodeTypeKey& key : nodeTypeCommonTag) { + typeToIdxTransition[key] = idx; + idx += 1; + } + + int actSt = 0; + std::set<NodeTmp> allNodeValidated; + std::set<NodeTmp> allNodeTested; + std::set<std::pair<NodeTmp,std::string>> allCommonNode; + bool stmIsValid =false; + + + SeqStm stm( + stmIdx, + transitionMatrix, + nodesRegex, + typeToIdxTransition, + actSt, + allNodeValidated, + allNodeTested, + allCommonNode, + stmIsValid); + + REQUIRE(stm.getStmIdx() == 0); + REQUIRE(stm.isValid() == false); + REQUIRE(stm.getAllCommonNode().size() == 0); + REQUIRE(stm.getAllNodeTested().size() == 0); + REQUIRE(stm.getAllNodeValidated().size() == 0); +} + +TEST_CASE("Test testNode function", "[SeqStm]") { + + int stmIdx = 0; + std::map<NodeTypeKey,int> typeToIdxTransition; + std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; + //init nodeTypeCommonTag + int idx = 0; + for (const NodeTypeKey& key : nodeTypeCommonTag) { + typeToIdxTransition[key] = idx; + idx += 1; + } + //matrix that in B->C + std::vector<std::vector<int>> transitionMatrix { + { -1, 1, -1 }, + { -1, -1, 2 }, + { -1, -1, -1 } }; + + //std::cout << transitionMatrix.size() << "\n"; + // init the nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + // + int actSt = 0; + std::set<NodeTmp> allNodeValidated; + std::set<NodeTmp> allNodeTested; + std::set<std::pair<NodeTmp,std::string>> allCommonNode; + bool stmIsValid =false; + + SeqStm stm( + stmIdx, + transitionMatrix, + nodesRegex, + typeToIdxTransition, + actSt, + allNodeValidated, + allNodeTested, + allCommonNode, + stmIsValid); + REQUIRE(stm.getStmIdx() == 0); + //test a node + std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); + std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); + + //set use to test the state of the smt + std::set<NodeTmp> testAllNodeTested; + std::set<NodeTmp> testAllNodeValidated; + + stm.testNode(nodeB); + REQUIRE(stm.isValid() == false); + REQUIRE(stm.getState() == 1); + REQUIRE(stm.isStmBlocked() == false); + testAllNodeTested.insert(nodeB); + testAllNodeValidated.insert(nodeB); + REQUIRE(stm.getAllNodeTested() == testAllNodeTested); + REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); + + + stm.testNode(nodeC); + REQUIRE(stm.isValid() == true); + REQUIRE(stm.getState() == 2); + REQUIRE(stm.isStmBlocked() == false); + testAllNodeTested.insert(nodeC); + testAllNodeValidated.insert(nodeC); + REQUIRE(stm.getAllNodeTested() == testAllNodeTested); + REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); + + stm.testNode(nodeC); + REQUIRE(stm.isValid() == true); + REQUIRE(stm.getState() == -1); + REQUIRE(stm.isStmBlocked() == true); + REQUIRE(stm.getAllNodeTested() == testAllNodeTested); + REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); +} \ No newline at end of file diff --git a/aidge/_Core/tests/Test_StmFactory.cpp b/aidge/_Core/tests/Test_StmFactory.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ca17978053996f7e004bbd12a3db7f4261b8487a --- /dev/null +++ b/aidge/_Core/tests/Test_StmFactory.cpp @@ -0,0 +1,178 @@ +#include <iostream> +#include <map> +#include <memory> +#include <vector> +#include <utility> +#include <cassert> + +#include <catch2/catch_test_macros.hpp> +//test +#include "graphmatching/StmFactory.hpp" +#include "graphmatching/NodeRegex.hpp" +//use +#include "backend/OperatorImpl.hpp" +#include "operator/GenericOperator.hpp" +#include "operator/Producer.hpp" + +using namespace Aidge; + +TEST_CASE("Create good init StmFactory", "[StmFactory]") { + // init the nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + StmFactory stmF(nodesRegex); + REQUIRE(stmF.getNumberOfStm() == 0); +} + +TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { + + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + StmFactory stmF(nodesRegex); + + std::string seq1 = "A->B+->A#;"; + SeqStm* stm = stmF.makeNewStm(seq1); + REQUIRE(stm->getStmIdx() == 0); + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getAllCommonNode().size() == 0); + REQUIRE(stm->getAllNodeTested().size() == 0); + REQUIRE(stm->getAllNodeValidated().size() == 0); + + std::string seq2 = "A->B;"; + SeqStm* stm2 = stmF.makeNewStm(seq2); + REQUIRE(stm2->getStmIdx() == 1); + REQUIRE(stm2->isValid() == false); + REQUIRE(stm2->getAllCommonNode().size() == 0); + REQUIRE(stm2->getAllNodeTested().size() == 0); + REQUIRE(stm2->getAllNodeValidated().size() == 0); + + //test the number of stm + REQUIRE(stmF.getNumberOfStm() == 2); +} + +TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { + + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + + StmFactory stmF(nodesRegex); + std::string seq1 = "B->C;"; + SeqStm* stm = stmF.makeNewStm(seq1); + //test the number of stm + REQUIRE(stmF.getNumberOfStm() == 1); + + //std::shared_ptr<Node> nodeB = GenericOperator("B",1,1,1); + //std::shared_ptr<Node> nodeC = GenericiOperator("C",1,1,1); + std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); + std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); + //set use to test the state of the smt + std::set<NodeTmp> testAllNodeTested; + std::set<NodeTmp> testAllNodeValidated; + + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getState() == 0); + REQUIRE(stm->isStmBlocked() == false); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + stm->testNode(nodeB); + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getState() == 1); + REQUIRE(stm->isStmBlocked() == false); + testAllNodeTested.insert(nodeB); + testAllNodeValidated.insert(nodeB); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + + stm->testNode(nodeC); + REQUIRE(stm->isValid() == true); + REQUIRE(stm->getState() == 2); + REQUIRE(stm->isStmBlocked() == false); + testAllNodeTested.insert(nodeC); + testAllNodeValidated.insert(nodeC); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + stm->testNode(nodeC); + REQUIRE(stm->isValid() == true); + REQUIRE(stm->getState() == -1); + REQUIRE(stm->isStmBlocked() == true); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + +} + +TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") { + + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + + StmFactory stmF(nodesRegex); + std::string seq1 = "B->C;"; + SeqStm* stm = stmF.makeNewStm(seq1); + SeqStm* stmD = stmF.duplicateStm(stm); + + std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); + std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); + //set use to test the state of the smt + std::set<NodeTmp> testAllNodeTested; + std::set<NodeTmp> testAllNodeValidated; + + //run the stm + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getState() == 0); + REQUIRE(stm->isStmBlocked() == false); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + stm->testNode(nodeB); + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getState() == 1); + REQUIRE(stm->isStmBlocked() == false); + testAllNodeTested.insert(nodeB); + testAllNodeValidated.insert(nodeB); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + + stm->testNode(nodeC); + REQUIRE(stm->isValid() == true); + REQUIRE(stm->getState() == 2); + REQUIRE(stm->isStmBlocked() == false); + testAllNodeTested.insert(nodeC); + testAllNodeValidated.insert(nodeC); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + stm->testNode(nodeC); + REQUIRE(stm->isValid() == true); + REQUIRE(stm->getState() == -1); + REQUIRE(stm->isStmBlocked() == true); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + //check if stmD not move + REQUIRE(stmD->isValid() == false); + REQUIRE(stmD->getState() == 0); + REQUIRE(stmD->isStmBlocked() == false); + REQUIRE(stmD->getAllNodeTested().size() == 0); + REQUIRE(stmD->getAllNodeValidated().size() == 0); +} +