diff --git a/unit_tests/recipes/Test_removeFlatten.cpp b/unit_tests/recipes/Test_removeFlatten.cpp index 6c805e4cfc64ad3dcfbf020e74926dd0aeca5f9f..15f2a27a526af37fe0bae072ce7721add36d5266 100644 --- a/unit_tests/recipes/Test_removeFlatten.cpp +++ b/unit_tests/recipes/Test_removeFlatten.cpp @@ -10,40 +10,64 @@ ********************************************************************************/ #include <catch2/catch_test_macros.hpp> +#include <memory> #include <set> #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" -#include "aidge/operator/GenericOperator.hpp" +#include "aidge/graph/OpArgs.hpp" #include "aidge/operator/FC.hpp" +#include "aidge/operator/GenericOperator.hpp" #include "aidge/recipes/Recipes.hpp" namespace Aidge { +TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { + std::shared_ptr<Node> flatten = + GenericOperator("Flatten", 1, 0, 1, "myFlatten"); + std::shared_ptr<Node> fc0 = FC(10, 10, "FC_1"); + std::shared_ptr<Node> fc1 = FC(10, 10, "FC_2"); -TEST_CASE("[cpu/recipes] RemoveFlatten", "[RemoveFlatten][recipes]") { - // generate the original GraphView - auto flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten"); - auto fc = FC(10, 50, "myFC"); - - flatten -> addChild(fc); + SECTION("flatten last layer") { + std::shared_ptr<Aidge::GraphView> g = Sequential({fc0, flatten}); - auto g = std::make_shared<GraphView>(); - g->add({fc, flatten}); + removeFlatten(flatten); - // Check original graph - // g -> save("before_remove_flatten"); + CHECK(g->getOrderedInputs().size() == 1); + CHECK(g->getOrderedOutputs().size() == 1); + CHECK(g->getOrderedInputs()[0].first == fc0); + CHECK(g->getOrderedOutputs()[0].first == fc0); + CHECK(fc0->getParent(0) == nullptr); + CHECK(fc0->getChildren(0).size() == 0); + CHECK(g->getRootNode() == fc0); + } + SECTION("flatten first layer") { + auto g = Sequential({flatten, fc0}); - // use recipie removeFlatten(g); - // Check transformed graph - // g -> save("after_remove_flatten"); + CHECK(g->getOrderedInputs().size() == 1); + CHECK(g->getOrderedOutputs().size() == 1); + CHECK(g->getOrderedInputs()[0].first == fc0); + CHECK(g->getOrderedOutputs()[0].first == fc0); + CHECK(fc0->getParent(0) == nullptr); + CHECK(fc0->getChildren(0).size() == 0); + CHECK(g->getRootNode() == fc0); + } + SECTION("flatten middle layer") { + + auto g = Sequential({fc0, flatten, fc1}); + + removeFlatten(g); - REQUIRE(g->getOrderedInputs().size() == 1); - REQUIRE(g->getOrderedOutputs().size() == 1); - REQUIRE(g->getOrderedInputs()[0].first == fc); - REQUIRE(g->getOrderedOutputs()[0].first == fc); + CHECK(g->getOrderedInputs().size() == 1); + CHECK(g->getOrderedOutputs().size() == 1); + CHECK(g->getOrderedInputs()[0].first == fc0); + CHECK(g->getOrderedOutputs()[0].first == fc1); + CHECK(fc1->getParent(0) == fc0); + CHECK(fc0->getChildren(0)[0] == fc1); + CHECK(g->getRootNode() == fc0); + } } -} // namespace Aidge \ No newline at end of file +} // namespace Aidge