diff --git a/unit_tests/recipes/Test_removeFlatten.cpp b/unit_tests/recipes/Test_removeFlatten.cpp index 13e83f985a9fc153ecbd9e9836a09aef7f13189a..feb982fdf7f77ca368ee52b9ed0fe72dab85ea28 100644 --- a/unit_tests/recipes/Test_removeFlatten.cpp +++ b/unit_tests/recipes/Test_removeFlatten.cpp @@ -31,7 +31,7 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { std::shared_ptr<Node> fc1 = FC(10, 10, "FC_2"); std::shared_ptr<Node> prod = Producer(std::array<DimSize_t, 10>(), "myProd"); - SECTION("flatten last layer") { + SECTION("flatten last layer : nothing removed because pattern searched is Flatten=>FC") { std::shared_ptr<Aidge::GraphView> g = Sequential({fc0, flatten}); removeFlatten(g); @@ -39,12 +39,12 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { CHECK(g->getOrderedInputs().size() == 1); CHECK(g->getOrderedOutputs().size() == 1); CHECK(g->getOrderedInputs()[0].first == fc0); - CHECK(g->getOrderedOutputs()[0].first == fc0); + CHECK(g->getOrderedOutputs()[0].first == flatten); CHECK(fc0->getParent(0) == nullptr); - CHECK(fc0->getChildren(0).size() == 0); + CHECK(fc0->getChildren(0).size() == 1); CHECK(g->getRootNode() == fc0); } - SECTION("flatten first layer") { + SECTION("flatten first layer : flatten removed") { auto g = Sequential({flatten, fc0}); removeFlatten(g);