Skip to content
Snippets Groups Projects
Commit ff65e20b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed multiple outputs support for GraphView::replace()

parent bbdf4db5
No related branches found
No related tags found
No related merge requests found
...@@ -486,6 +486,14 @@ public: ...@@ -486,6 +486,14 @@ public:
*/ */
IOIndex_t getNbFreeDataInputs() const; IOIndex_t getNbFreeDataInputs() const;
/**
* @brief Force update of GraphView inputs/outputs.
* It may be necessary to force the update of GraphView inputs/outputs when
* connections are added or removed inside the GraphView **after** the nodes
* were added.
*/
void updateInputsOutputs();
private: private:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TENSOR MANAGEMENT // TENSOR MANAGEMENT
......
...@@ -910,7 +910,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -910,7 +910,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
newGraph->getOrderedOutputs(); newGraph->getOrderedOutputs();
auto inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOIn.size()); auto inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOIn.size());
auto outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOOut.size()); auto outputChildren = std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(oldOOut.size());
// keep in memory every node related to the node to replace : // keep in memory every node related to the node to replace :
// Parent // Parent
...@@ -921,19 +921,12 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -921,19 +921,12 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
// inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
} }
// Children // Children
for (std::size_t i = 0; i < oldOOut.size();) { for (std::size_t i = 0; i < oldOOut.size(); ++i) {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild = std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild =
oldOOut[i].first -> output(oldOOut[i].second); oldOOut[i].first -> output(oldOOut[i].second);
if (outputChild.empty()) { for (const auto& child : outputChild) {
outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); if (oldNodes.find(child.first) == oldNodes.cend()) {
++i; outputChildren[i].push_back(child);
}
else {
for (const auto& child : outputChild) {
if (oldNodes.find(child.first) == oldNodes.cend()) {
outputChildren[i] = child;
++i;
}
} }
} }
} }
...@@ -971,8 +964,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -971,8 +964,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
} }
} }
for (std::size_t o = 0; o < oldOOut.size(); ++o) { for (std::size_t o = 0; o < oldOOut.size(); ++o) {
if (outputChildren[o].first) { for (const auto child : outputChildren[o]) {
newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second); newOOut[o].first -> addChild(child.first, newOOut[o].second, child.second);
} }
} }
} }
...@@ -982,15 +975,21 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -982,15 +975,21 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
if (newNodes.size() == 0) { if (newNodes.size() == 0) {
// Case 3 // Case 3
if (oldOIn.size() == oldOOut.size()) { if (oldOIn.size() == oldOOut.size()) {
// Same number of inputs and outputs: connect each input to the corresponding output
for (std::size_t i = 0; i < oldOIn.size(); ++i) { for (std::size_t i = 0; i < oldOIn.size(); ++i) {
if (inputParents[i].first) { if (inputParents[i].first) {
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); for (const auto child : outputChildren[i]) {
inputParents[i].first -> addChild(child.first, inputParents[i].second, child.second);
}
} }
} }
} }
else if ((oldOIn.size() == 1) && (inputParents[0].first)) { else if ((oldOIn.size() == 1) && (inputParents[0].first)) {
for (std::size_t i = 0; i < oldOIn.size(); ++i) { // Single input: connect the only input to all the outputs
inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); for (std::size_t i = 0; i < oldOOut.size(); ++i) {
for (const auto child : outputChildren[i]) {
inputParents[0].first -> addChild(child.first, inputParents[0].second, child.second);
}
} }
} }
} }
...@@ -1011,8 +1010,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -1011,8 +1010,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
} }
} }
for (std::size_t o = 0; o < oldOOut.size(); ++o) { for (std::size_t o = 0; o < oldOOut.size(); ++o) {
if (outputChildren[o].first) { for (const auto child : outputChildren[o]) {
newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second); newOOut[o].first -> addChild(child.first, newOOut[o].second, child.second);
} }
} }
} }
...@@ -1061,6 +1060,12 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -1061,6 +1060,12 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
return true; return true;
} }
void Aidge::GraphView::updateInputsOutputs() {
for (auto node : mNodes) {
updateInputsOutputsNew(node);
}
}
void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Can be called several times with the same node, e.g. when addChild() is // Can be called several times with the same node, e.g. when addChild() is
// called on a node already part of the GraphView. In this case, inputs/outputs // called on a node already part of the GraphView. In this case, inputs/outputs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment