diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index dda3d8ee459e9f089f817f7222d717bf75ede0f5..766c6ba72c44293834f130c76b7c21881ef10752 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -257,7 +257,7 @@ public: * @brief Get the operator with the corresponding name if it is in the * GraphView. * @param nodeName Name of the node. - * @return NodePtr returns a new empty node if the one asked for + * @return NodePtr returns a nullptr if the one asked for * was not found. */ NodePtr getNode(const std::string& nodeName) const; diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 554c535f229af0ab5b59fa6f57607c7bacd872fa..377f991a7bb0d6c7c2e8a63198218f878da64f13 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -616,11 +616,11 @@ std::shared_ptr<Aidge::Node> Aidge::GraphView::getNode(const std::string& nodeName) const { std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName); - if (it != mNodeRegistry.end()) { + if (it != mNodeRegistry.cend()) { return it->second; } else { printf("No Node named %s in the current GraphView.\n", nodeName.c_str()); - exit(-1); + return nullptr; } } @@ -760,14 +760,6 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s return false; } - for (const auto& nodePtr : oldNodes) { - for (const auto& g : commonGraphViews) { - g -> remove(nodePtr, false); - g -> updateInputsOutputsDelete(nodePtr); - } - nodePtr -> resetConnections(true); - } - if ((oldOI.size() == newOI.size()) && (oldOO.size() == newOO.size())) { // Case 1 @@ -793,7 +785,7 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); } } - else if (oldOI.size() == 1) { + else if ((oldOI.size() == 1) && (inputParents[0].first)) { for (std::size_t i = 0; i < oldOI.size(); ++i) { inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); } @@ -804,13 +796,15 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s ((oldOO.size() == newOO.size())) ) { // Case 2 - if ((oldOI.size() == 1)) { + if ((oldOI.size() == 1) && (inputParents[0].first)) { for (std::size_t i = 0; i < newOI.size(); ++i) { inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second); } } else { for (std::size_t i = 0; i < oldOI.size(); ++i) { - inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second); + if (inputParents[i].first) { + inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second); + } } } for (std::size_t o = 0; o < oldOO.size(); ++o) { @@ -829,6 +823,27 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s return false; } } + + auto oldGOutputs = oldG->outputNodes(); + for (const auto& nodePtr : oldNodes) { + bool removeFromGraphs = true; + if (std::find(oldGOutputs.cbegin(), oldGOutputs.cend(), nodePtr) == oldGOutputs.cend()) { + for (const auto& chPtr : nodePtr->getChildren()) { + if (oldNodes.find(chPtr) == oldNodes.cend()) { + removeFromGraphs = false; + } + } + } + if (removeFromGraphs) { + for (const auto& g : commonGraphViews) { + g -> remove(nodePtr, false); + g -> updateInputsOutputsDelete(nodePtr); + } + nodePtr -> resetConnections(true); + } + + } + for (const auto& nodePtr : newNodes) { for (const auto& g : commonGraphViews) { g -> add(nodePtr); @@ -934,10 +949,10 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { // Check if node outputs are outputs for the GraphView and add them to the output list if so IOIndex_t outputIdx = 0; - for (auto orderedChilds : newNode->getOrderedChildren()) { + for (const auto& orderedChilds : newNode->getOrderedChildren()) { bool noInsideConnection = true; - for (auto ch_ptr : orderedChilds) { - if (mNodes.find(ch_ptr) != mNodes.end()) { + for (const auto& ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) != mNodes.cend()) { noInsideConnection = false; break; } @@ -946,7 +961,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { if (noInsideConnection) { const auto val = std::make_pair(newNode, outputIdx); // Output may be already be present (see addChild() with a node already in GraphView) - if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { + if (std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val) == mOutputNodes.cend()) { newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index acbea04a27a0b6be22105bb73fda53fedf621235..ebbfb3ad89721eb4f1390c3efca475acbb0b6f46 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -23,6 +23,8 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Testing.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/graph/OpArgs.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" @@ -589,6 +591,56 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0})); } + + SECTION("Nodes with shared parameters") { + + auto myConv1 = Conv(1, 5, {1,1}, "conv1"); + auto myConv2 = Conv(5, 5, {1,1}, "conv2"); + auto myConv3 = Conv(5, 5, {1,1}, "conv3"); + auto myConv4 = Conv(5, 5, {1,1}, "conv4"); + auto myConv5 = Conv(5, 5, {1,1}, "conv5"); + + auto sharedWeightTensor = std::make_shared<Tensor>(); + sharedWeightTensor->resize({5,5,1,1}); + auto sharedWeight = Producer(sharedWeightTensor, "sharedWeight"); + sharedWeight -> addChild(myConv2, 0, 1); + sharedWeight -> addChild(myConv3, 0, 1); + sharedWeight -> addChild(myConv4, 0, 1); + + auto sharedBiasTensor = std::make_shared<Tensor>(); + sharedBiasTensor->resize({5}); + auto sharedBias = Producer(sharedBiasTensor, "sharedBias"); + sharedBias -> addChild(myConv2, 0, 2); + sharedBias -> addChild(myConv3, 0, 2); + sharedBias -> addChild(myConv4, 0, 2); + + auto g = Sequential({ + myConv1, + myConv2, + myConv3, + myConv4, + myConv5 + }); + + REQUIRE(g->getNode("sharedWeight") != nullptr); + REQUIRE(g->getNode("sharedBias") != nullptr); + + + auto newReLU4 = ReLU("relu4"); + GraphView::replace({myConv4, myConv4->getParent(1), myConv4->getParent(2)}, {newReLU4}); + REQUIRE(g->getNode("sharedWeight") != nullptr); + REQUIRE(g->getNode("sharedBias") != nullptr); + + auto newReLU3 = ReLU("relu3"); + GraphView::replace({myConv3, myConv3->getParent(1), myConv3->getParent(2)}, {newReLU3}); + REQUIRE(g->getNode("sharedWeight") != nullptr); + REQUIRE(g->getNode("sharedBias") != nullptr); + + auto newReLU2 = ReLU("relu2"); + GraphView::replace({myConv2, myConv2->getParent(1), myConv2->getParent(2)}, {newReLU2}); + REQUIRE(g->getNode("sharedWeight") == nullptr); + REQUIRE(g->getNode("sharedBias") == nullptr); + } } TEST_CASE("[GraphView] clone") {