-
vincent lorrain authoredvincent lorrain authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
GraphView.cpp 25.90 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include "aidge/utils/Types.h"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Aidge::Connector Aidge::GraphView::operator()(
const std::vector<Aidge::Connector> ctors) {
// TODO: allow for multiple inputNodes?
assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
std::shared_ptr<Node> inNode = *inputNodes().begin();
assert((ctors.size() == static_cast<std::size_t>(inNode->nbDataInputs())) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
(void)input; // avoid unused warning
}
IOIndex_t inID = 0;
for (const Connector &ctor : ctors) {
assert((ctor.node() != nullptr) &&
"Input Connector must be associated with a node");
ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
{inNode, inID++});
}
return Connector(*(outputNodes().begin()));
}
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(std::string path, bool verbose) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
"'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n");
std::map<const std::string, std::size_t> typeCounter;
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
// Start by creating every node
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0;
++typeCounter[currentType];
const std::string givenName =
(node_ptr->name().empty())
? currentType + std::to_string(typeCounter[currentType])
: node_ptr->name();
namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType]));
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
// Write every link
std::size_t emptyInputCounter = 0;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) {
if ((pa_ptr == nullptr) || !inView(pa_ptr)) {
std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter,
emptyInputCounter, namePtrTable[node_ptr].c_str());
++emptyInputCounter;
} else {
std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(),
namePtrTable[node_ptr].c_str());
}
}
}
if (verbose) {
for (const auto &c : typeCounter) {
std::printf("%s - %zu\n", c.first.c_str(), c.second);
}
}
std::fprintf(fp, "\n");
std::fclose(fp);
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const {
IOIndex_t nbDataInput = 0;
// assert(outputNodes().size() == static_cast<std::size_t>(1));
for (const std::shared_ptr<Node> &inNode : inputNodes()) {
nbDataInput += inNode->nbDataInputs();
}
return nbDataInput;
}
Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
IOIndex_t nbIn = 0;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
nbIn += inputNode->getNbFreeDataInputs();
}
return nbIn;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const {
IOIndex_t nbDataIn = 0U;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
nbDataIn += inputNode->nbDataInputs();
}
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataIn);
nbDataIn = 0U;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->dataInputs();
std::move(inputNodeinputs.begin(), inputNodeinputs.end(),
res.begin() + nbDataIn);
nbDataIn += inputNode->nbDataInputs();
// res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode ->
// inputs()).end());
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const {
std::size_t nbIn = 0U;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
nbIn += inputNode->nbInputs();
}
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbIn);
nbIn = 0U;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->inputs();
std::move(inputNodeinputs.begin(), inputNodeinputs.end(),
res.begin() + nbIn);
nbIn += inputNode->nbInputs();
// res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode ->
// inputs()).end());
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs(std::string name) const {
return mNodeRegistry.at(name)->inputs();
}
void Aidge::GraphView::forwardDims() {
// setInputs
// Link every tensor to the right pointer
// following parent - children informations
for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output
std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
if (inputI.first) {
if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
if ((strcmp(nodePtr->getOperator()->getRawInput(i)->type(), Tensor::Type) == 0) && (strcmp(inputI.first->getOperator()->getRawOutput(inputI.second)->type(), Tensor::Type)==0)) {
// assert provided Data is of "Tensor" type
nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
}
else {
assert(false && "Non-tensor entries not handled yet.\n");
}
}
} else
{
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
}
}
}
// Compute dimensions of every node
_forwardDims(inputNodes());
}
void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
// TODO: support multi-inputs/outputs
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nodePtr->getOperator()->computeOutputDims();
}
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nextList.insert(nodePtr);
} else {
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
nextList.insert(children.begin(), children.end());
}
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
}
}
if (!nextList.empty()) {
_forwardDims(nextList);
}
}
void Aidge::GraphView::setBackend(const std::string &backend) {
for (auto node : getNodes()) {
node->getOperator()->setBackend(backend);
}
}
void Aidge::GraphView::setDatatype(const DataType &datatype) {
for (auto node : getNodes()) {
node->getOperator()->setDatatype(datatype);
}
}
void Aidge::GraphView::updateOutputNodes() {
mOutputNodes.clear();
for (const std::shared_ptr<Node>& go_it : mNodes) {
if (go_it->nbOutputs() !=
go_it->nbValidOutputs()) { // an output linked to nothing
mOutputNodes.insert(go_it);
continue;
}
for (const std::shared_ptr<Node>& ch_ptr : go_it->getChildren()) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
mOutputNodes.insert(go_it);
break;
}
}
}
}
void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) {
if (node->nbOutputs() !=
node->nbValidOutputs()) { // an output linked to nothing
mOutputNodes.insert(node);
} else { // don't enter if was already added to outputNodes
for (const std::shared_ptr<Node> &ch_ptr : node->getChildren()) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
mOutputNodes.insert(node);
break;
}
}
}
// update other outputNodes
for (const std::shared_ptr<Node> &pa_ptr :
node->getParents()) { // check if any parent is in OutputNodes too
if ((pa_ptr != nullptr) &&
(mOutputNodes.find(pa_ptr) !=
mOutputNodes.end())) { // it's a match! Must check if the outputNode
// found is still an outputNode
bool remove = (pa_ptr->nbOutputs() == pa_ptr->nbValidOutputs());
for (const std::shared_ptr<Node>& ch_ptr : pa_ptr->getChildren()) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
remove = false;
break;
}
}
if (remove) {
mOutputNodes.erase(pa_ptr);
}
}
}
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
outputTensors;
for (const std::shared_ptr<Node>& outputNode : mOutputNodes) {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
tmpOutputs = (outputNode->outputs());
outputTensors.insert(outputTensors.end(), tmpOutputs.begin(),
tmpOutputs.end());
}
return outputTensors;
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs(std::string nodeName) const {
return mNodeRegistry.at(nodeName)->outputs();
}
void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
Aidge::IOIndex_t /*newNodeOutID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
// add to the GraphView nodes
node->addView(shared_from_this());
mNodes.insert(node);
if (!(node->name()).empty())
mNodeRegistry.insert(std::make_pair(node->name(), node));
// add learnable parameters to the graph
if (includeLearnableParam) {
for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) {
std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
if (parentNode) {
parentNode->addView(shared_from_this());
mNodes.insert(parentNode);
if (!(parentNode->name()).empty())
mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode));
// check if the Node is an input node
updateInputNodes(parentNode);
}
}
}
// check if the Node is an input node
updateInputNodes(node);
// check if the Node is an input node
updateOutputNodes(node);
}
void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
for (auto& nodePtr : otherNodes) { add(nodePtr, includeLearnableParam); }
}
void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
for (const std::shared_ptr<Node> &node_ptr : graph->getNodes()) {
node_ptr->addView(shared_from_this());
mNodes.insert(node_ptr);
if (!(node_ptr->name()).empty())
mNodeRegistry.insert(std::make_pair(node_ptr->name(), node_ptr));
// if node_ptr is part of graph inputNodes or outputNodes
// if (graph->isInputNode(node_ptr) || graph->isOutputNode(node_ptr)) {
// Update OutputNodes/inputNodes
updateInputNodes();
updateOutputNodes();
}
}
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
std::shared_ptr<Node> fromOutNode,
const Aidge::IOIndex_t fromTensor,
Aidge::IOIndex_t toTensor) {
if (fromOutNode)
assert(inView(fromOutNode) && "Output Node not found in the GraphView.");
else {
assert((outputNodes().size() == 1U) &&
"Must specify an outputNode or have only one.");
fromOutNode = *(outputNodes().begin());
}
fromOutNode->addChild(toOtherNode, fromTensor, toTensor);
add(toOtherNode);
}
void Aidge::GraphView::addChild(
std::shared_ptr<GraphView> toOtherView,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) {
// assert output node is valid
if (!fromOutNode.first) {
assert(outputNodes().size() == 1U &&
"If no output node is provided, the graph should have only one to "
"make the choice explicit.");
fromOutNode.first = *(outputNodes().begin());
} else
assert(inView(fromOutNode.first));
// assert input node is valid
if (!toNode.first) {
assert(toOtherView->inputNodes().size() == 1U &&
"If no intput node is provided, the other graph should have only "
"one to make the choice explicit.");
toNode.first = *(toOtherView->inputNodes().begin());
} else {
assert(toOtherView->inView(toNode.first));
}
// Tensor assertions are performed in the Node adChild method
fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second);
// once linking performed, add other graph to current graph
add(toOtherView);
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const {
// TODO: choose if we return a set or a vector
std::set<std::shared_ptr<Node>> parents;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
parents.insert(inputNode->getParents().begin(),
inputNode->getParents().end());
}
return parents;
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str());
exit(-1);
}
return (it->second)->getParents();
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getOrderedParents() const {
std::vector<std::vector<std::shared_ptr<Node>>> parents;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
parents.push_back(inputNode->getParents());
}
return parents;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const {
std::set<std::shared_ptr<Node>> children;
for (const std::shared_ptr<Node>& outputNode : mOutputNodes) {
children.insert((outputNode->getChildren()).begin(),
(outputNode->getChildren()).end());
}
return children;
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getChildren(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(),
name().c_str());
exit(-1);
}
return (it->second)->getOrderedChildren();
}
std::set<std::shared_ptr<Aidge::Node>>
Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const {
std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode);
if (it == mNodes.end()) {
printf("No such node in graph.\n");
exit(-1);
}
return (*it)->getChildren();
}
std::shared_ptr<Aidge::Node>
Aidge::GraphView::getNode(const std::string& nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
return it->second;
} else {
printf("No Node named %s in the current GraphView.\n", nodeName.c_str());
exit(-1);
}
}
void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) {
if (mNodes.find(nodePtr) != mNodes.end()) {
mNodes.erase(nodePtr);
nodePtr->removeView(shared_from_this());
}
if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }
// same for learnable params
if (includeLearnableParam) {
for (IOIndex_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) {
auto inputI = nodePtr->input(i);
bool removeNode = true;
for (const auto& parentOutput : inputI.first->outputs()) {
for (const auto& childOfParentOutput : parentOutput) {
// only remove the learnable parameter if not related to any other Node in the GraphView
if (childOfParentOutput.first != nodePtr) {
removeNode = false;
break;
}
}
}
if (removeNode) {
// assert Learnable Parameter in the GraphView scope
if (mNodes.find(inputI.first) != mNodes.end()) {
mNodes.erase(inputI.first);
inputI.first->removeView(shared_from_this());
}
if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); }
}
}
}
updateInputNodes();
updateOutputNodes();
}
bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) {
printf("Swap() not implementated yet. Return false.\n");
return false;
}
void Aidge::GraphView::link(std::string /*name1_inID*/,
std::string /*name2_outID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx){
NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
// Remove child from current parent & current Parent from child
currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);
// Add child
currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);
add(newParentNode);
}
bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
// TODO : only supports one input/output node for now
assert(mNodes.size()>0 && "There must be at least one Node to replace");
bool replacable;
std::shared_ptr<Node> previousInputNode = (*inputNodes().begin());
std::shared_ptr<Node> previousOutputNode = (*outputNodes().begin());
std::shared_ptr<Node> newOutputNode;
auto gNew = std::make_shared<GraphView>();
gNew->add(newNodes, false);
if (newNodes.empty()) {
replacable = (outputNodes().size() == 1) &&
(inputNodes().size() == 1) &&
((*outputNodes().begin())->nbOutputs() == 1) &&
((*inputNodes().begin())->nbDataInputs() == 1);
newOutputNode = previousInputNode->input(0).first;
} else {
newOutputNode = (*gNew->outputNodes().begin());
replacable = (outputNodes().size() == gNew->outputNodes().size()) &&
(outputNodes().size() == 1) &&
(previousOutputNode->nbOutputs() == newOutputNode->nbOutputs());
}
if (replacable) {
auto copyOutputs = previousOutputNode->outputs();
// manage Views for newNodes
// only keep common views to each node for the new set
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*mNodes.begin())->views();
for (const auto& nodePtr : mNodes) {
const auto nodeView = nodePtr->views();
std::set<std::shared_ptr<GraphView>> intersection;
std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
nodeView.begin(), nodeView.end(),
std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection;
}
// clean Nodes to replace
std::set<std::shared_ptr<Node>> copyNode = mNodes;
for (auto& nodePtr : copyNode) { nodePtr->resetConnections(true); }
// copy output connections
if (newOutputNode) {
for (IOIndex_t o = 0; o < previousOutputNode->nbOutputs(); ++o) {
auto outputPairs = copyOutputs[o];
for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second);
}
}
}
// insert new Nodes in the right GraphViews
for (auto& graphPtr : commonGraphViews) {
graphPtr->add(newNodes, false);
if (newNodes.empty()) {
graphPtr->updateInputNodes();
graphPtr->updateOutputNodes();
}
}
}
return replacable;
}
void Aidge::GraphView::updateInputNodes() {
mInputNodes.clear();
for (const std::shared_ptr<Node>& go_ptr : mNodes) {
for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
mInputNodes.insert(go_ptr);
break;
}
}
}
}
void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) {
// add node_ptr to inputNode if it can
std::size_t filledWithKnownInputs = 0U;
bool wasAdded = mInputNodes.find(node) != mInputNodes.end();
for (const std::shared_ptr<Node>& pa_ptr : node->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
mInputNodes.insert(node);
wasAdded = true;
break;
}
++filledWithKnownInputs;
}
if (filledWithKnownInputs == node->nbInputs() && wasAdded) {
mInputNodes.erase(node);
}
// update other inputNodes
for (const std::shared_ptr<Node>& ch_ptr :
node->getChildren()) { // check if any child is in InputNodes too
if (mInputNodes.find(ch_ptr) !=
mInputNodes.end()) { // it's a match! Must check if the inputNode found
// is still an inputNode
// change here
bool remove = true;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
if (pa_ptr == nullptr ||
mNodes.find(pa_ptr) ==
mNodes
.end()) { // Parent doesn't exist || Parent not in the graph
remove = false;
break;
}
}
if (remove) {
mInputNodes.erase(ch_ptr);
}
}
}
}
void Aidge::GraphView::removeInputNode(const std::string nodeName) {
std::map<std::string, std::shared_ptr<Node>>::iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
const std::shared_ptr<Node> val = (*it).second;
if (mInputNodes.find(val) != mInputNodes.end()) {
mInputNodes.erase(val);
}
}
}
void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
std::map<std::string, std::shared_ptr<Node>>::iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
const std::shared_ptr<Node> val = (*it).second;
if (mOutputNodes.find(val) != mOutputNodes.end()) {
mOutputNodes.erase(val);
}
}
}
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);
// Map for old node -> new node correspondance
std::map<NodePtr, NodePtr> oldToNewNodes;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
oldToNewNodes[node_ptr] = cloneNode(node_ptr);
}
// For each node, convert old node -> new node connections
for (auto &oldToNewNode : oldToNewNodes) {
if (oldToNewNode.second == nullptr)
continue; // deleted node
// Add new node to new GraphView
newGraph->add(oldToNewNode.second, false);
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0;
for (auto parent : oldToNewNode.first->inputs()) {
while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph
assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs");
const auto& parents = parent.first->inputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
parent = parents[0];
}
else {
break;
}
}
if (oldToNewNodes[parent.first]) {
oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
}
++parentId;
}
}
// Update OutputNodes/inputNodes
newGraph->updateInputNodes();
newGraph->updateOutputNodes();
return newGraph;
}