diff --git a/include/aidge/recipies/LabelGraph.hpp b/include/aidge/recipies/LabelGraph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8827df0b7a8dc6d679c9bd63240f76350140816a --- /dev/null +++ b/include/aidge/recipies/LabelGraph.hpp @@ -0,0 +1,23 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_RECIPIES_LABELGRAPH_H_ +#define AIDGE_RECIPIES_LABELGRAPH_H_ + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge { +NodePtr nodeLabel(NodePtr node); +std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph); +} // namespace Aidge + +#endif /* AIDGE_RECIPIES_LABELGRAPH_H_ */ diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp index a42a9fdaedb0137d94f2b9924cd03adaa8132096..f85e61b5485bd7109842ec2d5dc7ebfd78a6c9f7 100644 --- a/src/recipies/LabelGraph.cpp +++ b/src/recipies/LabelGraph.cpp @@ -11,14 +11,11 @@ #include <memory> +#include "aidge/recipies/LabelGraph.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -using namespace Aidge; - -NodePtr nodeLabel(NodePtr node) { +Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { // TODO: this is just a proof of concept right now! if (node->type() == Conv_Op<2>::Type) { auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); @@ -32,6 +29,6 @@ NodePtr nodeLabel(NodePtr node) { return nullptr; } -std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph) { +std::shared_ptr<Aidge::GraphView> Aidge::labelGraph(std::shared_ptr<GraphView> graph) { return graph->clone(&nodeLabel); } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 71c3ea3aedb941b94d016c13ee169a15dad9af55..db57989b16b70d451aef81b3fa41d72880dd99b2 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -459,3 +459,67 @@ TEST_CASE("[GraphView] cloneSharedProducers") { REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); } } + +TEST_CASE("[GraphView] cloneSharedOperators") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("cloneSharedOperators_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->cloneSharedOperators(); + g2->save("cloneSharedOperators_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} diff --git a/unit_tests/recipies/Test_LabelGraph.cpp b/unit_tests/recipies/Test_LabelGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40a7d6270f3e17c566d867627ef80d8edf26eb4c --- /dev/null +++ b/unit_tests/recipies/Test_LabelGraph.cpp @@ -0,0 +1,43 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/recipies/LabelGraph.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[LabelGraph]") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("LabelGraph_graph"); + + auto g2 = labelGraph(g1); + g2->save("LabelGraph_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad"); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad"); + } +}