From 7f30404552de7a4e9be9015efaf4807094a093a9 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Mon, 2 Sep 2024 08:16:50 +0000 Subject: [PATCH] Update remove flatten tests. --- unit_tests/graphRegex/Test_GraphRegex.cpp | 13 ++++++------- unit_tests/recipes/Test_removeFlatten.cpp | 12 ++++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index 7e4027139..68ac509e7 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -175,18 +175,17 @@ TEST_CASE("GraphRegexUser") { matmul1->addChild(add1, 0, 0); b1->addChild(add1, 0, 1); - auto fc = GenericOperator("FC", 1, 0, 1, "c"); - auto fl = GenericOperator("Flatten", 1, 0, 1, "c"); - - + auto fc = GenericOperator("FC", 1, 0, 1, "fc1"); + auto fl = GenericOperator("Flatten", 1, 0, 1, "flatten0"); + add1->addChild(fl, 0, 0); + fl->addChild(fc, 0, 0); auto g = std::make_shared<GraphView>(); - g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc}); + g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1, fl, fc}); matMulToFC(g); removeFlatten(g); - std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); - REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc})); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc})); //REQUIRE(newNodes.size() == 6); diff --git a/unit_tests/recipes/Test_removeFlatten.cpp b/unit_tests/recipes/Test_removeFlatten.cpp index 24f5aa2e2..c3b4c08d9 100644 --- a/unit_tests/recipes/Test_removeFlatten.cpp +++ b/unit_tests/recipes/Test_removeFlatten.cpp @@ -42,7 +42,7 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { CHECK(g->getOrderedInputs().size() == 1); CHECK(g->getOrderedInputs()[0].first == fc0); - + CHECK(fc0->getParent(0) == nullptr); CHECK(fc0->getChildren(0).size() == 1); CHECK(g->rootNode() == fc0); @@ -54,10 +54,10 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { CHECK(g->getOrderedInputs().size() == 1); CHECK(g->getOrderedInputs()[0].first == fc0); - + CHECK(g->getOrderedOutputs().size() == 1); CHECK(g->getOrderedOutputs()[0].first == fc0); - + CHECK(fc0->getParent(0) == nullptr); CHECK(fc0->getChildren(0).size() == 0); CHECK(g->rootNode() == fc0); @@ -73,7 +73,7 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { CHECK(g->getOrderedOutputs().size() == 1); CHECK(g->getOrderedOutputs()[0].first == fc1); - + CHECK(fc1->getParent(0) == fc0); CHECK(fc0->getChildren(0)[0] == fc1); CHECK(g->rootNode() == fc0); @@ -87,10 +87,10 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { removeFlatten(g); CHECK(g->getOrderedInputs().size() == 0); - + CHECK(g->getOrderedOutputs().size() == 1); CHECK(g->getOrderedOutputs()[0].first == fc0); - + CHECK(fc0->getParent(0) == prod); CHECK(fc0->getChildren(0).size() == 0); -- GitLab