From da73529994ec233066c200fd41444209391ba41e Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Mon, 11 Dec 2023 01:34:40 +0000 Subject: [PATCH] [Upd] the replace function to use graph ordering --- include/aidge/graph/GraphView.hpp | 23 ++- src/graph/GraphView.cpp | 292 ++++++++++++++++-------------- 2 files changed, 167 insertions(+), 148 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index bf23ef9f0..de25e99d0 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -391,17 +391,24 @@ public: IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx); - /** * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible. * Both sets should include all the necessary Producers. - * @details Replaced Nodes are removed from any GraphView pointing at them all. - * The oldNodes set should have only one input/output - * Tensor for automatic connections of newNodes set. - * @param oldNodes actual set of shared_ptr<Node> to replace. - * @param newNodes new set of shared_ptr<Node>. - * @return true - * @return false + * @details There are 3 cases of replacement: + * Case 1: same number of input/output connections for oldNodes and newNodes sets. + * - input/output connections are replacated according to their IDs. + * Case 2: different number of input/output connections for oldNodes and newNodes sets. + * - only a single parent/child node for the newNodes set, every input/output is + * connected to it. + * - several parents/children nodes for newNodes set => impossible to know, return false + * Case 3: newNodes set is empty + * - same number of input/output connections in oldNodes, parents and children are linked according + * to these connections IDs + * - different number of input/output connections in oldNodes => return false + * @param oldNodes + * @param newNodes + * @return true replacement has been performed + * @return false no replacement has been performed */ static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 2716de300..064cb3a04 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -74,7 +74,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { std::string givenName = (node_ptr->name().empty()) ? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>" - : node_ptr->name() + " <sub><em>" + currentType + "</em></sub>"; + : "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + " )</em></sub>\""; namePtrTable[node_ptr] = (currentType + "_" + std::to_string(typeCounter[currentType])); @@ -117,14 +117,14 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { size_t inputIdx = 0; for (auto input : mInputNodes) { - std::fprintf(fp, "input%lu((in#%lu)):::inputCls-->|→%u|%s\n", inputIdx, inputIdx, + std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|→%u|%s\n", inputIdx, inputIdx, input.second, namePtrTable[input.first].c_str()); ++inputIdx; } size_t outputIdx = 0; for (auto output : mOutputNodes) { - std::fprintf(fp, "%s-->|%u→|output%lu((out#%lu)):::outputCls\n", + std::fprintf(fp, "%s--->|%u→|output%lu((out#%lu)):::outputCls\n", namePtrTable[output.first].c_str(), output.second, outputIdx, outputIdx); ++outputIdx; @@ -694,128 +694,150 @@ void Aidge::GraphView::insertParent(NodePtr childNode, add(newParentNode); } - bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { - // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // How to distinguish it from data input? // TODO: Parameter Tensors could be identified with their dimensions // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // It also avoids specifying each producer since they are automatically included + // (1) create GraphViews from both sets of Nodes auto oldG = std::make_shared<GraphView>("oldG"); oldG->add(oldNodes, false); auto newG = std::make_shared<GraphView>("newG"); newG->add(newNodes, false); - if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) { - return false; - } - if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) || - (newG->outputNodes().size() != 1))) { - return false; + const auto oldOI = oldG->getOrderedInputs(); + const auto oldOO = oldG->getOrderedOutputs(); + const auto newOI = newG->getOrderedInputs(); + const auto newOO = newG->getOrderedOutputs(); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOI.size()); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOO.size()); + + // keep in memory every parent + for (std::size_t i = 0; i < oldOI.size(); ++i) { + auto inputParent = oldOI[i].first -> input(oldOI[i].second); + inputParents[i]= inputParent; + // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); } - - // there is at least one inputNode in the old/new GraphView - std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin()); - std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin()); - - // find Node to link to new input Node - //compute number of input for firstPreviousInputNode not in oldNodes set - std::size_t nbExternalInputs = 0; - std::shared_ptr<Node> externalInput = nullptr; - IOIndex_t externalInputId = gk_IODefaultIndex; - for (const auto& input : firstPreviousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG - nbExternalInputs++; - externalInput = input.first; - externalInputId = input.second; + for (std::size_t i = 0; i < oldOO.size();) { + auto outputChildList = oldOO[i].first -> output(oldOO[i].second); + if (outputChildList.empty()) { + outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); + ++i; } - } - if (nbExternalInputs > 1) { - AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set"); - } - - if (oldG->inputNodes().size() > 1){ - // one or no input has been identified. Checking every input points to the same source - for (const auto& previousInputNode : oldG->inputNodes()) { - for (const auto& input : previousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { - if ( (externalInput != input.first) || (externalInputId != input.second) ) { - return false; // an inputNode points to an external Node different from the registered one - } + else { + for (const auto& child : outputChildList) { + if (oldNodes.find(child.first) == oldNodes.cend()) { + outputChildren[i] = child; + ++i; } } } } - if (firstPreviousOutputNode->nbOutputs() != 1) { - return false; - } - - // find Node to replicate output connections - std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin()); - - auto copyOutputs = firstPreviousOutputNode->outputs(); - // manage Views for newNodes // only keep common views to each node for the new set + // set of common GraphView for oldNodes' Nodes std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); for (const auto& nodePtr : oldNodes) { - const auto nodeView = nodePtr->views(); - std::set<std::shared_ptr<GraphView>> intersection; - std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), - nodeView.begin(), nodeView.end(), - std::inserter(intersection, intersection.begin())); - commonGraphViews = intersection; + const auto nodeView = nodePtr->views(); + std::set<std::shared_ptr<GraphView>> intersection; + std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), + nodeView.begin(), nodeView.end(), + std::inserter(intersection, intersection.begin())); + commonGraphViews = intersection; } commonGraphViews.erase(oldG); commonGraphViews.erase(newG); - // clean Nodes to replace - // Do not include common Nodes to avoid cleaning Producers linked to newNodes - std::set<std::shared_ptr<Node>> nodesToClean; - std::set_difference(oldNodes.begin(), oldNodes.end(), - newNodes.begin(), newNodes.end(), - std::inserter(nodesToClean, nodesToClean.begin())); - for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } - - // copy output connections - if (newOutputNode) { - for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { - auto outputPairs = copyOutputs[o]; - for (const auto& onePair : outputPairs) { - newOutputNode->addChild(onePair.first, o, onePair.second); - } + if ((newNodes.size() > 0) && (oldOI.size() != newOI.size()) && (oldOO.size() != newOO.size())) { + for (const auto& nodePtr : oldNodes) { + nodePtr->removeView(oldG); + } + for (const auto& nodePtr : newNodes) { + nodePtr->removeView(newG); + } + return false; + } + + for (const auto& nodePtr : oldNodes) { + for (const auto& g : commonGraphViews) { + g -> remove(nodePtr, false); + g -> updateInputsOutputsDelete(nodePtr); } + nodePtr -> resetConnections(true); } - // copy input connections - if (!newNodes.empty() && externalInput) { - for (const auto& newInputNode : newG->inputNodes()) { - IOIndex_t inputId = 0; - for (const auto& input : newInputNode->inputs()) { - if (newNodes.find(input.first) == newNodes.end()) { - externalInput->addChild(newInputNode, externalInputId, inputId); + if ((oldOI.size() == newOI.size()) && + (oldOO.size() == newOO.size())) { + // Case 1 + for (std::size_t i = 0; i < oldOI.size(); ++i) { + if (inputParents[i].first) { + inputParents[i].first -> addChild(newOI[i].first, inputParents[i].second, newOI[i].second); + } + } + for (std::size_t o = 0; o < oldOO.size(); ++o) { + if (outputChildren[o].first) { + newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); + } + } + } + else { + // get the number of Parents for oldG->inputNodes() + // get the number of Children for oldg->outputNodes() + if (newNodes.size() == 0) { + // Case 3 + if (oldOI.size() == oldOO.size()) { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); + } + } + else if (oldOI.size() == 1) { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); + } + } + } + else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes + ((oldOI.size() == 1) || (newOI.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1 + ((oldOO.size() == newOO.size())) + ) { + // Case 2 + if ((oldOI.size() == 1)) { + for (std::size_t i = 0; i < newOI.size(); ++i) { + inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second); + } + } else { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second); + } + } + for (std::size_t o = 0; o < oldOO.size(); ++o) { + if (outputChildren[o].first) { + newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); } - inputId++; } } + else { + for (const auto& nodePtr : oldNodes) { + nodePtr->removeView(oldG); + } + for (const auto& nodePtr : newNodes) { + nodePtr->removeView(newG); + } + return false; + } } - - // insert new Nodes in the right GraphViews - for (const auto& graphPtr : commonGraphViews) { - graphPtr->add(newNodes, false); - if (newNodes.empty()) { - // TODO: FIXME: this function should not be called anymore! - graphPtr->updateInputsOutputsNodes_DEPRECATED(); + for (const auto& nodePtr : newNodes) { + for (const auto& g : commonGraphViews) { + g -> add(nodePtr); } } - - for (const auto& node : oldNodes) { - node->removeView(oldG); + for (const auto& nodePtr : oldNodes) { + nodePtr -> removeView(oldG); } - for (const auto& node : newNodes) { - node->removeView(newG); + for (const auto& nodePtr : newNodes) { + nodePtr -> removeView(newG); } return true; } @@ -824,22 +846,22 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { // Can be called several times with the same node, e.g. when addChild() is // called on a node already part of the GraphView. In this case, inputs/outputs // need to be updated! - std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end(); + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend(); // Remove inputs that are not input anymore because connected to newNode for (auto orderedChilds : newNode->getOrderedChildren()) { for (auto ch_ptr : orderedChilds) { // Check that newNode child is in current GraphView - if (mNodes.find(ch_ptr) != mNodes.end()) { + if (mNodes.find(ch_ptr) != mNodes.cend()) { IOIndex_t inputIdx = 0; for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { // If newNode is connected to it if (pa_ptr == newNode) { const auto val = std::make_pair(ch_ptr, inputIdx); - const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); + const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); // Check that it was not already the case (if node UPDATE) - if (iter != mInputNodes.end()) { + if (iter != mInputNodes.cend()) { // newNode is linked to an actual inputNode to an input connection // The first old (removed) input becomes the insertion point for newNode GraphView inputs if (std::distance(newInputsInsertionPoint, iter) <= 0) { newInputsInsertionPoint = mInputNodes.erase(iter); @@ -855,55 +877,45 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { } } - // Check if node inputs are inputs for the GraphView and add them to the input list if so - // Inputs addition order follows node inputs order - // Inputs are inserted at the position of the first input removed - IOIndex_t inputIdx = 0U; - for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { - if ((pa_ptr == nullptr) || - (mNodes.find(pa_ptr) == - mNodes.end())) { // Parent doesn't exist || Parent not in the graph - const auto val = std::make_pair(newNode, inputIdx); - if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { - newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); - newInputsInsertionPoint = std::next(newInputsInsertionPoint); - } - } - ++inputIdx; - } - - // (if node UPDATE) - // newNode may already exists in the graph and may have been updated - // Check and remove inputs that are not inputs anymore - inputIdx = 0U; - for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { - if ((pa_ptr != nullptr) && - (mNodes.find(pa_ptr) != - mNodes.end())) { - const auto val = std::make_pair(newNode, inputIdx); - auto it = std::find(mInputNodes.begin(), mInputNodes.end(), val); - if (it != mInputNodes.end()) { - mInputNodes.erase(it); - } + // Manage newNode parents + // Check if any input connection is an input for the GraphView + IOIndex_t inputIdx = 0U; + for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { + const auto val = std::make_pair(newNode, inputIdx); + const auto it = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); + if ((pa_ptr == nullptr) || + (mNodes.find(pa_ptr) == mNodes.cend())) { + // Parent doesn't exist || Parent not in the graph + if (it == mInputNodes.cend()) { + // If node's inputs are inputs for the GraphView: add them to the input list + // Addition rule: + // - Inputs addition order follows node inputs order + // - Inputs are inserted at the position of the first input removed + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); + } + } else if (it != mInputNodes.cend()) { + // Parent already in the graph SO edge is not an input anymore for the graph + mInputNodes.erase(it); + } + ++inputIdx; } - ++inputIdx; - } - std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end(); + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend(); // Remove outputs that are not output anymore because connected to newNode for (const std::shared_ptr<Node>& parent : newNode->getParents()) { // Check that newNode parent is in current GraphView - if (mNodes.find(parent) != mNodes.end()) { + if (mNodes.find(parent) != mNodes.cend()) { IOIndex_t outputIdx = 0; for (auto orderedChilds : parent->getOrderedChildren()) { for (auto ch_ptr : orderedChilds) { // If newNode is connected to it if (ch_ptr == newNode) { const auto val = std::make_pair(parent, outputIdx); - const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val); + const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val); - if (iter != mOutputNodes.end()) { + if (iter != mOutputNodes.cend()) { // The first old (removed) output becomes the insertion point for newNode GraphView outputs if (std::distance(newOutputsInsertionPoint, iter) <= 0) { newOutputsInsertionPoint = mOutputNodes.erase(iter); @@ -943,14 +955,14 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { } void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) { - std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end(); + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend(); // Check if node inputs were inputs for the GraphView and remove them from the list if so for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) { const auto val = std::make_pair(deletedNode, inputIdx); - const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); + const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); - if (iter != mInputNodes.end()) { + if (iter != mInputNodes.cend()) { // The first old (removed) input becomes the insertion point for new GraphView inputs if (std::distance(newInputsInsertionPoint, iter) <= 0) { newInputsInsertionPoint = mInputNodes.erase(iter); @@ -966,13 +978,13 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo for (auto orderedChilds : deletedNode->getOrderedChildren()) { for (auto ch_ptr : orderedChilds) { // Check that deletedNode child is in current GraphView - if (mNodes.find(ch_ptr) != mNodes.end()) { + if (mNodes.find(ch_ptr) != mNodes.cend()) { IOIndex_t inputIdx = 0; for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { // If newNode was connected to it if (pa_ptr == deletedNode) { const auto val = std::make_pair(ch_ptr, inputIdx); - if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { + if (std::find(mInputNodes.cbegin(), mInputNodes.cend(), val) == mInputNodes.cend()) { newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); newInputsInsertionPoint = std::next(newInputsInsertionPoint); } @@ -982,15 +994,15 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo } } } - - std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end(); + + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend(); // Check if node outputs were outputs for the GraphView and remove them from the list if so for (IOIndex_t outputIdx = 0; outputIdx < deletedNode->getOrderedChildren().size(); ++outputIdx) { const auto val = std::make_pair(deletedNode, outputIdx); - const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val); + const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val); - if (iter != mOutputNodes.end()) { + if (iter != mOutputNodes.cend()) { // The first old (removed) output becomes the insertion point for newNode GraphView outputs if (std::distance(newOutputsInsertionPoint, iter) <= 0) { newOutputsInsertionPoint = mOutputNodes.erase(iter); @@ -1004,7 +1016,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo // Add parent node outputs that become GraphView output following the removal of the node // Outputs addition order follows deletedNode inputs order for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { - if (parent != nullptr && mNodes.find(parent) != mNodes.end()) { + if (mNodes.find(parent) != mNodes.end()) { IOIndex_t outputIdx = 0; for (auto orderedChilds : parent->getOrderedChildren()) { bool noInsideConnection = true; @@ -1017,7 +1029,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo if (noInsideConnection) { const auto val = std::make_pair(parent, outputIdx); - if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { + if (std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val) == mOutputNodes.cend()) { newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); } -- GitLab