diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 6fc571c66c398c85accb4fd83bbb05f01d64cc53..4ed9cd106615be6d664e06dbf933e1d4342f95c4 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -15,7 +15,7 @@ #include <map> #include <memory> -#include <set> +#include <unordered_set> #include <string> #include <utility> #include <vector> @@ -41,11 +41,11 @@ private: /// @brief Set of nodes included in the graphview with names std::map<std::string, NodePtr> mNodeRegistry; - /// @brief Nodes without input link (computable, cached) - std::set<NodePtr> mInputNodes; + /// @brief GraphView inputs + std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes; - /// @brief Nodes without output link (computable, cached) - std::set<NodePtr> mOutputNodes; + /// @brief GraphView outputs + std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes; public: GraphView(std::string name="") @@ -54,11 +54,21 @@ public: // ctor } - // GraphView(std::set<NodePtr> nodes, std::string name="") - // : mName(name) - // { - // add(nodes); - // } + /** + * Construct a GraphView from a set of nodes. The startNode parameters + * allows to define a default inputs/ouputs order relative to this node. + * For two topologically identical graphs, using the same topological node + * as starting node will lead to the same topologically ordered inputs/outputs list. + * Otherwise, inputs/outputs order will be arbitrary. + */ + GraphView(std::set<NodePtr> nodes, NodePtr startNode = nullptr, std::string name="") + : mName(name) + { + if (startNode != nullptr) { + add(startNode, false); + } + add(nodes); + } bool operator==(const GraphView &gv) const { @@ -110,19 +120,35 @@ public: /////////////////////////////////////////////////////// public: /** @brief Get reference to the set of input Nodes. */ - inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; } + inline const std::set<NodePtr>& inputNodes() const noexcept { + std::set<NodePtr> nodes; + for (auto node : mInputNodes) { + nodes.insert(node.first); + } + return nodes; + } /** @brief Get reference to the set of output Nodes. */ - inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; } - + inline const std::set<NodePtr>& outputNodes() const noexcept { + std::set<NodePtr> nodes; + for (auto node : mOutputNodes) { + nodes.insert(node.first); + } + return nodes; + } /** @brief Assess if the given Node is an input Node of the GraphView object. */ inline bool isInputNode(NodePtr nodePtr) const { - return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false; + const auto nodes = inputNodes(); + return (nodes.find(nodePtr) != nodes.end()) ? true : false; } /** @brief Assess if the given Node is an output Node of the GraphView object. */ inline bool isOutputNode(NodePtr nodePtr) const { - return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false; + const auto nodes = outputNodes(); + return (nodes.find(nodePtr) != nodes.end()) ? true : false; } + void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); + void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); + /** * @brief List outside data input connections of the GraphView. * Data inputs exclude inputs expecting parameters (weights or bias). @@ -357,11 +383,10 @@ public: */ static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); - void updateInputNodes(); /** - * @brief Process from zero the set of output Nodes. - */ - void updateOutputNodes(); + * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes with ordered inputs/outputs in a GraphView if possible. + */ + static bool replace(const std::shared_ptr<GraphView>& oldNodes, const std::shared_ptr<GraphView>& newNodes); /** * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones. @@ -415,27 +440,33 @@ private: IOIndex_t getNbDataInputs() const; /** - * @brief Update the set of inputNodes with a new Node, checking if it can be - * added and removing any Node not part of mInputNode anymore. + * @brief Update inputs/outputs of the GraphView, with no particular order. + * This function DOES NOT preserve inputs/outputs order and should NOT BE USED. + * It is here only to leave time to adapt the replace() function. + */ + [[deprecated]] void updateInputsOutputsNodes(); + + /** + * @brief Automatically update GraphView inputs/outputs with a new Node, checking if + * it this Node becomes an input/output for the graph and if previous inputs are still + * inputs/outputs after adding this node. * @param nodePtr */ - void updateInputNodes(NodePtr node); + void updateInputsOutputsNew(NodePtr newNode); /** - * @brief Update the set of outputNodes with a new Node, checking if it can be - * added and removing any Node not part of mOutputNode anymore. + * @brief Automatically update GraphView inputs/outputs with a Node removed, checking if + * it this Node was an input/output for the graph and if this node childs become new inputs/outputs + * for the graph. * @param nodePtr */ - void updateOutputNodes(NodePtr node); + void updateInputsOutputsDelete(NodePtr deletedNode); /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// void _forwardDims(std::set<NodePtr> listNodes); - - void removeInputNode(const std::string nodeName); - void removeOutputNode(const std::string nodeName); }; } // namespace Aidge diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index e8e10798b583ba87685d3c47a6143bb0d195f8a2..3854e5b05aded73c92f5acc99d1e30ef6306fd23 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -106,6 +106,34 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { // TENSOR MANAGEMENT /////////////////////////////////////////////////////// +void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) { + AIDGE_ASSERT(inputs.size() > mInputNodes.size(), "too many specified number of inputs"); + + std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes); + for (auto input : inputs) { + auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input); + AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input"); + ignoredInputs.erase(it); + } + + mInputNodes = inputs; + mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end()); +} + +void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) { + AIDGE_ASSERT(outputs.size() > mOutputNodes.size(), "too many specified number of outputs"); + + std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes); + for (auto output : outputs) { + auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output); + AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output"); + ignoredOutputs.erase(it); + } + + mOutputNodes = outputs; + mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end()); +} + Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const { IOIndex_t nbDataInput = 0; for (const std::shared_ptr<Node> &inNode : inputNodes()) { @@ -128,7 +156,7 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { IOIndex_t nbIn = 0; // Free inputs within the GraphView are logically also free inputs from outside // the GraphView. - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { nbIn += inputNode->getNbFreeDataInputs(); } return nbIn; @@ -139,7 +167,7 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); @@ -157,7 +185,7 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); @@ -242,7 +270,7 @@ void Aidge::GraphView::setDatatype(const DataType &datatype) { node->getOperator()->setDatatype(datatype); } } - +/* void Aidge::GraphView::updateOutputNodes() { mOutputNodes.clear(); for (const std::shared_ptr<Node>& go_it : mNodes) { @@ -292,13 +320,13 @@ void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) { } } } - +*/ std::vector< std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::GraphView::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> outsideOutputs; - for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { + for (const std::shared_ptr<Node>& outputNode : outputNodes()) { const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> outputNodeOutputs = outputNode->outputs(); @@ -334,6 +362,10 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara mNodes.insert(node); if (!(node->name()).empty()) mNodeRegistry.insert(std::make_pair(node->name(), node)); + + // check if the node is an input/output node + updateInputsOutputsNew(node); + // add learnable parameters to the graph if (includeLearnableParam) { for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) { @@ -343,33 +375,74 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara mNodes.insert(parentNode); if (!(parentNode->name()).empty()) mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode)); - // check if the Node is an input node - updateInputNodes(parentNode); + // check if the parentNode is an input/output node + updateInputsOutputsNew(parentNode); } } } - // check if the Node is an input node - updateInputNodes(node); - // check if the Node is an input node - updateOutputNodes(node); } void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { - for (auto& nodePtr : otherNodes) { add(nodePtr, includeLearnableParam); } + // List only the nodes that are not already present in current graph + std::set<NodePtr> nodesToAdd; + std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::back_inserter(nodesToAdd)); + + do { + std::set<NodePtr> nextNodesToAdd; + + // Find nodes that are direct parent of current GraphView and add them first + // such that the obtained GraphView inputs list will be the same, regardless + // of the evaluation order of those nodes + // (i.e. one of their child is in current GraphView) + for (const std::shared_ptr<Node> &node_ptr : nodesToAdd) { + for (auto child : node_ptr->getChildren()) { + if (mNodes.find(child) != mNodes.end()) { + nextNodesToAdd.insert(node_ptr); + nodesToAdd.erase(node_ptr); + break; + } + } + } + + // If there is no more parent, find nodes that are direct children of current GraphView, + // such that the obtained GraphView outputs list will be the same, regardless + // of the evaluation order of those nodes + // (i.e. one of their parent is in current GraphView) + // TODO: this might be done simultaneously with direct parents, by removing + // the empty() condition, but there might be edge cases that may change + // the resulting inputs/outputs order depending on evaluation order (???) + if (nextNodesToAdd.empty()) { + for (const std::shared_ptr<Node> &node_ptr : nodesToAdd) { + for (auto parent : node_ptr->getParents()) { + if (mNodes.find(parent) != mNodes.end()) { + nextNodesToAdd.insert(node_ptr); + nodesToAdd.erase(node_ptr); + break; + } + } + } + } + + // If no node if found, there might be remaining nodes that form an independant sub-graph + // In this case, additionnal inputs/outputs will be added at the end of + // the GraphView inputs/outputs list, in no particular order. + // TODO: we might try to preserve the initial inputs/ouputs relative order of those nodes + // if they actually comes from a GraphView, but I think that would be a far-fetched expectation + // from the users... + if (nextNodesToAdd.empty()) { + nodesToAdd.swap(nextNodesToAdd); + } + + // Add selected nodes in the current GraphView, in no particular order + for (auto node_ptr : nextNodesToAdd) { + add(node_ptr, includeLearnableParam); + } + } + while (!nodesToAdd.empty()); } void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { - for (const std::shared_ptr<Node> &node_ptr : graph->getNodes()) { - node_ptr->addView(shared_from_this()); - mNodes.insert(node_ptr); - if (!(node_ptr->name()).empty()) - mNodeRegistry.insert(std::make_pair(node_ptr->name(), node_ptr)); - // if node_ptr is part of graph inputNodes or outputNodes - // if (graph->isInputNode(node_ptr) || graph->isOutputNode(node_ptr)) { - // Update OutputNodes/inputNodes - updateInputNodes(); - updateOutputNodes(); - } + add(graph->getNodes(), false); } void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, @@ -417,7 +490,7 @@ void Aidge::GraphView::addChild( std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const { // TODO: choose if we return a set or a vector std::set<std::shared_ptr<Node>> parents; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { parents.insert(inputNode->getParents().begin(), inputNode->getParents().end()); } @@ -436,7 +509,7 @@ std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::GraphView::getOrderedParents() const { std::vector<std::vector<std::shared_ptr<Node>>> parents; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { parents.push_back(inputNode->getParents()); } return parents; @@ -444,7 +517,7 @@ Aidge::GraphView::getOrderedParents() const { std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { + for (const std::shared_ptr<Node>& outputNode : outputNodes()) { children.insert((outputNode->getChildren()).begin(), (outputNode->getChildren()).end()); } @@ -488,13 +561,7 @@ Aidge::GraphView::getNode(const std::string& nodeName) const { void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) { - if (mNodes.find(nodePtr) != mNodes.end()) { - mNodes.erase(nodePtr); - nodePtr->removeView(shared_from_this()); - } - if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } - // same for learnable params - + // remove learnable params if (includeLearnableParam) { for (IOIndex_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) { auto inputI = nodePtr->input(i); @@ -515,11 +582,21 @@ void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnab inputI.first->removeView(shared_from_this()); } if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } + + // check if the node was an input/output node + updateInputsOutputsDelete(inputI.first); } } } - updateInputNodes(); - updateOutputNodes(); + + if (mNodes.find(nodePtr) != mNodes.end()) { + mNodes.erase(nodePtr); + nodePtr->removeView(shared_from_this()); + } + if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } + + // check if the nodePtr was an input/output node + updateInputsOutputsDelete(nodePtr); } @@ -662,8 +739,8 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s for (const auto& graphPtr : commonGraphViews) { graphPtr->add(newNodes, false); if (newNodes.empty()) { - graphPtr->updateInputNodes(); - graphPtr->updateOutputNodes(); + // TODO: FIXME: this function should not be called anymore! + graphPtr->updateInputsOutputsNodes(); } } @@ -676,21 +753,216 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s return true; } - +/* void Aidge::GraphView::updateInputNodes() { - mInputNodes.clear(); + std::set<std::pair<NodePtr, IOIndex_t>> inputNodes; for (const std::shared_ptr<Node>& go_ptr : mNodes) { + size_t inputIdx = 0; for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) { if ((pa_ptr == nullptr) || (mNodes.find(pa_ptr) == mNodes.end())) { // Parent doesn't exist || Parent not in the graph - mInputNodes.insert(go_ptr); + inputNodes.insert(std::make_pair(go_ptr, inputIdx)); + } + ++inputIdx; + } + } + + // Remove inputs that are not input anymore (deleted node or input connected internally) + for (auto it = mInputNodes.begin(); it != mInputNodes.end(); ++it) { + if (inputNodes.find(*it) == inputNodes.end()) { + it = mInputNodes.erase(it); + } + } + + // Add remaining new inputs + for (auto inputNode : inputNodes) { + mInputNodes.push_back(inputNode); + } +} +*/ + +void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { + std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end(); + + // Remove inputs that are not input anymore because connected to newNode + for (auto orderedChilds : newNode->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // Check that newNode child is in current GraphView + if (mNodes.find(ch_ptr) != mNodes.end()) { + std::size_t inputIdx = 0; + for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { + // If newNode is connected to it + if (pa_ptr == newNode) { + const auto val = std::make_pair(ch_ptr, inputIdx); + const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); + + // The first old (removed) input becomes the insertion point for newNode GraphView inputs + if (std::distance(newInputsInsertionPoint, iter) <= 0) { + newInputsInsertionPoint = mInputNodes.erase(iter); + } + else { + mInputNodes.erase(iter); + } + } + ++inputIdx; + } + } + } + } + + // Check if node inputs are inputs for the GraphView and add them to the input list if so + // Inputs addition order follows node inputs order + // Inputs are inserted at the position of the first input removed + std::size_t inputIdx = 0U; + for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { + if ((pa_ptr == nullptr) || + (mNodes.find(pa_ptr) == + mNodes.end())) { // Parent doesn't exist || Parent not in the graph + const auto val = std::make_pair(newNode, inputIdx); + // Make sure to not add this input twice, as updateInputsNew() may be + // called several times for the same node. + if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + } + } + ++inputIdx; + } + + std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end(); + + // Remove outputs that are not output anymore because connected to newNode + std::size_t outputIdx = 0; + for (const std::shared_ptr<Node>& parent : newNode->getParents()) { + // Check that newNode parent is in current GraphView + if (mNodes.find(parent) != mNodes.end()) { + for (auto orderedChilds : parent->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // If newNode is connected to it + if (ch_ptr == newNode) { + const auto val = std::make_pair(parent, outputIdx); + const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val); + + // The first old (removed) output becomes the insertion point for newNode GraphView outputs + if (std::distance(newOutputsInsertionPoint, iter) <= 0) { + newOutputsInsertionPoint = mOutputNodes.erase(iter); + } + else { + mOutputNodes.erase(iter); + } + } + } + } + } + ++outputIdx; + } + + // Check if node outputs are outputs for the GraphView and add them to the output list if so + outputIdx = 0; + for (auto orderedChilds : newNode->getOrderedChildren()) { + bool noInsideConnection = true; + for (auto ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) != mNodes.end()) { + noInsideConnection = false; break; } } + + if (noInsideConnection) { + const auto val = std::make_pair(newNode, outputIdx); + if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { + newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + } + } + ++outputIdx; + } +} + +void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) { + std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end(); + + // Check if node inputs were inputs for the GraphView and remove them from the list if so + std::size_t inputIdx = 0; + for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { + const auto val = std::make_pair(deletedNode, inputIdx); + const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); + + // The first old (removed) input becomes the insertion point for newNode GraphView inputs + if (std::distance(newInputsInsertionPoint, iter) <= 0) { + newInputsInsertionPoint = mInputNodes.erase(iter); + } + else { + mInputNodes.erase(iter); + } + ++inputIdx; + } + + // Add child node inputs that become GraphView input following the removal of the node + // Inputs addition order follows deletedNode outputs order + for (auto orderedChilds : deletedNode->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // Check that deletedNode child is in current GraphView + if (mNodes.find(ch_ptr) != mNodes.end()) { + inputIdx = 0; + for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { + // If newNode was connected to it + if (pa_ptr == deletedNode) { + const auto val = std::make_pair(ch_ptr, inputIdx); + // Make sure to not add this input twice, as updateInputsNew() may be + // called several times for the same node. + if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + } + } + ++inputIdx; + } + } + } + } + + std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end(); + + // Check if node outputs were outputs for the GraphView and remove them from the list if so + std::size_t outputIdx = 0; + for (auto orderedChilds : deletedNode->getOrderedChildren()) { + const auto val = std::make_pair(deletedNode, outputIdx); + const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val); + + // The first old (removed) output becomes the insertion point for newNode GraphView outputs + if (std::distance(newOutputsInsertionPoint, iter) <= 0) { + newOutputsInsertionPoint = mOutputNodes.erase(iter); + } + else { + mOutputNodes.erase(iter); + } + ++outputIdx; + } + + // Add parent node outputs that become GraphView output following the removal of the node + // Outputs addition order follows deletedNode inputs order + for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { + std::size_t outputIdx = 0; + for (auto orderedChilds : parent->getOrderedChildren()) { + bool noInsideConnection = true; + for (auto ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) != mNodes.end()) { + noInsideConnection = false; + break; + } + } + + if (noInsideConnection) { + const auto val = std::make_pair(parent, outputIdx); + if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { + newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + } + } + ++outputIdx; + } } } +/* void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) { // add node_ptr to inputNode if it can std::size_t filledWithKnownInputs = 0U; @@ -731,8 +1003,8 @@ void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) { } } } - - +*/ +/* void Aidge::GraphView::removeInputNode(const std::string nodeName) { std::map<std::string, std::shared_ptr<Node>>::iterator it = mNodeRegistry.find(nodeName); @@ -754,7 +1026,7 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { } } } - +*/ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); @@ -770,16 +1042,14 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone if (oldToNewNode.second == nullptr) continue; // deleted node - // Add new node to new GraphView - newGraph->add(oldToNewNode.second, false); - // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr size_t parentId = 0; for (auto parent : oldToNewNode.first->inputs()) { while (oldToNewNodes[parent.first] == nullptr) { // Find next valid parent in line, going backward in the graph - assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); - const auto& parents = parent.first->inputs(); + AIDGE_ASSERT(parent.first->getChildren().size() == 1, "deleted nodes in GraphView::clone() cannot have multiple children"); + AIDGE_ASSERT(parent.first->nbDataInputs() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents"); + const auto& parents = parent.first->dataInputs(); if (!parents.empty() && parents[0].first != nullptr // a valid parent exists && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView @@ -792,6 +1062,8 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone } if (oldToNewNodes[parent.first]) { + AIDGE_ASSERT(oldToNewNodes[parent.first]->nbOutputs() == parent.first->nbOutputs(), + "next valid parent after deleted nodes in GraphView::clone() has wrong number of outputs"); oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); } @@ -799,9 +1071,64 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone } } - // Update OutputNodes/inputNodes - newGraph->updateInputNodes(); - newGraph->updateOutputNodes(); + // Once connected, add each new nodes to new GraphView + // This has to be done in a second step to ensure that new GraphView inputs/outputs + // are properly set (otherwise, some node's inputs/outputs may be wrongly registered as + // GraphView inputs/outputs because not yet connected to other nodes) + for (auto &oldToNewNode : oldToNewNodes) { + if (oldToNewNode.second == nullptr) + continue; // deleted node + + newGraph->add(oldToNewNode.second, false); + } + + // Update cloned graph inputs/outputs order to match initial graph order + auto newInputNodes = mInputNodes; + for (auto it = newInputNodes.begin(); it != newInputNodes.end(); ++it) { + // If input node was removed, find next valid input + while (oldToNewNodes[it->first] == nullptr) { + // Removed node should have only one connected output, otherwise cloning is invalid + AIDGE_INTERNAL_ASSERT(it->first->getChildren().size() == 1); + auto child = *it->first->getChildren().begin(); + + bool found = false; + std::size_t inputIdx = 0; + for (auto parent : child->getParents()) { + if (parent == it->first) { + it->first = child; + it->second = inputIdx; + found = true; + break; + } + ++inputIdx; + } + + if (!found) { + it = newInputNodes.erase(it); + break; + } + } + } + newGraph->setOrderedInputs(newInputNodes); + + auto newOutputNodes = mOutputNodes; + for (auto it = newOutputNodes.begin(); it != newOutputNodes.end(); ++it) { + // If output node was removed, find previous valid output + while (oldToNewNodes[it->first] == nullptr) { + // Removed node should have only one connected data input, otherwise cloning is invalid + AIDGE_INTERNAL_ASSERT(it->first->nbDataInputs() <= 1); + auto parents = it->first->dataInputs(); + + if (!parents.empty()) { + *it = parents[0]; + } + else { + it = newOutputNodes.erase(it); + break; + } + } + } + newGraph->setOrderedOutputs(newOutputNodes); return newGraph; }