Skip to content
Snippets Groups Projects
Commit 19ef24a0 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Extended LabelGraph to test deleted nodes

parent a2aac52e
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,10 @@ Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { ...@@ -21,7 +21,10 @@ Aidge::NodePtr Aidge::nodeLabel(NodePtr node) {
auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator());
// TODO: adapt the following code. // TODO: adapt the following code.
auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 1, 1, 1); auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 1, 1, 1);
// TODO: dummy parameter
newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>()); newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>());
// TODO: compute correct output dims
newOp->setComputeOutputDims(GenericOperator_Op::Identity);
return std::make_shared<Node>(newOp, node->name()); return std::make_shared<Node>(newOp, node->name());
} }
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include <cstddef> #include <cstddef>
using namespace Aidge; using namespace Aidge;
TEST_CASE("[LabelGraph]") { TEST_CASE("[LabelGraph] conv") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv2 = Conv(32, 64, {3, 3}, "conv2");
...@@ -29,15 +30,83 @@ TEST_CASE("[LabelGraph]") { ...@@ -29,15 +30,83 @@ TEST_CASE("[LabelGraph]") {
g1->add(conv1); g1->add(conv1);
g1->addChild(conv2, conv1, 0); g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0); g1->addChild(conv3, conv2, 0);
g1->save("LabelGraph_graph"); g1->save("LabelGraph_conv_graph");
auto g2 = labelGraph(g1); auto g2 = labelGraph(g1);
g2->save("LabelGraph_label");
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_conv_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
}
}
TEST_CASE("[LabelGraph] deleted node") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(32, 64, {3, 3}, "conv2"),
Conv(64, 10, {1, 1}, "conv3")
});
g1->save("LabelGraph_deleted_graph");
auto g2 = labelGraph(g1);
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_deleted_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
}
}
TEST_CASE("[LabelGraph] deleted nodes") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(32, 64, {3, 3}, "conv2"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(64, 10, {1, 1}, "conv3")
});
g1->save("LabelGraph_deleteds_graph");
auto g2 = labelGraph(g1);
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_deleteds_label");
SECTION("Check resulting nodes") { SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3); REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad"); REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad"); REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad"); REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment