Skip to content
Snippets Groups Projects
Commit 61810442 authored by Maxence Naud's avatar Maxence Naud
Browse files

Improve 'GraphView::replace()' for cases where removed operator had multiple...

Improve 'GraphView::replace()' for cases where removed operator had multiple outputs also in the GraphView
parent c9386277
No related branches found
No related tags found
No related merge requests found
......@@ -257,7 +257,7 @@ public:
* @brief Get the operator with the corresponding name if it is in the
* GraphView.
* @param nodeName Name of the node.
* @return NodePtr returns a new empty node if the one asked for
* @return NodePtr returns a nullptr if the one asked for
* was not found.
*/
NodePtr getNode(const std::string& nodeName) const;
......
......@@ -616,11 +616,11 @@ std::shared_ptr<Aidge::Node>
Aidge::GraphView::getNode(const std::string& nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
if (it != mNodeRegistry.cend()) {
return it->second;
} else {
printf("No Node named %s in the current GraphView.\n", nodeName.c_str());
exit(-1);
return nullptr;
}
}
......@@ -760,14 +760,6 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
return false;
}
for (const auto& nodePtr : oldNodes) {
for (const auto& g : commonGraphViews) {
g -> remove(nodePtr, false);
g -> updateInputsOutputsDelete(nodePtr);
}
nodePtr -> resetConnections(true);
}
if ((oldOI.size() == newOI.size()) &&
(oldOO.size() == newOO.size())) {
// Case 1
......@@ -793,7 +785,7 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
}
}
else if (oldOI.size() == 1) {
else if ((oldOI.size() == 1) && (inputParents[0].first)) {
for (std::size_t i = 0; i < oldOI.size(); ++i) {
inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second);
}
......@@ -804,13 +796,15 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
((oldOO.size() == newOO.size()))
) {
// Case 2
if ((oldOI.size() == 1)) {
if ((oldOI.size() == 1) && (inputParents[0].first)) {
for (std::size_t i = 0; i < newOI.size(); ++i) {
inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second);
}
} else {
for (std::size_t i = 0; i < oldOI.size(); ++i) {
inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second);
if (inputParents[i].first) {
inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second);
}
}
}
for (std::size_t o = 0; o < oldOO.size(); ++o) {
......@@ -829,6 +823,27 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
return false;
}
}
auto oldGOutputs = oldG->outputNodes();
for (const auto& nodePtr : oldNodes) {
bool removeFromGraphs = true;
if (std::find(oldGOutputs.cbegin(), oldGOutputs.cend(), nodePtr) == oldGOutputs.cend()) {
for (const auto& chPtr : nodePtr->getChildren()) {
if (oldNodes.find(chPtr) == oldNodes.cend()) {
removeFromGraphs = false;
}
}
}
if (removeFromGraphs) {
for (const auto& g : commonGraphViews) {
g -> remove(nodePtr, false);
g -> updateInputsOutputsDelete(nodePtr);
}
nodePtr -> resetConnections(true);
}
}
for (const auto& nodePtr : newNodes) {
for (const auto& g : commonGraphViews) {
g -> add(nodePtr);
......@@ -934,10 +949,10 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Check if node outputs are outputs for the GraphView and add them to the output list if so
IOIndex_t outputIdx = 0;
for (auto orderedChilds : newNode->getOrderedChildren()) {
for (const auto& orderedChilds : newNode->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.end()) {
for (const auto& ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.cend()) {
noInsideConnection = false;
break;
}
......@@ -946,7 +961,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
if (noInsideConnection) {
const auto val = std::make_pair(newNode, outputIdx);
// Output may be already be present (see addChild() with a node already in GraphView)
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
if (std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val) == mOutputNodes.cend()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
newOutputsInsertionPoint = std::next(newOutputsInsertionPoint);
}
......
......@@ -23,6 +23,8 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
......@@ -589,6 +591,56 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0}));
}
SECTION("Nodes with shared parameters") {
auto myConv1 = Conv(1, 5, {1,1}, "conv1");
auto myConv2 = Conv(5, 5, {1,1}, "conv2");
auto myConv3 = Conv(5, 5, {1,1}, "conv3");
auto myConv4 = Conv(5, 5, {1,1}, "conv4");
auto myConv5 = Conv(5, 5, {1,1}, "conv5");
auto sharedWeightTensor = std::make_shared<Tensor>();
sharedWeightTensor->resize({5,5,1,1});
auto sharedWeight = Producer(sharedWeightTensor, "sharedWeight");
sharedWeight -> addChild(myConv2, 0, 1);
sharedWeight -> addChild(myConv3, 0, 1);
sharedWeight -> addChild(myConv4, 0, 1);
auto sharedBiasTensor = std::make_shared<Tensor>();
sharedBiasTensor->resize({5});
auto sharedBias = Producer(sharedBiasTensor, "sharedBias");
sharedBias -> addChild(myConv2, 0, 2);
sharedBias -> addChild(myConv3, 0, 2);
sharedBias -> addChild(myConv4, 0, 2);
auto g = Sequential({
myConv1,
myConv2,
myConv3,
myConv4,
myConv5
});
REQUIRE(g->getNode("sharedWeight") != nullptr);
REQUIRE(g->getNode("sharedBias") != nullptr);
auto newReLU4 = ReLU("relu4");
GraphView::replace({myConv4, myConv4->getParent(1), myConv4->getParent(2)}, {newReLU4});
REQUIRE(g->getNode("sharedWeight") != nullptr);
REQUIRE(g->getNode("sharedBias") != nullptr);
auto newReLU3 = ReLU("relu3");
GraphView::replace({myConv3, myConv3->getParent(1), myConv3->getParent(2)}, {newReLU3});
REQUIRE(g->getNode("sharedWeight") != nullptr);
REQUIRE(g->getNode("sharedBias") != nullptr);
auto newReLU2 = ReLU("relu2");
GraphView::replace({myConv2, myConv2->getParent(1), myConv2->getParent(2)}, {newReLU2});
REQUIRE(g->getNode("sharedWeight") == nullptr);
REQUIRE(g->getNode("sharedBias") == nullptr);
}
}
TEST_CASE("[GraphView] clone") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment