From 64a3132bbe1f4c71a8f06a59fbf3a4dd12e77098 Mon Sep 17 00:00:00 2001 From: vl241552 <vincent.lorrain@cea.fr> Date: Tue, 14 Nov 2023 13:57:43 +0000 Subject: [PATCH] add CI --- unit_tests/graphRegex/Test_GraphRegex.cpp | 67 +++++++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index b30560ea3..c9feded0b 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -2,6 +2,15 @@ #include <catch2/catch_test_macros.hpp> #include "aidge/graphRegex/GraphRegex.hpp" + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Recipies.hpp" + #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" @@ -47,12 +56,8 @@ TEST_CASE("GraphRegexUser") { } SECTION("CC") { - - - std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); @@ -81,4 +86,58 @@ TEST_CASE("GraphRegexUser") { } } + + + SECTION("Applied Recipes"){ + + // generate the original GraphView + auto matmul0 = MatMul(5, "matmul0"); + auto add0 = Add<2>("add0"); + auto matmul1 = MatMul(5, "matmul1"); + auto add1 = Add<2>("add1"); + + auto b0 = Producer({5}, "B0"); + auto w0 = Producer({5, 5}, "W0"); + auto b1 = Producer({5}, "B1"); + auto w1 = Producer({5,5},"W1"); + auto input = Producer({2,5}, "input"); + + input->addChild(matmul0, 0, 0); + w0->addChild(matmul0, 0, 1); + + matmul0->addChild(add0, 0, 0); + b0->addChild(add0, 0, 1); + + add0->addChild(matmul1, 0, 0); + w1->addChild(matmul1, 0, 1); + + matmul1->addChild(add1, 0, 0); + b1->addChild(add1, 0, 1); + + auto fc = GenericOperator("FC", 1, 1, 1, "c"); + auto fl = GenericOperator("Flatten", 1, 1, 1, "c"); + + + auto g = std::make_shared<GraphView>(); + g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc}); + + std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); + + kitchenBook->setNodeKey("Add","getType($) =='Add'"); + kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'"); + kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); + kitchenBook->setNodeKey("FC","getType($) =='FC'"); + + kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); + kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); + + kitchenBook->appliedRecipes(g); + + std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc})); + //REQUIRE(newNodes.size() == 6); + + + } + } \ No newline at end of file -- GitLab