#include <catch2/catch_test_macros.hpp> #include "aidge/graphRegex/GraphRegex.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/Add.hpp" #include "aidge/operator/FC.hpp" #include "aidge/operator/MatMul.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/recipies/Recipies.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" using namespace Aidge; TEST_CASE("GraphRegexUser") { SECTION("INIT") { const std::string query = "Conv->FC"; std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> fc = GenericOperator("FC", 1, 0, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 0, 1, "c3"); g1->add(conv); g1->addChild(fc, "c"); g1->addChild(conv2, "c1"); g1->addChild(fc2, "c2"); sut->setKeyFromGraph(g1); sut->addQuery(query); for (const auto& solution : sut->match(g1)) { REQUIRE(solution->getQuery() == query); if(solution->getStartNode() == std::vector<NodePtr>{conv}){ REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv} ); REQUIRE(solution->at("FC") == std::set<NodePtr>{fc} ); }else if (solution->getStartNode() == std::vector<NodePtr>{conv2}) { REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv2} ); REQUIRE(solution->at("FC") == std::set<NodePtr>{fc2} ); } } //REQUIRE( sut->match(g1)[1]->getAll() == std::set<NodePtr>{conv,fc}); } SECTION("2 query") { std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); g1->add(conv); g1->addChild(conv1, "c"); g1->addChild(conv2, "c1"); g1->addChild(conv3, "c2"); sut->setKeyFromGraph(g1); const std::string query = "Conv->Conv"; const std::string query2 = "Conv->FC"; sut->setNodeKey("FC","getType($) =='FC'"); sut->addQuery(query); sut->addQuery(query2); for (const auto& solution : sut->match(g1)) { REQUIRE(solution->getQuery() == query); } } SECTION("Not define node Test") { //test if the FC is not define only match query not query2 std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 0, 1, "c3"); g1->add(conv); g1->addChild(conv1, "c"); g1->addChild(conv2, "c1"); g1->addChild(conv3, "c2"); //sut->setKeyFromGraph(g1); const std::string query = "Conv->Conv"; const std::string query2 = "Conv->FC"; sut->setNodeKey("Conv","getType($) =='Conv'"); sut->addQuery(query); sut->addQuery(query2); for (const auto& solution : sut->match(g1)) { REQUIRE(solution->getQuery() == query); } } SECTION("Applied Recipes"){ // generate the original GraphView auto matmul0 = MatMul("matmul0"); auto add0 = Add(2, "add0"); auto matmul1 = MatMul("matmul1"); auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); auto w0 = Producer({5, 5}, "W0"); auto b1 = Producer({5}, "B1"); auto w1 = Producer({5,5},"W1"); auto input = Producer({2,5}, "input"); input->addChild(matmul0, 0, 0); w0->addChild(matmul0, 0, 1); matmul0->addChild(add0, 0, 0); b0->addChild(add0, 0, 1); add0->addChild(matmul1, 0, 0); w1->addChild(matmul1, 0, 1); matmul1->addChild(add1, 0, 0); b1->addChild(add1, 0, 1); auto fc = GenericOperator("FC", 1, 0, 1, "c"); auto fl = GenericOperator("Flatten", 1, 0, 1, "c"); auto g = std::make_shared<GraphView>(); g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc}); std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); kitchenBook->setNodeKey("Add","getType($) =='Add'"); kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'"); kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); kitchenBook->setNodeKey("FC","getType($) =='FC'"); kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); kitchenBook->appliedRecipes(g); std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc})); //REQUIRE(newNodes.size() == 6); } }