diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index cc2da0fd836c2013b6b7ec5a611737a8fbd451a0..9b7c1b0eae19c1984d9297ee6e92cba8b622b6b9 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -33,8 +33,7 @@ class GraphView; * @brief Object carrying the topological information of the computational graph. */ class Node : public std::enable_shared_from_this<Node> { -//private: -public: // TODO: workaround to make GraphView:clone() work, friend doesn't work because of forward declaration +private: struct weakCompare { bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const { // Compare the content of the weak_ptrs @@ -400,9 +399,6 @@ public: return node->clone(); } - // TODO: does not work, friend requires full definition, but there is a circular dependency between Node and GraphView - //friend std::shared_ptr<GraphView> GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const; - private: /////////////////////////////////////////////////////// // OPERATORS diff --git a/include/aidge/utils/CParameter.hpp b/include/aidge/utils/CParameter.hpp index 0f4c74ab8bccb7bc134e035a5f12d31d51663e5d..9387177d4a00ea9829ef41bda217afd811c0aed9 100644 --- a/include/aidge/utils/CParameter.hpp +++ b/include/aidge/utils/CParameter.hpp @@ -15,6 +15,7 @@ #include <assert.h> #include <map> #include <vector> +#include <string> namespace Aidge { @@ -30,11 +31,6 @@ private: struct is_vector<std::vector<T, Alloc>> : std::true_type {}; public: - // not copyable, not movable - CParameter(CParameter const &) = delete; - CParameter(CParameter &&) = delete; - CParameter &operator=(CParameter const &) = delete; - CParameter &operator=(CParameter &&) = delete; CParameter() : m_Params({}){}; ~CParameter() = default; diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 4d8692d2d69e9177b63bfdd36fe4c63be0f8a769..44d8fbb1804ade5a758e1deba29448bd5f6069b1 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -676,69 +676,52 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { } std::shared_ptr<Aidge::GraphView> Aidge::GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const { - std::shared_ptr<GraphView> newGraph; - std::map<NodePtr, NodePtr> oldToNewNodes; + std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); // Map for old node -> new node correspondance + std::map<NodePtr, NodePtr> oldToNewNodes; + for (const std::shared_ptr<Node> &node_ptr : mNodes) { oldToNewNodes[node_ptr] = cloneNode(node_ptr); } - // For each node, convert old node -> new node references + // For each node, convert old node -> new node connections for (auto &oldToNewNode : oldToNewNodes) { - // The only GraphView for the new graph is the one we return at the end - oldToNewNode.second->addView(newGraph); + if (oldToNewNode.second == nullptr) + continue; // deleted node - if (!(oldToNewNode.second->name()).empty()) - newGraph->mNodeRegistry.insert(std::make_pair(oldToNewNode.second->name(), oldToNewNode.second)); + // Add new node to new GraphView + newGraph->add(oldToNewNode.second, false); // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr size_t parentId = 0; - for (auto& parent : oldToNewNode.first->mParents) { - oldToNewNode.second->mParents.push_back(oldToNewNodes[parent]); - - if (oldToNewNodes[parent]) - oldToNewNode.second->mIdOutParents.push_back(oldToNewNode.first->mIdOutParents[parentId]); - else - oldToNewNode.second->mIdOutParents.push_back(gk_IODefaultIndex); - - ++parentId; - } - - // Connect child nodes. - size_t childId = 0; - for (const auto &childrenOfOneOutput : oldToNewNode.first->mChildren) { - oldToNewNode.second->mChildren.push_back(std::vector<std::weak_ptr<Node>>()); - oldToNewNode.second->mIdInChildren.push_back(std::vector<IOIndex_t>()); - - size_t j = 0; - for (const auto &oneChild : childrenOfOneOutput) { - if (oldToNewNodes[oneChild.lock()]) { - oldToNewNode.second->mChildren.back().push_back(std::weak_ptr<Node>(oldToNewNodes[oneChild.lock()])); - oldToNewNode.second->mIdInChildren.back().push_back(oldToNewNode.first->mIdInChildren[childId][j]); - } - - ++j; + for (auto parent : oldToNewNode.first->inputs()) { + while (oldToNewNodes[parent.first] == nullptr) { + // Find next valid parent in line, going backward in the graph + assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); + const auto& parents = parent.first->inputs(); + + if (!parents.empty() && parents[0].first != nullptr // a valid parent exists + && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView + { + parent = parents[0]; } + else { + break; + } + } - ++childId; - } - } - - // Copy remaining GraphView informations - newGraph->setName(mName); + if (oldToNewNodes[parent.first]) { + oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); + } - for (const auto& inputNode : mInputNodes) { - if (oldToNewNodes[inputNode]) { - newGraph->mInputNodes.insert(oldToNewNodes[inputNode]); + ++parentId; } } - for (const auto& outputNode : mOutputNodes) { - if (oldToNewNodes[outputNode]) { - newGraph->mOutputNodes.insert(oldToNewNodes[outputNode]); - } - } + // Update OutputNodes/inputNodes + newGraph->updateInputNodes(); + newGraph->updateOutputNodes(); return newGraph; } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index dc693193c6606c99b1628d23ad253015f8f8dbe6..71c3ea3aedb941b94d016c13ee169a15dad9af55 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -330,4 +330,132 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); REQUIRE((r1->output(0))[0].first == r4); } -} \ No newline at end of file +} + +TEST_CASE("[GraphView] clone") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("clone_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->clone(); + g2->save("clone_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + +TEST_CASE("[GraphView] cloneSharedProducers") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("cloneSharedProducers_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->cloneSharedProducers(); + g2->save("cloneSharedProducers_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +}