Skip to content
Snippets Groups Projects
Commit d59933a7 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added LabelGraph unit test

parent 499196e8
No related branches found
No related tags found
1 merge request!8GraphView cloning proposal + labelGraph proof of concept
Pipeline #31307 passed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_RECIPIES_LABELGRAPH_H_
#define AIDGE_RECIPIES_LABELGRAPH_H_
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
namespace Aidge {
NodePtr nodeLabel(NodePtr node);
std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph);
} // namespace Aidge
#endif /* AIDGE_RECIPIES_LABELGRAPH_H_ */
...@@ -11,14 +11,11 @@ ...@@ -11,14 +11,11 @@
#include <memory> #include <memory>
#include "aidge/recipies/LabelGraph.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
using namespace Aidge; Aidge::NodePtr Aidge::nodeLabel(NodePtr node) {
NodePtr nodeLabel(NodePtr node) {
// TODO: this is just a proof of concept right now! // TODO: this is just a proof of concept right now!
if (node->type() == Conv_Op<2>::Type) { if (node->type() == Conv_Op<2>::Type) {
auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator());
...@@ -32,6 +29,6 @@ NodePtr nodeLabel(NodePtr node) { ...@@ -32,6 +29,6 @@ NodePtr nodeLabel(NodePtr node) {
return nullptr; return nullptr;
} }
std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph) { std::shared_ptr<Aidge::GraphView> Aidge::labelGraph(std::shared_ptr<GraphView> graph) {
return graph->clone(&nodeLabel); return graph->clone(&nodeLabel);
} }
...@@ -459,3 +459,67 @@ TEST_CASE("[GraphView] cloneSharedProducers") { ...@@ -459,3 +459,67 @@ TEST_CASE("[GraphView] cloneSharedProducers") {
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
} }
} }
TEST_CASE("[GraphView] cloneSharedOperators") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("cloneSharedOperators_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0));
}
auto g2 = g1->cloneSharedOperators();
g2->save("cloneSharedOperators_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/recipies/LabelGraph.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include <cstddef>
using namespace Aidge;
TEST_CASE("[LabelGraph]") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("LabelGraph_graph");
auto g2 = labelGraph(g1);
g2->save("LabelGraph_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "CenterCropPad");
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "CenterCropPad");
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment