From ff65e20b241f5bfe9addf904c0de4067220d4232 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 30 Apr 2024 11:48:12 +0200 Subject: [PATCH] Fixed multiple outputs support for GraphView::replace() --- include/aidge/graph/GraphView.hpp | 8 ++++++ src/graph/GraphView.cpp | 43 +++++++++++++++++-------------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 1a1714272..c9a4c11d7 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -486,6 +486,14 @@ public: */ IOIndex_t getNbFreeDataInputs() const; + /** + * @brief Force update of GraphView inputs/outputs. + * It may be necessary to force the update of GraphView inputs/outputs when + * connections are added or removed inside the GraphView **after** the nodes + * were added. + */ + void updateInputsOutputs(); + private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index b748bd4bc..273825172 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -910,7 +910,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const newGraph->getOrderedOutputs(); auto inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOIn.size()); - auto outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOOut.size()); + auto outputChildren = std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(oldOOut.size()); // keep in memory every node related to the node to replace : // Parent @@ -921,19 +921,12 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); } // Children - for (std::size_t i = 0; i < oldOOut.size();) { + for (std::size_t i = 0; i < oldOOut.size(); ++i) { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild = oldOOut[i].first -> output(oldOOut[i].second); - if (outputChild.empty()) { - outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); - ++i; - } - else { - for (const auto& child : outputChild) { - if (oldNodes.find(child.first) == oldNodes.cend()) { - outputChildren[i] = child; - ++i; - } + for (const auto& child : outputChild) { + if (oldNodes.find(child.first) == oldNodes.cend()) { + outputChildren[i].push_back(child); } } } @@ -971,8 +964,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const } } for (std::size_t o = 0; o < oldOOut.size(); ++o) { - if (outputChildren[o].first) { - newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second); + for (const auto child : outputChildren[o]) { + newOOut[o].first -> addChild(child.first, newOOut[o].second, child.second); } } } @@ -982,15 +975,21 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const if (newNodes.size() == 0) { // Case 3 if (oldOIn.size() == oldOOut.size()) { + // Same number of inputs and outputs: connect each input to the corresponding output for (std::size_t i = 0; i < oldOIn.size(); ++i) { if (inputParents[i].first) { - inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); + for (const auto child : outputChildren[i]) { + inputParents[i].first -> addChild(child.first, inputParents[i].second, child.second); + } } } } else if ((oldOIn.size() == 1) && (inputParents[0].first)) { - for (std::size_t i = 0; i < oldOIn.size(); ++i) { - inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); + // Single input: connect the only input to all the outputs + for (std::size_t i = 0; i < oldOOut.size(); ++i) { + for (const auto child : outputChildren[i]) { + inputParents[0].first -> addChild(child.first, inputParents[0].second, child.second); + } } } } @@ -1011,8 +1010,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const } } for (std::size_t o = 0; o < oldOOut.size(); ++o) { - if (outputChildren[o].first) { - newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second); + for (const auto child : outputChildren[o]) { + newOOut[o].first -> addChild(child.first, newOOut[o].second, child.second); } } } @@ -1061,6 +1060,12 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const return true; } +void Aidge::GraphView::updateInputsOutputs() { + for (auto node : mNodes) { + updateInputsOutputsNew(node); + } +} + void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { // Can be called several times with the same node, e.g. when addChild() is // called on a node already part of the GraphView. In this case, inputs/outputs -- GitLab