Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include <numeric>
#include "aidge/utils/Types.h"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Aidge::Connector Aidge::GraphView::operator()(
const std::vector<Aidge::Connector> ctors) {
// TODO: allow for multiple inputNodes?
assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
std::shared_ptr<Node> inNode = *inputNodes().begin();
assert((ctors.size() == static_cast<std::size_t>(inNode->nbData())) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
assert((ctor.node() != nullptr) &&
"Input Connector must be associated with a node");
ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
{inNode, inID++});
}
return Connector(*(outputNodes().begin()));
}
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(std::string path, bool verbose) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
"'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n");
std::map<const std::string, std::size_t> typeCounter;
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
// Start by creating every node
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0;
++typeCounter[currentType];
? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>"
: "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + " )</em></sub>\"";
namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType]));
if (node_ptr == mRootNode) {
std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
else {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
// Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) {
for (auto parent : child->inputs()) {
if (parent.first == node_ptr && parent.second == outputIdx) {
if (mNodes.find(child) != mNodes.end()) {
std::fprintf(fp, "%s-->|%u→%u|%s\n", namePtrTable[node_ptr].c_str(),
outputIdx, inputIdx, namePtrTable[child].c_str());
}
else if (verbose) {
std::fprintf(fp, "%s-->|%u→%u|%p:::externalCls\n", namePtrTable[node_ptr].c_str(),
outputIdx, inputIdx, static_cast<void*>(child.get()));
}
++outputIdx;
}
}
size_t inputIdx = 0;
for (auto input : mInputNodes) {
std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|→%u|%s\n", inputIdx, inputIdx,
input.second, namePtrTable[input.first].c_str());
++inputIdx;
size_t outputIdx = 0;
for (auto output : mOutputNodes) {
std::fprintf(fp, "%s--->|%u→|output%lu((out#%lu)):::outputCls\n",
namePtrTable[output.first].c_str(), output.second,
outputIdx, outputIdx);
++outputIdx;
}
std::fprintf(fp, "classDef inputCls fill:#afa\n");
std::fprintf(fp, "classDef outputCls fill:#ffa\n");
std::fprintf(fp, "classDef externalCls fill:#ccc\n");
std::fprintf(fp, "classDef rootCls stroke:#f00\n");
}
std::fprintf(fp, "\n");
std::fclose(fp);
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) {
AIDGE_ASSERT(inputs.size() <= mInputNodes.size(), "too many specified number of inputs");
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes);
for (auto input : inputs) {
auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input);
AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input");
ignoredInputs.erase(it);
}
mInputNodes = inputs;
mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end());
}
void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) {
AIDGE_ASSERT(outputs.size() <= mOutputNodes.size(), "too many specified number of outputs");
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes);
for (auto output : outputs) {
auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output);
AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output");
ignoredOutputs.erase(it);
}
mOutputNodes = outputs;
mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end());
}
Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const {
IOIndex_t nbDataInput = 0;
for (const std::shared_ptr<Node> &inNode : inputNodes()) {
Olivier BICHLER
committed
// We cannot simply add inNode->nbDataInputs(), as input nodes may already
// have some inputs connected within the GraphView, which would therefore not
// constitue inputs (from outside) for the GraphView!
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inNode->dataInputs();
for (const auto& input : inputNodeinputs) {
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
++nbDataInput;
}
}
Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
IOIndex_t nbIn = 0;
Olivier BICHLER
committed
// Free inputs within the GraphView are logically also free inputs from outside
// the GraphView.
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
nbIn += inputNode->getNbFreeDataInputs();
}
return nbIn;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
Olivier BICHLER
committed
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
for (const auto& input : inputNodeinputs) {
Olivier BICHLER
committed
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
Olivier BICHLER
committed
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
for (const auto& input : inputNodeinputs) {
Olivier BICHLER
committed
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs(std::string name) const {
return mNodeRegistry.at(name)->inputs();
}

Maxence Naud
committed
void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype) {
// Backend
// TODO: add Backend attribute to Operator
setBackend(backend);
// Data type
// TODO: manage Datatype attribute in OperatorImpl

Maxence Naud
committed
// Data Format
// TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
// Forward dimensions
forwardDims();
}
void Aidge::GraphView::forwardDims() {
// setInputs
// Link every tensor to the right pointer
// following parent - children informations
for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output
std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
if (inputI.first) {
if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
// assert provided Data is of "Tensor" type
nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
}
else {
assert(false && "Non-tensor entries not handled yet.\n");
}
}
} else {
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
}
}
// Compute dimensions of every node
_forwardDims(inputNodes());
}
void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
// TODO: support multi-inputs/outputs
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
if (!op->outputDimsForwarded()) {
op->computeOutputDims();
}
if (!op->outputDimsForwarded()) { // try to compute output dimensions again later
nextList.insert(nodePtr);
} else { // compute output dimensions of children
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
nextList.insert(children.begin(), children.end());
}
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
}
}
}
}
void Aidge::GraphView::setBackend(const std::string &backend) {
for (auto node : getNodes()) {
node->getOperator()->setBackend(backend);
}
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) {
for (auto node : getNodes()) {
node->getOperator()->setDataType(datatype);
}
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
Olivier BICHLER
committed
outsideOutputs;
for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
Olivier BICHLER
committed
const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
outputNodeOutputs = outputNode->outputs();
for (const auto& outputPos : outputNodeOutputs) {
// Keep only the nodes connected at this output position that are outside the GraphView
std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos;
for (const auto& output : outputPos) {
if (mNodes.find(output.first) == mNodes.end()) {
outsideOutputPos.push_back(output);
}
}
outsideOutputs.push_back(outsideOutputPos);
}
Olivier BICHLER
committed
return outsideOutputs;
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs(std::string nodeName) const {
return mNodeRegistry.at(nodeName)->outputs();
}
void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
Aidge::IOIndex_t /*newNodeOutID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
// first node to be added to the graph is the root node by default
if (mRootNode == nullptr) {
mRootNode = node;
}
// add to the GraphView nodes
node->addView(shared_from_this());
mNodes.insert(node);
if (!(node->name()).empty())
mNodeRegistry.insert(std::make_pair(node->name(), node));
// check if the node is an input/output node
updateInputsOutputsNew(node);
// add learnable parameters to the graph
if (includeLearnableParam) {
for (IOIndex_t i = node->nbData(); i < node->nbInputs(); ++i) {
std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
if (parentNode) {
parentNode->addView(shared_from_this());
mNodes.insert(parentNode);
if (!(parentNode->name()).empty())
mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode));
// check if the parentNode is an input/output node
updateInputsOutputsNew(parentNode);
bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
if (otherNodes.empty()) {
return true;
}
bool orderUnicity = true;
// List only the nodes that are not already present in current graph
std::set<NodePtr> nodesToAdd;
std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));
// List the nodes to rank, initially all the nodes in the GraphView
std::set<NodePtr> nodesToRank(mNodes);
nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end());
std::vector<NodePtr> rankedNodesToAdd;
if (mRootNode == nullptr) {
std::set<NodePtr> noParentNodes;
// If no root node is defined, check nodes without parents
for (auto node : nodesToRank) {
bool noParent = true;
for (auto parent : node->getParents()) {
if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
noParent = false;
if (noParent) {
noParentNodes.insert(node);
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
// Take the first one found (this is an arbitrary choice)
mRootNode = *noParentNodes.begin();
if (noParentNodes.size() > 1) {
// If there is more than one, order unicity cannot be garanteed!
orderUnicity = false;
}
rankedNodesToAdd.push_back(mRootNode);
}
nodesToRank.erase(mRootNode);
std::vector<NodePtr> rankedNodes;
rankedNodes.push_back(mRootNode);
for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
NodePtr curNode = rankedNodes[curNodeIdx];
for (auto childs : curNode->getOrderedChildren()) {
for (auto child : childs) {
if (nodesToRank.find(child) != nodesToRank.end()) {
rankedNodes.push_back(child);
nodesToRank.erase(child);
if (nodesToAdd.find(child) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(child);
nodesToAdd.erase(child);
for (auto parent : curNode->getParents()) {
if (nodesToRank.find(parent) != nodesToRank.end()) {
rankedNodes.push_back(parent);
nodesToRank.erase(parent);
if (nodesToAdd.find(parent) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(parent);
nodesToAdd.erase(parent);
}
}
if (!nodesToAdd.empty()) {
// There are remaining nodes without path to the root node
orderUnicity = false;
while (!nodesToAdd.empty()) {
const auto it = nodesToAdd.begin();
rankedNodesToAdd.push_back(*it);
nodesToAdd.erase(it);
for (auto node_ptr : rankedNodesToAdd) {
add(node_ptr, includeLearnableParam);
}
return orderUnicity;
bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
mRootNode = nodes.first;
add(nodes.first, includeLearnableParam);
}
return add(nodes.second, includeLearnableParam);
bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
if (mRootNode == nullptr) {
mRootNode = graph->getRootNode();
}
return add(graph->getNodes(), false);
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
}
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
std::shared_ptr<Node> fromOutNode,
const Aidge::IOIndex_t fromTensor,
Aidge::IOIndex_t toTensor) {
if (fromOutNode)
assert(inView(fromOutNode) && "Output Node not found in the GraphView.");
else {
assert((outputNodes().size() == 1U) &&
"Must specify an outputNode or have only one.");
fromOutNode = *(outputNodes().begin());
}
fromOutNode->addChild(toOtherNode, fromTensor, toTensor);
add(toOtherNode);
}
void Aidge::GraphView::addChild(
std::shared_ptr<GraphView> toOtherView,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) {
// assert output node is valid
if (!fromOutNode.first) {
assert(outputNodes().size() == 1U &&
"If no output node is provided, the graph should have only one to "
"make the choice explicit.");
fromOutNode.first = *(outputNodes().begin());
} else
assert(inView(fromOutNode.first));
// assert input node is valid
if (!toNode.first) {
assert(toOtherView->inputNodes().size() == 1U &&
"If no intput node is provided, the other graph should have only "
"one to make the choice explicit.");
toNode.first = *(toOtherView->inputNodes().begin());
} else {
assert(toOtherView->inView(toNode.first));
}
// Tensor assertions are performed in the Node adChild method
fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second);
// once linking performed, add other graph to current graph
add(toOtherView);
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const {
// TODO: choose if we return a set or a vector
std::set<std::shared_ptr<Node>> parents;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
parents.insert(inputNode->getParents().begin(),
inputNode->getParents().end());
}
return parents;
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str());
exit(-1);
}
return (it->second)->getParents();
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getOrderedParents() const {
std::vector<std::vector<std::shared_ptr<Node>>> parents;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
parents.push_back(inputNode->getParents());
}
return parents;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const {
std::set<std::shared_ptr<Node>> children;
for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
children.insert((outputNode->getChildren()).begin(),
(outputNode->getChildren()).end());
}
return children;
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getChildren(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(),
name().c_str());
exit(-1);
}
return (it->second)->getOrderedChildren();
}
std::set<std::shared_ptr<Aidge::Node>>
Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const {
std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode);
if (it == mNodes.end()) {
printf("No such node in graph.\n");
exit(-1);
}
return (*it)->getChildren();
}
std::shared_ptr<Aidge::Node>
Aidge::GraphView::getNode(const std::string& nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
if (it != mNodeRegistry.end()) {
return it->second;
} else {
printf("No Node named %s in the current GraphView.\n", nodeName.c_str());
exit(-1);
}
}
void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) {
for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) {
if (inputI.first != nullptr) {
bool removeNode = true;
for (const auto& parentOutput : inputI.first->outputs()) {
for (const auto& childOfParentOutput : parentOutput) {
// only remove the learnable parameter if not related to any other Node in the GraphView
if (childOfParentOutput.first != nodePtr) {
removeNode = false;
break;
}
if (removeNode) {
// assert Learnable Parameter in the GraphView scope
if (mNodes.find(inputI.first) != mNodes.end()) {
mNodes.erase(inputI.first);
inputI.first->removeView(shared_from_this());
}
if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); }
// check if the node was an input/output node
updateInputsOutputsDelete(inputI.first);
}
if (mNodes.find(nodePtr) != mNodes.end()) {
mNodes.erase(nodePtr);
nodePtr->removeView(shared_from_this());
// check if the nodePtr was an input/output node
updateInputsOutputsDelete(nodePtr);
}
if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }
}
bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) {
printf("Swap() not implementated yet. Return false.\n");
return false;
}
void Aidge::GraphView::link(std::string /*name1_inID*/,
std::string /*name2_outID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx){
NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
// Remove child from current parent & current Parent from child
currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);
currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);
add(newParentNode);
bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included
// (1) create GraphViews from both sets of Nodes
auto oldG = std::make_shared<GraphView>("oldG");
oldG->add(oldNodes, false);
auto newG = std::make_shared<GraphView>("newG");
newG->add(newNodes, false);
const auto oldOI = oldG->getOrderedInputs();
const auto oldOO = oldG->getOrderedOutputs();
const auto newOI = newG->getOrderedInputs();
const auto newOO = newG->getOrderedOutputs();
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOI.size());
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOO.size());
// keep in memory every parent
for (std::size_t i = 0; i < oldOI.size(); ++i) {
auto inputParent = oldOI[i].first -> input(oldOI[i].second);
inputParents[i]= inputParent;
// inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
for (std::size_t i = 0; i < oldOO.size();) {
auto outputChildList = oldOO[i].first -> output(oldOO[i].second);
if (outputChildList.empty()) {
outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex});
++i;
else {
for (const auto& child : outputChildList) {
if (oldNodes.find(child.first) == oldNodes.cend()) {
outputChildren[i] = child;
++i;
}
}
}
}
// only keep common views to each node for the new set
// set of common GraphView for oldNodes' Nodes
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views();
for (const auto& nodePtr : oldNodes) {
const auto nodeView = nodePtr->views();
std::set<std::shared_ptr<GraphView>> intersection;
std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
nodeView.begin(), nodeView.end(),
std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection;
commonGraphViews.erase(oldG);
commonGraphViews.erase(newG);
if ((newNodes.size() > 0) && (oldOI.size() != newOI.size()) && (oldOO.size() != newOO.size())) {
for (const auto& nodePtr : oldNodes) {
nodePtr->removeView(oldG);
}
for (const auto& nodePtr : newNodes) {
nodePtr->removeView(newG);
}
return false;
}
for (const auto& nodePtr : oldNodes) {
for (const auto& g : commonGraphViews) {
g -> remove(nodePtr, false);
g -> updateInputsOutputsDelete(nodePtr);
nodePtr -> resetConnections(true);
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
if ((oldOI.size() == newOI.size()) &&
(oldOO.size() == newOO.size())) {
// Case 1
for (std::size_t i = 0; i < oldOI.size(); ++i) {
if (inputParents[i].first) {
inputParents[i].first -> addChild(newOI[i].first, inputParents[i].second, newOI[i].second);
}
}
for (std::size_t o = 0; o < oldOO.size(); ++o) {
if (outputChildren[o].first) {
newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second);
}
}
}
else {
// get the number of Parents for oldG->inputNodes()
// get the number of Children for oldg->outputNodes()
if (newNodes.size() == 0) {
// Case 3
if (oldOI.size() == oldOO.size()) {
for (std::size_t i = 0; i < oldOI.size(); ++i) {
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
}
}
else if (oldOI.size() == 1) {
for (std::size_t i = 0; i < oldOI.size(); ++i) {
inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second);
}
}
}
else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes
((oldOI.size() == 1) || (newOI.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1
((oldOO.size() == newOO.size()))
) {
// Case 2
if ((oldOI.size() == 1)) {
for (std::size_t i = 0; i < newOI.size(); ++i) {
inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second);
}
} else {
for (std::size_t i = 0; i < oldOI.size(); ++i) {
inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second);
}
}
for (std::size_t o = 0; o < oldOO.size(); ++o) {
if (outputChildren[o].first) {
newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second);
else {
for (const auto& nodePtr : oldNodes) {
nodePtr->removeView(oldG);
}
for (const auto& nodePtr : newNodes) {
nodePtr->removeView(newG);
}
return false;
}
for (const auto& nodePtr : newNodes) {
for (const auto& g : commonGraphViews) {
g -> add(nodePtr);
for (const auto& nodePtr : oldNodes) {
nodePtr -> removeView(oldG);
for (const auto& nodePtr : newNodes) {
nodePtr -> removeView(newG);
return true;
}
void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Can be called several times with the same node, e.g. when addChild() is
// called on a node already part of the GraphView. In this case, inputs/outputs
// need to be updated!
std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend();
// Remove inputs that are not input anymore because connected to newNode
for (auto orderedChilds : newNode->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// Check that newNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.cend()) {
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode is connected to it
if (pa_ptr == newNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val);
// Check that it was not already the case (if node UPDATE)
if (iter != mInputNodes.cend()) { // newNode is linked to an actual inputNode to an input connection
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
// Manage newNode parents
// Check if any input connection is an input for the GraphView
IOIndex_t inputIdx = 0U;
for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
const auto val = std::make_pair(newNode, inputIdx);
const auto it = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val);
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) == mNodes.cend())) {
// Parent doesn't exist || Parent not in the graph
if (it == mInputNodes.cend()) {
// If node's inputs are inputs for the GraphView: add them to the input list
// Addition rule:
// - Inputs addition order follows node inputs order
// - Inputs are inserted at the position of the first input removed
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
newInputsInsertionPoint = std::next(newInputsInsertionPoint);
}
} else if (it != mInputNodes.cend()) {
// Parent already in the graph SO edge is not an input anymore for the graph
mInputNodes.erase(it);
}
++inputIdx;
std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend();
// Remove outputs that are not output anymore because connected to newNode
for (const std::shared_ptr<Node>& parent : newNode->getParents()) {
// Check that newNode parent is in current GraphView
if (mNodes.find(parent) != mNodes.cend()) {
for (auto orderedChilds : parent->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// If newNode is connected to it
if (ch_ptr == newNode) {
const auto val = std::make_pair(parent, outputIdx);
const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val);
if (iter != mOutputNodes.cend()) {
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
}
}
}
}
// Check if node outputs are outputs for the GraphView and add them to the output list if so
for (auto orderedChilds : newNode->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.end()) {
noInsideConnection = false;
if (noInsideConnection) {
const auto val = std::make_pair(newNode, outputIdx);
// Output may be already be present (see addChild() with a node already in GraphView)
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
newOutputsInsertionPoint = std::next(newOutputsInsertionPoint);
}
}
++outputIdx;
}
}
void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) {
std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend();
// Check if node inputs were inputs for the GraphView and remove them from the list if so
for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) {
const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val);
if (iter != mInputNodes.cend()) {
// The first old (removed) input becomes the insertion point for new GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
}
}
// Add child node inputs that become GraphView input following the removal of the node
// Inputs addition order follows deletedNode outputs order
for (auto orderedChilds : deletedNode->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// Check that deletedNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.cend()) {
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode was connected to it
if (pa_ptr == deletedNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
if (std::find(mInputNodes.cbegin(), mInputNodes.cend(), val) == mInputNodes.cend()) {
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
newInputsInsertionPoint = std::next(newInputsInsertionPoint);
}
std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend();
// Check if node outputs were outputs for the GraphView and remove them from the list if so