From 19ef24a0ecfbd6f0fe61da7400635b6cfe8b824e Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 8 Sep 2023 16:38:53 +0200
Subject: [PATCH] Extended LabelGraph to test deleted nodes

---
 src/recipies/LabelGraph.cpp             |  3 +
 unit_tests/recipies/Test_LabelGraph.cpp | 75 ++++++++++++++++++++++++-
 2 files changed, 75 insertions(+), 3 deletions(-)

diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp
index af09baeb7..824f48e1f 100644
--- a/src/recipies/LabelGraph.cpp
+++ b/src/recipies/LabelGraph.cpp
@@ -21,7 +21,10 @@ Aidge::NodePtr Aidge::nodeLabel(NodePtr node) {
         auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator());
         // TODO: adapt the following code.
         auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 1, 1, 1);
+        // TODO: dummy parameter
         newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>());
+        // TODO: compute correct output dims
+        newOp->setComputeOutputDims(GenericOperator_Op::Identity);
         return std::make_shared<Node>(newOp, node->name());
     }
 
diff --git a/unit_tests/recipies/Test_LabelGraph.cpp b/unit_tests/recipies/Test_LabelGraph.cpp
index 40a7d6270..3e05c2437 100644
--- a/unit_tests/recipies/Test_LabelGraph.cpp
+++ b/unit_tests/recipies/Test_LabelGraph.cpp
@@ -15,11 +15,12 @@
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/GenericOperator.hpp"
 #include "aidge/operator/Producer.hpp"
+#include "aidge/graph/OpArgs.hpp"
 #include <cstddef>
 
 using namespace Aidge;
 
-TEST_CASE("[LabelGraph]") {
+TEST_CASE("[LabelGraph] conv") {
     auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
     auto conv1 = Conv(3, 32, {3, 3}, "conv1");
     auto conv2 = Conv(32, 64, {3, 3}, "conv2");
@@ -29,15 +30,83 @@ TEST_CASE("[LabelGraph]") {
     g1->add(conv1);
     g1->addChild(conv2, conv1, 0);
     g1->addChild(conv3, conv2, 0);
-    g1->save("LabelGraph_graph");
+    g1->save("LabelGraph_conv_graph");
 
     auto g2 = labelGraph(g1);
-    g2->save("LabelGraph_label");
+
+    auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
+    dataProvider2->addChild(g2->getNode("conv1"), 0);
+
+    g2->forwardDims();
+    g2->save("LabelGraph_conv_label");
+
+    SECTION("Check resulting nodes") {
+        REQUIRE(g2->getNodes().size() == 3);
+        REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
+        REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
+        REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
+        REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
+        REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
+    }
+}
+
+TEST_CASE("[LabelGraph] deleted node") {
+    auto g1 = Sequential({
+        Producer({16, 3, 224, 224}, "dataProvider"),
+        Conv(3, 32, {3, 3}, "conv1"),
+        GenericOperator("Dummy_to_be_removed", 1, 1, 1),
+        Conv(32, 64, {3, 3}, "conv2"),
+        Conv(64, 10, {1, 1}, "conv3")
+    });
+
+    g1->save("LabelGraph_deleted_graph");
+
+    auto g2 = labelGraph(g1);
+
+    auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
+    dataProvider2->addChild(g2->getNode("conv1"), 0);
+
+    g2->forwardDims();
+    g2->save("LabelGraph_deleted_label");
+
+    SECTION("Check resulting nodes") {
+        REQUIRE(g2->getNodes().size() == 3);
+        REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
+        REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
+        REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
+        REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
+        REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
+    }
+}
+
+TEST_CASE("[LabelGraph] deleted nodes") {
+    auto g1 = Sequential({
+        Producer({16, 3, 224, 224}, "dataProvider"),
+        Conv(3, 32, {3, 3}, "conv1"),
+        GenericOperator("Dummy_to_be_removed", 1, 1, 1),
+        GenericOperator("Dummy_to_be_removed", 1, 1, 1),
+        GenericOperator("Dummy_to_be_removed", 1, 1, 1),
+        Conv(32, 64, {3, 3}, "conv2"),
+        GenericOperator("Dummy_to_be_removed", 1, 1, 1),
+        Conv(64, 10, {1, 1}, "conv3")
+    });
+
+    g1->save("LabelGraph_deleteds_graph");
+
+    auto g2 = labelGraph(g1);
+
+    auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
+    dataProvider2->addChild(g2->getNode("conv1"), 0);
+
+    g2->forwardDims();
+    g2->save("LabelGraph_deleteds_label");
 
     SECTION("Check resulting nodes") {
         REQUIRE(g2->getNodes().size() == 3);
         REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
+        REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
         REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
+        REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
         REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
     }
 }
-- 
GitLab