/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <algorithm> #include <cassert> #include <iterator> #include <utility> #include "aidge/utils/Types.h" #include "aidge/graph/GraphView.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/ErrorHandling.hpp" /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// Aidge::Connector Aidge::GraphView::operator()( const std::vector<Aidge::Connector> ctors) { // TODO: allow for multiple inputNodes? assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour"); std::shared_ptr<Node> inNode = *inputNodes().begin(); assert((ctors.size() == static_cast<std::size_t>(inNode->nbDataInputs())) && "Wrong number of arguments.\n"); for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) { assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); (void)input; // avoid unused warning } IOIndex_t inID = 0; for (const Connector &ctor : ctors) { assert((ctor.node() != nullptr) && "Input Connector must be associated with a node"); ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()), {inNode, inID++}); } return Connector(*(outputNodes().begin())); } /////////////////////////////////////////////////////// // INNER /////////////////////////////////////////////////////// std::string Aidge::GraphView::name() const { return mName; } void Aidge::GraphView::setName(const std::string &name) { mName = name; } void Aidge::GraphView::save(std::string path, bool verbose) const { FILE *fp = std::fopen((path + ".mmd").c_str(), "w"); std::fprintf(fp, "%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, " "'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n"); std::map<const std::string, std::size_t> typeCounter; std::map<std::shared_ptr<Node>, std::string> namePtrTable; // Start by creating every node for (const std::shared_ptr<Node> &node_ptr : mNodes) { const std::string currentType = node_ptr->type(); if (typeCounter.find(currentType) == typeCounter.end()) typeCounter[currentType] = 0; ++typeCounter[currentType]; const std::string givenName = (node_ptr->name().empty()) ? currentType + std::to_string(typeCounter[currentType]) : node_ptr->name(); namePtrTable[node_ptr] = (currentType + "_" + std::to_string(typeCounter[currentType])); std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), givenName.c_str()); } // Write every link std::size_t emptyInputCounter = 0; for (const std::shared_ptr<Node> &node_ptr : mNodes) { for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) { if ((pa_ptr == nullptr) || !inView(pa_ptr)) { std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter, emptyInputCounter, namePtrTable[node_ptr].c_str()); ++emptyInputCounter; } else { std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(), namePtrTable[node_ptr].c_str()); } } } if (verbose) { for (const auto &c : typeCounter) { std::printf("%s - %zu\n", c.first.c_str(), c.second); } } std::fprintf(fp, "\n"); std::fclose(fp); } /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const { IOIndex_t nbDataInput = 0; for (const std::shared_ptr<Node> &inNode : inputNodes()) { // We cannot simply add inNode->nbDataInputs(), as input nodes may already // have some inputs connected within the GraphView, which would therefore not // constitue inputs (from outside) for the GraphView! const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inNode->dataInputs(); for (const auto& input : inputNodeinputs) { if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { ++nbDataInput; } } } return nbDataInput; } Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { IOIndex_t nbIn = 0; // Free inputs within the GraphView are logically also free inputs from outside // the GraphView. for (const std::shared_ptr<Node>& inputNode : mInputNodes) { nbIn += inputNode->getNbFreeDataInputs(); } return nbIn; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; for (const std::shared_ptr<Node>& inputNode : mInputNodes) { const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); for (const auto& input : inputNodeinputs) { if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { res.push_back(input); } } } return res; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; for (const std::shared_ptr<Node>& inputNode : mInputNodes) { const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); for (const auto& input : inputNodeinputs) { if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { res.push_back(input); } } } return res; } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs(std::string name) const { return mNodeRegistry.at(name)->inputs(); } void Aidge::GraphView::forwardDims() { // setInputs // Link every tensor to the right pointer // following parent - children informations for (std::shared_ptr<Node> nodePtr : getNodes()) { for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { // assess if the input was not already set and is a Tensor then link it to parent output std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i); if (inputI.first) { if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) { if ((strcmp(nodePtr->getOperator()->getRawInput(i)->type(), Tensor::Type) == 0) && (strcmp(inputI.first->getOperator()->getRawOutput(inputI.second)->type(), Tensor::Type)==0)) { // assert provided Data is of "Tensor" type nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second)); } else { assert(false && "Non-tensor entries not handled yet.\n"); } } } else { assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); } } } // Compute dimensions of every node _forwardDims(inputNodes()); } void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { // TODO: support multi-inputs/outputs std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>(); for (std::shared_ptr<Node> nodePtr : listNodes) { if (!nodePtr->getOperator()->outputDimsForwarded()) { nodePtr->getOperator()->computeOutputDims(); } if (!nodePtr->getOperator()->outputDimsForwarded()) { nextList.insert(nodePtr); } else { std::set<std::shared_ptr<Node>> children = nodePtr->getChildren(); nextList.insert(children.begin(), children.end()); } } if (nextList.empty()) { for (std::shared_ptr<Node> nodePtr : getNodes()) { if (!nodePtr->getOperator()->outputDimsForwarded()) { nextList.insert(nodePtr); } } } if (!nextList.empty()) { _forwardDims(nextList); } } void Aidge::GraphView::setBackend(const std::string &backend) { for (auto node : getNodes()) { node->getOperator()->setBackend(backend); } } void Aidge::GraphView::setDatatype(const DataType &datatype) { for (auto node : getNodes()) { node->getOperator()->setDatatype(datatype); } } void Aidge::GraphView::updateOutputNodes() { mOutputNodes.clear(); for (const std::shared_ptr<Node>& go_it : mNodes) { if (go_it->nbOutputs() != go_it->nbValidOutputs()) { // an output linked to nothing mOutputNodes.insert(go_it); continue; } for (const std::shared_ptr<Node>& ch_ptr : go_it->getChildren()) { if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph mOutputNodes.insert(go_it); break; } } } } void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) { if (node->nbOutputs() != node->nbValidOutputs()) { // an output linked to nothing mOutputNodes.insert(node); } else { // don't enter if was already added to outputNodes for (const std::shared_ptr<Node> &ch_ptr : node->getChildren()) { if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph mOutputNodes.insert(node); break; } } } // update other outputNodes for (const std::shared_ptr<Node> &pa_ptr : node->getParents()) { // check if any parent is in OutputNodes too if ((pa_ptr != nullptr) && (mOutputNodes.find(pa_ptr) != mOutputNodes.end())) { // it's a match! Must check if the outputNode // found is still an outputNode bool remove = (pa_ptr->nbOutputs() == pa_ptr->nbValidOutputs()); for (const std::shared_ptr<Node>& ch_ptr : pa_ptr->getChildren()) { if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph remove = false; break; } } if (remove) { mOutputNodes.erase(pa_ptr); } } } } std::vector< std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::GraphView::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> outsideOutputs; for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> outputNodeOutputs = outputNode->outputs(); for (const auto& outputPos : outputNodeOutputs) { // Keep only the nodes connected at this output position that are outside the GraphView std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos; for (const auto& output : outputPos) { if (mNodes.find(output.first) == mNodes.end()) { outsideOutputPos.push_back(output); } } outsideOutputs.push_back(outsideOutputPos); } } return outsideOutputs; } std::vector< std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::GraphView::outputs(std::string nodeName) const { return mNodeRegistry.at(nodeName)->outputs(); } void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/, Aidge::IOIndex_t /*newNodeOutID*/) { printf("Not implemented yet.\n"); } void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) { // add to the GraphView nodes node->addView(shared_from_this()); mNodes.insert(node); if (!(node->name()).empty()) mNodeRegistry.insert(std::make_pair(node->name(), node)); // add learnable parameters to the graph if (includeLearnableParam) { for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) { std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i)); if (parentNode) { parentNode->addView(shared_from_this()); mNodes.insert(parentNode); if (!(parentNode->name()).empty()) mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode)); // check if the Node is an input node updateInputNodes(parentNode); } } } // check if the Node is an input node updateInputNodes(node); // check if the Node is an input node updateOutputNodes(node); } void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { for (auto& nodePtr : otherNodes) { add(nodePtr, includeLearnableParam); } } void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { for (const std::shared_ptr<Node> &node_ptr : graph->getNodes()) { node_ptr->addView(shared_from_this()); mNodes.insert(node_ptr); if (!(node_ptr->name()).empty()) mNodeRegistry.insert(std::make_pair(node_ptr->name(), node_ptr)); // if node_ptr is part of graph inputNodes or outputNodes // if (graph->isInputNode(node_ptr) || graph->isOutputNode(node_ptr)) { // Update OutputNodes/inputNodes updateInputNodes(); updateOutputNodes(); } } void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, std::shared_ptr<Node> fromOutNode, const Aidge::IOIndex_t fromTensor, Aidge::IOIndex_t toTensor) { if (fromOutNode) assert(inView(fromOutNode) && "Output Node not found in the GraphView."); else { assert((outputNodes().size() == 1U) && "Must specify an outputNode or have only one."); fromOutNode = *(outputNodes().begin()); } fromOutNode->addChild(toOtherNode, fromTensor, toTensor); add(toOtherNode); } void Aidge::GraphView::addChild( std::shared_ptr<GraphView> toOtherView, std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode, std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) { // assert output node is valid if (!fromOutNode.first) { assert(outputNodes().size() == 1U && "If no output node is provided, the graph should have only one to " "make the choice explicit."); fromOutNode.first = *(outputNodes().begin()); } else assert(inView(fromOutNode.first)); // assert input node is valid if (!toNode.first) { assert(toOtherView->inputNodes().size() == 1U && "If no intput node is provided, the other graph should have only " "one to make the choice explicit."); toNode.first = *(toOtherView->inputNodes().begin()); } else { assert(toOtherView->inView(toNode.first)); } // Tensor assertions are performed in the Node adChild method fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second); // once linking performed, add other graph to current graph add(toOtherView); } std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const { // TODO: choose if we return a set or a vector std::set<std::shared_ptr<Node>> parents; for (const std::shared_ptr<Node>& inputNode : mInputNodes) { parents.insert(inputNode->getParents().begin(), inputNode->getParents().end()); } return parents; } std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const { std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName); if (it == mNodeRegistry.end()) { printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str()); exit(-1); } return (it->second)->getParents(); } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::GraphView::getOrderedParents() const { std::vector<std::vector<std::shared_ptr<Node>>> parents; for (const std::shared_ptr<Node>& inputNode : mInputNodes) { parents.push_back(inputNode->getParents()); } return parents; } std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const { std::set<std::shared_ptr<Node>> children; for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { children.insert((outputNode->getChildren()).begin(), (outputNode->getChildren()).end()); } return children; } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::GraphView::getChildren(const std::string nodeName) const { std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName); if (it == mNodeRegistry.end()) { printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str()); exit(-1); } return (it->second)->getOrderedChildren(); } std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const { std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode); if (it == mNodes.end()) { printf("No such node in graph.\n"); exit(-1); } return (*it)->getChildren(); } std::shared_ptr<Aidge::Node> Aidge::GraphView::getNode(const std::string& nodeName) const { std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName); if (it != mNodeRegistry.end()) { return it->second; } else { printf("No Node named %s in the current GraphView.\n", nodeName.c_str()); exit(-1); } } void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) { if (mNodes.find(nodePtr) != mNodes.end()) { mNodes.erase(nodePtr); nodePtr->removeView(shared_from_this()); } if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } // same for learnable params if (includeLearnableParam) { for (IOIndex_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) { auto inputI = nodePtr->input(i); bool removeNode = true; for (const auto& parentOutput : inputI.first->outputs()) { for (const auto& childOfParentOutput : parentOutput) { // only remove the learnable parameter if not related to any other Node in the GraphView if (childOfParentOutput.first != nodePtr) { removeNode = false; break; } } } if (removeNode) { // assert Learnable Parameter in the GraphView scope if (mNodes.find(inputI.first) != mNodes.end()) { mNodes.erase(inputI.first); inputI.first->removeView(shared_from_this()); } if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } } } } updateInputNodes(); updateOutputNodes(); } bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) { printf("Swap() not implementated yet. Return false.\n"); return false; } void Aidge::GraphView::link(std::string /*name1_inID*/, std::string /*name2_outID*/) { printf("Not implemented yet.\n"); } void Aidge::GraphView::insertParent(NodePtr childNode, NodePtr newParentNode, IOIndex_t childInputTensorIdx, IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx){ NodePtr currentParentNode = childNode->getParent(childInputTensorIdx); const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second; // Remove child from current parent & current Parent from child currentParentNode->removeChild(childNode, currentParentOutputTensorIdx); // Add child currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx); newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx); add(newParentNode); } bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // How to distinguish it from data input? // TODO: Parameter Tensors could be identified with their dimensions // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // It also avoids specifying each producer since they are automatically included auto oldG = std::make_shared<GraphView>("oldG"); oldG->add(oldNodes, false); auto newG = std::make_shared<GraphView>("newG"); newG->add(newNodes, false); if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) { return false; } if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) || (newG->outputNodes().size() != 1))) { return false; } // there is at least one inputNode in the old/new GraphView std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin()); std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin()); // find Node to link to new input Node //compute number of input for firstPreviousInputNode not in oldNodes set std::size_t nbExternalInputs = 0; std::shared_ptr<Node> externalInput = nullptr; IOIndex_t externalInputId = gk_IODefaultIndex; for (const auto& input : firstPreviousInputNode->inputs()) { if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG nbExternalInputs++; externalInput = input.first; externalInputId = input.second; } } if (nbExternalInputs > 1) { AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set"); } if (oldG->inputNodes().size() > 1){ // one or no input has been identified. Checking every input points to the same source for (const auto& previousInputNode : oldG->inputNodes()) { for (const auto& input : previousInputNode->inputs()) { if (oldNodes.find(input.first) == oldNodes.end()) { if ( (externalInput != input.first) || (externalInputId != input.second) ) { return false; // an inputNode points to an external Node different from the registered one } } } } } if (firstPreviousOutputNode->nbOutputs() != 1) { return false; } // find Node to replicate output connections std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin()); auto copyOutputs = firstPreviousOutputNode->outputs(); // manage Views for newNodes // only keep common views to each node for the new set std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); for (const auto& nodePtr : oldNodes) { const auto nodeView = nodePtr->views(); std::set<std::shared_ptr<GraphView>> intersection; std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), nodeView.begin(), nodeView.end(), std::inserter(intersection, intersection.begin())); commonGraphViews = intersection; } commonGraphViews.erase(oldG); commonGraphViews.erase(newG); // clean Nodes to replace // Do not include common Nodes to avoid cleaning Producers linked to newNodes std::set<std::shared_ptr<Node>> nodesToClean; std::set_difference(oldNodes.begin(), oldNodes.end(), newNodes.begin(), newNodes.end(), std::inserter(nodesToClean, nodesToClean.begin())); for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } // copy output connections if (newOutputNode) { for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { auto outputPairs = copyOutputs[o]; for (const auto& onePair : outputPairs) { newOutputNode->addChild(onePair.first, o, onePair.second); } } } // copy input connections if (!newNodes.empty() && externalInput) { for (const auto& newInputNode : newG->inputNodes()) { IOIndex_t inputId = 0; for (const auto& input : newInputNode->inputs()) { if (newNodes.find(input.first) == newNodes.end()) { externalInput->addChild(newInputNode, externalInputId, inputId); } inputId++; } } } // insert new Nodes in the right GraphViews for (const auto& graphPtr : commonGraphViews) { graphPtr->add(newNodes, false); if (newNodes.empty()) { graphPtr->updateInputNodes(); graphPtr->updateOutputNodes(); } } for (const auto& node : oldNodes) { node->removeView(oldG); } for (const auto& node : newNodes) { node->removeView(newG); } return true; } void Aidge::GraphView::updateInputNodes() { mInputNodes.clear(); for (const std::shared_ptr<Node>& go_ptr : mNodes) { for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) { if ((pa_ptr == nullptr) || (mNodes.find(pa_ptr) == mNodes.end())) { // Parent doesn't exist || Parent not in the graph mInputNodes.insert(go_ptr); break; } } } } void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) { // add node_ptr to inputNode if it can std::size_t filledWithKnownInputs = 0U; bool wasAdded = mInputNodes.find(node) != mInputNodes.end(); for (const std::shared_ptr<Node>& pa_ptr : node->getParents()) { if ((pa_ptr == nullptr) || (mNodes.find(pa_ptr) == mNodes.end())) { // Parent doesn't exist || Parent not in the graph mInputNodes.insert(node); wasAdded = true; break; } ++filledWithKnownInputs; } if (filledWithKnownInputs == node->nbInputs() && wasAdded) { mInputNodes.erase(node); } // update other inputNodes for (const std::shared_ptr<Node>& ch_ptr : node->getChildren()) { // check if any child is in InputNodes too if (mInputNodes.find(ch_ptr) != mInputNodes.end()) { // it's a match! Must check if the inputNode found // is still an inputNode // change here bool remove = true; for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { if (pa_ptr == nullptr || mNodes.find(pa_ptr) == mNodes .end()) { // Parent doesn't exist || Parent not in the graph remove = false; break; } } if (remove) { mInputNodes.erase(ch_ptr); } } } } void Aidge::GraphView::removeInputNode(const std::string nodeName) { std::map<std::string, std::shared_ptr<Node>>::iterator it = mNodeRegistry.find(nodeName); if (it != mNodeRegistry.end()) { const std::shared_ptr<Node> val = (*it).second; if (mInputNodes.find(val) != mInputNodes.end()) { mInputNodes.erase(val); } } } void Aidge::GraphView::removeOutputNode(const std::string nodeName) { std::map<std::string, std::shared_ptr<Node>>::iterator it = mNodeRegistry.find(nodeName); if (it != mNodeRegistry.end()) { const std::shared_ptr<Node> val = (*it).second; if (mOutputNodes.find(val) != mOutputNodes.end()) { mOutputNodes.erase(val); } } } std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); // Map for old node -> new node correspondance std::map<NodePtr, NodePtr> oldToNewNodes; for (const std::shared_ptr<Node> &node_ptr : mNodes) { oldToNewNodes[node_ptr] = cloneNode(node_ptr); } // For each node, convert old node -> new node connections for (auto &oldToNewNode : oldToNewNodes) { if (oldToNewNode.second == nullptr) continue; // deleted node // Add new node to new GraphView newGraph->add(oldToNewNode.second, false); // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr size_t parentId = 0; for (auto parent : oldToNewNode.first->inputs()) { while (oldToNewNodes[parent.first] == nullptr) { // Find next valid parent in line, going backward in the graph assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); const auto& parents = parent.first->inputs(); if (!parents.empty() && parents[0].first != nullptr // a valid parent exists && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView { parent = parents[0]; } else { break; } } if (oldToNewNodes[parent.first]) { oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); } ++parentId; } } // Update OutputNodes/inputNodes newGraph->updateInputNodes(); newGraph->updateOutputNodes(); return newGraph; }