diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 34a6e8c5526804ced3b6ff0f0340a219998d87d4..8e83dff098d3cc355f1d15cc6f0faed5ce563f7a 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -12,12 +12,14 @@ #include <algorithm> #include <cassert> #include <iterator> +#include <memory> #include <utility> #include <numeric> #include <fmt/format.h> #include <fmt/ranges.h> +#include "aidge/graph/Connector.hpp" #include "aidge/utils/Types.h" #include "aidge/graph/GraphView.hpp" #include "aidge/data/Tensor.hpp" @@ -827,36 +829,45 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s return GraphView::replace(oldG, newG); } -bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std::shared_ptr<GraphView>& newG) { +bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const std::shared_ptr<GraphView>& newGraph) { // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // How to distinguish it from data input? // TODO: Parameter Tensors could be identified with their dimensions // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // It also avoids specifying each producer since they are automatically included - const auto& oldNodes = oldG->getNodes(); - const auto& newNodes = newG->getNodes(); - - const auto oldOI = oldG->getOrderedInputs(); - const auto oldOO = oldG->getOrderedOutputs(); - const auto newOI = newG->getOrderedInputs(); - const auto newOO = newG->getOrderedOutputs(); - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOI.size()); - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOO.size()); - - // keep in memory every parent - for (std::size_t i = 0; i < oldOI.size(); ++i) { - auto inputParent = oldOI[i].first -> input(oldOI[i].second); + const std::set<NodePtr>& oldNodes = oldGraph->getNodes(); + const std::set<NodePtr>& newNodes = newGraph->getNodes(); + + const std::vector<std::pair<NodePtr, IOIndex_t>> oldOIn = + oldGraph->getOrderedInputs(); + const std::vector<std::pair<NodePtr, IOIndex_t>> oldOOut = + oldGraph->getOrderedOutputs(); + const std::vector<std::pair<NodePtr, IOIndex_t>> newOIn = + newGraph->getOrderedInputs(); + const std::vector<std::pair<NodePtr, IOIndex_t>> newOOut = + newGraph->getOrderedOutputs(); + + auto inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOIn.size()); + auto outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOOut.size()); + + // keep in memory every node related to the node to replace : + // Parent + for (std::size_t i = 0; i < oldOIn.size(); ++i) { + std::pair<NodePtr, IOIndex_t> inputParent = + oldOIn[i].first -> input(oldOIn[i].second); inputParents[i]= inputParent; // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); } - for (std::size_t i = 0; i < oldOO.size();) { - auto outputChildList = oldOO[i].first -> output(oldOO[i].second); - if (outputChildList.empty()) { + // Children + for (std::size_t i = 0; i < oldOOut.size();) { + std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild = + oldOOut[i].first -> output(oldOOut[i].second); + if (outputChild.empty()) { outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); ++i; } else { - for (const auto& child : outputChildList) { + for (const auto& child : outputChild) { if (oldNodes.find(child.first) == oldNodes.cend()) { outputChildren[i] = child; ++i; @@ -869,37 +880,37 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std // set of common GraphView for oldNodes' Nodes std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); for (const auto& nodePtr : oldNodes) { - const auto nodeView = nodePtr->views(); + const std::set<std::shared_ptr<GraphView>> nodeView = nodePtr->views(); std::set<std::shared_ptr<GraphView>> intersection; std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), nodeView.begin(), nodeView.end(), std::inserter(intersection, intersection.begin())); commonGraphViews = intersection; } - commonGraphViews.erase(oldG); - commonGraphViews.erase(newG); + commonGraphViews.erase(oldGraph); + commonGraphViews.erase(newGraph); - if ((newNodes.size() > 0) && (oldOI.size() != newOI.size()) && (oldOO.size() != newOO.size())) { + if ((newNodes.size() > 0) && (oldOIn.size() != newOIn.size()) && (oldOOut.size() != newOOut.size())) { for (const auto& nodePtr : oldNodes) { - nodePtr->removeView(oldG); + nodePtr->removeView(oldGraph); } for (const auto& nodePtr : newNodes) { - nodePtr->removeView(newG); + nodePtr->removeView(newGraph); } return false; } - if ((oldOI.size() == newOI.size()) && - (oldOO.size() == newOO.size())) { + if ((oldOIn.size() == newOIn.size()) && + (oldOOut.size() == newOOut.size())) { // Case 1 - for (std::size_t i = 0; i < oldOI.size(); ++i) { + for (std::size_t i = 0; i < oldOIn.size(); ++i) { if (inputParents[i].first) { - inputParents[i].first -> addChild(newOI[i].first, inputParents[i].second, newOI[i].second); + inputParents[i].first -> addChild(newOIn[i].first, inputParents[i].second, newOIn[i].second); } } - for (std::size_t o = 0; o < oldOO.size(); ++o) { + for (std::size_t o = 0; o < oldOOut.size(); ++o) { if (outputChildren[o].first) { - newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); + newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second); } } } @@ -908,10 +919,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std // get the number of Children for oldg->outputNodes() if (newNodes.size() == 0) { // Case 3 - if (oldOI.size() == oldOO.size()) { - for (std::size_t i = 0; i < oldOI.size(); ++i) { - if (inputParents[i].first) - inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); + if (oldOIn.size() == oldOOut.size()) { + for (std::size_t i = 0; i < oldOIn.size(); ++i) { if (inputParents[i].first) { inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); } else { @@ -921,46 +930,46 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std } } } - else if ((oldOI.size() == 1) && (inputParents[0].first)) { - for (std::size_t i = 0; i < oldOI.size(); ++i) { + else if ((oldOIn.size() == 1) && (inputParents[0].first)) { + for (std::size_t i = 0; i < oldOIn.size(); ++i) { inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); } } } else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes - ((oldOI.size() == 1) || (newOI.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1 - ((oldOO.size() == newOO.size())) + ((oldOIn.size() == 1) || (newOIn.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1 + ((oldOOut.size() == newOOut.size())) ) { // Case 2 - if ((oldOI.size() == 1) && (inputParents[0].first)) { - for (std::size_t i = 0; i < newOI.size(); ++i) { - inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second); + if ((oldOIn.size() == 1) && (inputParents[0].first)) { + for (std::size_t i = 0; i < newOIn.size(); ++i) { + inputParents[0].first -> addChild(newOIn[i].first, inputParents[0].second, newOIn[i].second); } } else { - for (std::size_t i = 0; i < oldOI.size(); ++i) { + for (std::size_t i = 0; i < oldOIn.size(); ++i) { if (inputParents[i].first) { - inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second); + inputParents[i].first -> addChild(newOIn[0].first, inputParents[i].second, newOIn[0].second); } } } - for (std::size_t o = 0; o < oldOO.size(); ++o) { + for (std::size_t o = 0; o < oldOOut.size(); ++o) { if (outputChildren[o].first) { - newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); + newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second); } } } else { for (const auto& nodePtr : oldNodes) { - nodePtr->removeView(oldG); + nodePtr->removeView(oldGraph); } for (const auto& nodePtr : newNodes) { - nodePtr->removeView(newG); + nodePtr->removeView(newGraph); } return false; } } - auto oldGOutputs = oldG->outputNodes(); + auto oldGOutputs = oldGraph->outputNodes(); for (const auto& nodePtr : oldNodes) { bool removeFromGraphs = true; if (std::find(oldGOutputs.cbegin(), oldGOutputs.cend(), nodePtr) == oldGOutputs.cend()) { @@ -986,10 +995,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std } } for (const auto& nodePtr : oldNodes) { - nodePtr -> removeView(oldG); + nodePtr -> removeView(oldGraph); } for (const auto& nodePtr : newNodes) { - nodePtr -> removeView(newG); + nodePtr -> removeView(newGraph); } return true; } diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 14e166402039230a283ce617e4997c9ad099eed9..374c3895936d657007cae4db386cf2b0187f05aa 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -283,7 +283,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { - std::vector<std::vector<std::shared_ptr<Node>>> children = + auto children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { children[outId] = getChildren(outId);