Skip to content
Snippets Groups Projects
Commit 0c0bdc63 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

feat: new test cases + tests to ensure root node value + refactor REQUIRE -> CHECK

parent 0a77f4e1
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!91Feat/operator global average pooling
...@@ -10,40 +10,64 @@ ...@@ -10,40 +10,64 @@
********************************************************************************/ ********************************************************************************/
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include <memory>
#include <set> #include <set>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/FC.hpp" #include "aidge/operator/FC.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
namespace Aidge { namespace Aidge {
TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") {
std::shared_ptr<Node> flatten =
GenericOperator("Flatten", 1, 0, 1, "myFlatten");
std::shared_ptr<Node> fc0 = FC(10, 10, "FC_1");
std::shared_ptr<Node> fc1 = FC(10, 10, "FC_2");
TEST_CASE("[cpu/recipes] RemoveFlatten", "[RemoveFlatten][recipes]") { SECTION("flatten last layer") {
// generate the original GraphView std::shared_ptr<Aidge::GraphView> g = Sequential({fc0, flatten});
auto flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten");
auto fc = FC(10, 50, "myFC");
flatten -> addChild(fc);
auto g = std::make_shared<GraphView>(); removeFlatten(flatten);
g->add({fc, flatten});
// Check original graph CHECK(g->getOrderedInputs().size() == 1);
// g -> save("before_remove_flatten"); CHECK(g->getOrderedOutputs().size() == 1);
CHECK(g->getOrderedInputs()[0].first == fc0);
CHECK(g->getOrderedOutputs()[0].first == fc0);
CHECK(fc0->getParent(0) == nullptr);
CHECK(fc0->getChildren(0).size() == 0);
CHECK(g->getRootNode() == fc0);
}
SECTION("flatten first layer") {
auto g = Sequential({flatten, fc0});
// use recipie
removeFlatten(g); removeFlatten(g);
// Check transformed graph CHECK(g->getOrderedInputs().size() == 1);
// g -> save("after_remove_flatten"); CHECK(g->getOrderedOutputs().size() == 1);
CHECK(g->getOrderedInputs()[0].first == fc0);
CHECK(g->getOrderedOutputs()[0].first == fc0);
CHECK(fc0->getParent(0) == nullptr);
CHECK(fc0->getChildren(0).size() == 0);
CHECK(g->getRootNode() == fc0);
}
SECTION("flatten middle layer") {
auto g = Sequential({fc0, flatten, fc1});
removeFlatten(g);
REQUIRE(g->getOrderedInputs().size() == 1); CHECK(g->getOrderedInputs().size() == 1);
REQUIRE(g->getOrderedOutputs().size() == 1); CHECK(g->getOrderedOutputs().size() == 1);
REQUIRE(g->getOrderedInputs()[0].first == fc); CHECK(g->getOrderedInputs()[0].first == fc0);
REQUIRE(g->getOrderedOutputs()[0].first == fc); CHECK(g->getOrderedOutputs()[0].first == fc1);
CHECK(fc1->getParent(0) == fc0);
CHECK(fc0->getChildren(0)[0] == fc1);
CHECK(g->getRootNode() == fc0);
}
} }
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment