diff --git a/include/aidge/recipies/LabelGraph.hpp b/include/aidge/recipies/LabelGraph.hpp index 8827df0b7a8dc6d679c9bd63240f76350140816a..9dd77e5e9f397260cf936cf77b15616c17ea33b8 100644 --- a/include/aidge/recipies/LabelGraph.hpp +++ b/include/aidge/recipies/LabelGraph.hpp @@ -17,6 +17,18 @@ namespace Aidge { NodePtr nodeLabel(NodePtr node); + +/** + * @brief Generate the graph for the pixel-wise labels corresponding to a data graph, taking into account the scaling changes (padding, stride, pooling...). + * @details Right now, the behavior is to replace the following operators: + * - Conv: MaxPooling + * - ConvDepthWie: MaxPooling + * - AvgPooling: MaxPooling + * - MaxPooling: MaxPooling + * - all others: identity (removed) + * @param graph Data graph + * @param return Computing graph for the labels derived from the data graph + */ std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph); } // namespace Aidge diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp index 824f48e1f9370fabedda8810f16544f124cd1822..7ac2cbf6ca65c7ecbced9596efb71c2052405984 100644 --- a/src/recipies/LabelGraph.cpp +++ b/src/recipies/LabelGraph.cpp @@ -13,21 +13,40 @@ #include "aidge/recipies/LabelGraph.hpp" #include "aidge/operator/Conv.hpp" -#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { - // TODO: this is just a proof of concept right now! + // Conv => MaxPooling if (node->type() == Conv_Op<2>::Type) { auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); - // TODO: adapt the following code. - auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 1, 1, 1); - // TODO: dummy parameter - newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>()); - // TODO: compute correct output dims - newOp->setComputeOutputDims(GenericOperator_Op::Identity); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<ConvParam::KernelDims>(), op->get<ConvParam::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // ConvDepthWise => MaxPooling + if (node->type() == ConvDepthWise_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<ConvDepthWise_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<ConvDepthWiseParam::KernelDims>(), op->get<ConvDepthWiseParam::StrideDims>()); return std::make_shared<Node>(newOp, node->name()); } + // AvgPooling => MaxPooling + if (node->type() == AvgPooling_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<AvgPooling_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<AvgPoolingParam::KernelDims>(), op->get<AvgPoolingParam::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // MaxPooling => MaxPooling + if (node->type() == MaxPooling_Op<2>::Type) { + return node->clone(); + } + // By default, remove the node from the graph return nullptr; } diff --git a/unit_tests/recipies/Test_LabelGraph.cpp b/unit_tests/recipies/Test_LabelGraph.cpp index 3e05c243767710ac49159a5cd1d74637e61074d5..873ad68f3198c6b6adf44d8c7ae31e667c63a18d 100644 --- a/unit_tests/recipies/Test_LabelGraph.cpp +++ b/unit_tests/recipies/Test_LabelGraph.cpp @@ -13,6 +13,8 @@ #include "aidge/recipies/LabelGraph.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/graph/OpArgs.hpp" @@ -42,11 +44,11 @@ TEST_CASE("[LabelGraph] conv") { SECTION("Check resulting nodes") { REQUIRE(g2->getNodes().size() == 3); - REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); } } @@ -56,14 +58,14 @@ TEST_CASE("[LabelGraph] deleted node") { Conv(3, 32, {3, 3}, "conv1"), GenericOperator("Dummy_to_be_removed", 1, 1, 1), Conv(32, 64, {3, 3}, "conv2"), - Conv(64, 10, {1, 1}, "conv3") + Conv(64, 10, {1, 1}, "conv3", {2, 2}) }); g1->save("LabelGraph_deleted_graph"); auto g2 = labelGraph(g1); - auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider"); dataProvider2->addChild(g2->getNode("conv1"), 0); g2->forwardDims(); @@ -71,11 +73,17 @@ TEST_CASE("[LabelGraph] deleted node") { SECTION("Check resulting nodes") { REQUIRE(g2->getNodes().size() == 3); - REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } + + SECTION("Check dimensions") { + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 220, 220})); + REQUIRE(g2->getNode("conv3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 110, 110})); } } @@ -103,10 +111,44 @@ TEST_CASE("[LabelGraph] deleted nodes") { SECTION("Check resulting nodes") { REQUIRE(g2->getNodes().size() == 3); - REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); - REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } +} + +TEST_CASE("[LabelGraph] pooling") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + AvgPooling({2, 2}, "pool1"), + MaxPooling({2, 2}, "pool2"), + MaxPooling({2, 2}, "pool3", {2, 2}) + }); + + g1->save("LabelGraph_deleted_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("pool1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_pooling"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("pool1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0) == g2->getNode("pool2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0) == g2->getNode("pool3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool3")->getOperator()->type() == "MaxPooling"); + } + + SECTION("Check dimensions") { + REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 223, 223})); + REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(g2->getNode("pool3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 111, 111})); } }