Skip to content
Snippets Groups Projects
Commit 19ef24a0 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Extended LabelGraph to test deleted nodes

parent a2aac52e
No related branches found
No related tags found
1 merge request!8GraphView cloning proposal + labelGraph proof of concept
Pipeline #31358 passed
......@@ -21,7 +21,10 @@ Aidge::NodePtr Aidge::nodeLabel(NodePtr node) {
auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator());
// TODO: adapt the following code.
auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 1, 1, 1);
// TODO: dummy parameter
newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>());
// TODO: compute correct output dims
newOp->setComputeOutputDims(GenericOperator_Op::Identity);
return std::make_shared<Node>(newOp, node->name());
}
......
......@@ -15,11 +15,12 @@
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include <cstddef>
using namespace Aidge;
TEST_CASE("[LabelGraph]") {
TEST_CASE("[LabelGraph] conv") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
......@@ -29,15 +30,83 @@ TEST_CASE("[LabelGraph]") {
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("LabelGraph_graph");
g1->save("LabelGraph_conv_graph");
auto g2 = labelGraph(g1);
g2->save("LabelGraph_label");
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_conv_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
}
}
TEST_CASE("[LabelGraph] deleted node") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(32, 64, {3, 3}, "conv2"),
Conv(64, 10, {1, 1}, "conv3")
});
g1->save("LabelGraph_deleted_graph");
auto g2 = labelGraph(g1);
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_deleted_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
}
}
TEST_CASE("[LabelGraph] deleted nodes") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(32, 64, {3, 3}, "conv2"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(64, 10, {1, 1}, "conv3")
});
g1->save("LabelGraph_deleteds_graph");
auto g2 = labelGraph(g1);
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_deleteds_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment