diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp index af09baeb7a6461cb252182c06ea96cb888fac1fc..824f48e1f9370fabedda8810f16544f124cd1822 100644 --- a/src/recipies/LabelGraph.cpp +++ b/src/recipies/LabelGraph.cpp @@ -21,7 +21,10 @@ Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); // TODO: adapt the following code. auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 1, 1, 1); + // TODO: dummy parameter newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>()); + // TODO: compute correct output dims + newOp->setComputeOutputDims(GenericOperator_Op::Identity); return std::make_shared<Node>(newOp, node->name()); } diff --git a/unit_tests/recipies/Test_LabelGraph.cpp b/unit_tests/recipies/Test_LabelGraph.cpp index 40a7d6270f3e17c566d867627ef80d8edf26eb4c..3e05c243767710ac49159a5cd1d74637e61074d5 100644 --- a/unit_tests/recipies/Test_LabelGraph.cpp +++ b/unit_tests/recipies/Test_LabelGraph.cpp @@ -15,11 +15,12 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/graph/OpArgs.hpp" #include <cstddef> using namespace Aidge; -TEST_CASE("[LabelGraph]") { +TEST_CASE("[LabelGraph] conv") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); @@ -29,15 +30,83 @@ TEST_CASE("[LabelGraph]") { g1->add(conv1); g1->addChild(conv2, conv1, 0); g1->addChild(conv3, conv2, 0); - g1->save("LabelGraph_graph"); + g1->save("LabelGraph_conv_graph"); auto g2 = labelGraph(g1); - g2->save("LabelGraph_label"); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_conv_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad"); + } +} + +TEST_CASE("[LabelGraph] deleted node") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(32, 64, {3, 3}, "conv2"), + Conv(64, 10, {1, 1}, "conv3") + }); + + g1->save("LabelGraph_deleted_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_deleted_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad"); + } +} + +TEST_CASE("[LabelGraph] deleted nodes") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(32, 64, {3, 3}, "conv2"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(64, 10, {1, 1}, "conv3") + }); + + g1->save("LabelGraph_deleteds_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_deleteds_label"); SECTION("Check resulting nodes") { REQUIRE(g2->getNodes().size() == 3); REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad"); } }