Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include <numeric>
#include "aidge/utils/Formatting.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Aidge::Connector Aidge::GraphView::operator()(
const std::vector<Aidge::Connector> ctors) {
// TODO: allow for multiple inputNodes?
assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
std::shared_ptr<Node> inNode = *inputNodes().begin();
assert((ctors.size() == static_cast<std::size_t>(inNode->nbData())) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
assert((ctor.node() != nullptr) &&
"Input Connector must be associated with a node");
ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
{inNode, inID++});
}
return Connector(*(outputNodes().begin()));
}
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; }

Maxence Naud
committed
void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
"'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n");
std::map<const std::string, std::size_t> typeCounter;
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
// Start by creating every node
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())

Maxence Naud
committed
typeCounter[currentType] = 0;
? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>"
: "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + "#" + std::to_string(typeCounter[currentType]) + " )</em></sub>\"";
namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType]));
if (node_ptr == mRootNode) {
std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
else {

Maxence Naud
committed
if ((currentType != "Producer") || showProducers) {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
// Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) {

Maxence Naud
committed
if ((node_ptr -> type() == "Producer") && !showProducers) {
continue;
}
IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) {
for (auto parent : child->inputs()) {
if (parent.first == node_ptr && parent.second == outputIdx) {
// Add-on to display the operator's output dimensions
std::string dims = "";
const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator());
if (op && !op->getOutput(outputIdx)->dims().empty()) {
dims += " " + print(op->getOutput(outputIdx)->dims(), "%u");
}
if (mNodes.find(child) != mNodes.end()) {
std::fprintf(fp, "%s-->|\"%u%s→%u\"|%s\n", namePtrTable[node_ptr].c_str(),
outputIdx, dims.c_str(), inputIdx, namePtrTable[child].c_str());
std::fprintf(fp, "%s-->|\"%u%s→%u\"|%p:::externalCls\n", namePtrTable[node_ptr].c_str(),
outputIdx, dims.c_str(), inputIdx, static_cast<void*>(child.get()));
++outputIdx;
}
}
size_t inputIdx = 0;
for (auto input : mInputNodes) {
std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|→%u|%s\n", inputIdx, inputIdx,
input.second, namePtrTable[input.first].c_str());
++inputIdx;
size_t outputIdx = 0;
for (auto output : mOutputNodes) {
// Add-on to display the operator's output dimensions
std::string dims = "";
const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator());
if (op && !op->getOutput(output.second)->dims().empty()) {
dims += " " + print(op->getOutput(output.second)->dims(), "%u");
}
std::fprintf(fp, "%s--->|\"%u%s→\"|output%lu((out#%lu)):::outputCls\n",
namePtrTable[output.first].c_str(), output.second,
dims.c_str(), outputIdx, outputIdx);
++outputIdx;
}
std::fprintf(fp, "classDef inputCls fill:#afa\n");
std::fprintf(fp, "classDef outputCls fill:#ffa\n");
std::fprintf(fp, "classDef externalCls fill:#ccc\n");
std::fprintf(fp, "classDef rootCls stroke:#f00\n");
}
std::fprintf(fp, "\n");
std::fclose(fp);
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) {
AIDGE_ASSERT(inputs.size() <= mInputNodes.size(), "too many specified number of inputs");
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes);
for (auto input : inputs) {
auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input);
AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input");
ignoredInputs.erase(it);
}
mInputNodes = inputs;
mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end());
}
void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) {
AIDGE_ASSERT(outputs.size() <= mOutputNodes.size(), "too many specified number of outputs");
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes);
for (auto output : outputs) {
auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output);
AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output");
ignoredOutputs.erase(it);
}
mOutputNodes = outputs;
mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end());
}
Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const {
IOIndex_t nbDataInput = 0;
for (const std::shared_ptr<Node> &inNode : inputNodes()) {
Olivier BICHLER
committed
// We cannot simply add inNode->nbDataInputs(), as input nodes may already
// have some inputs connected within the GraphView, which would therefore not
// constitue inputs (from outside) for the GraphView!
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inNode->dataInputs();
for (const auto& input : inputNodeinputs) {
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
++nbDataInput;
}
}
Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
IOIndex_t nbIn = 0;
Olivier BICHLER
committed
// Free inputs within the GraphView are logically also free inputs from outside
// the GraphView.
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
nbIn += inputNode->getNbFreeDataInputs();
}
return nbIn;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
Olivier BICHLER
committed
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
for (const auto& input : inputNodeinputs) {
Olivier BICHLER
committed
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
Olivier BICHLER
committed
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
for (const auto& input : inputNodeinputs) {
Olivier BICHLER
committed
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs(std::string name) const {
return mNodeRegistry.at(name)->inputs();
}
void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device) {

Maxence Naud
committed
// Backend
// TODO: add Backend attribute to Operator
setBackend(backend, device);

Maxence Naud
committed
// Data type
// TODO: manage Datatype attribute in OperatorImpl

Maxence Naud
committed
// Data Format
// TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
// Forward dimensions
forwardDims();
}
std::set<NodePtr> startNodes = inputNodes();
// setInputs
// Link every tensor to the right pointer
// following parent - children informations
for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output
std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
if (inputI.first) {
if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
// assert provided Data is of "Tensor" type
nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
}
else {
assert(false && "Non-tensor entries not handled yet.\n");
}
}
} else {
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
if (nodePtr->type() == Producer_Op::Type) {
startNodes.insert(nodePtr);
}
}
void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
// TODO: support multi-inputs/outputs
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
if (!op->outputDimsForwarded()) {
op->computeOutputDims();
}
if (!op->outputDimsForwarded()) { // try to compute output dimensions again later
nextList.insert(nodePtr);
} else { // compute output dimensions of children
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
for (auto child : children) {
const auto childOp = std::static_pointer_cast<OperatorTensor>(child->getOperator());
if (!childOp->outputDimsForwarded()) {
nextList.insert(child);
}
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
}
}
// Internal check to make sure we won't enter in an infinite loop!
AIDGE_ASSERT(nextList != listNodes, "Unable to forward dimensions (circular dependency and/or wrong dimensions?)");
void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) {
node->getOperator()->setBackend(backend, device);
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) {
for (auto node : getNodes()) {
node->getOperator()->setDataType(datatype);
}
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
Olivier BICHLER
committed
outsideOutputs;
for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
Olivier BICHLER
committed
const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
outputNodeOutputs = outputNode->outputs();
for (const auto& outputPos : outputNodeOutputs) {
// Keep only the nodes connected at this output position that are outside the GraphView
std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos;
for (const auto& output : outputPos) {
if (output.first == nullptr || mNodes.find(output.first) == mNodes.end()) {
Olivier BICHLER
committed
outsideOutputPos.push_back(output);
}
}
if (outputPos.empty() || !outsideOutputPos.empty()) {
outsideOutputs.push_back(outsideOutputPos);
}
Olivier BICHLER
committed
}
Olivier BICHLER
committed
return outsideOutputs;
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs(std::string nodeName) const {
return mNodeRegistry.at(nodeName)->outputs();
}
void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
Aidge::IOIndex_t /*newNodeOutID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
// first node to be added to the graph is the root node by default
if (mRootNode == nullptr) {
mRootNode = node;
}
// add to the GraphView nodes
node->addView(shared_from_this());
mNodes.insert(node);
if (!(node->name()).empty())
mNodeRegistry.insert(std::make_pair(node->name(), node));
// check if the node is an input/output node
updateInputsOutputsNew(node);
// add learnable parameters to the graph
if (includeLearnableParam) {
for (IOIndex_t i = node->nbData(); i < node->nbInputs(); ++i) {
std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
if (parentNode) {
parentNode->addView(shared_from_this());
mNodes.insert(parentNode);
if (!(parentNode->name()).empty())
mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode));
// check if the parentNode is an input/output node
updateInputsOutputsNew(parentNode);
bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
if (otherNodes.empty()) {
return true;
}
bool orderUnicity = true;
// List only the nodes that are not already present in current graph
std::set<NodePtr> nodesToAdd;
std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));
// List the nodes to rank, initially all the nodes in the GraphView
std::set<NodePtr> nodesToRank(mNodes);
nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end());
std::vector<NodePtr> rankedNodesToAdd;
if (mRootNode == nullptr) {
std::set<NodePtr> noParentNodes;
// If no root node is defined, check nodes without parents
for (auto node : nodesToRank) {
bool noParent = true;
for (auto parent : node->getParents()) {
if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
noParent = false;
if (noParent) {
noParentNodes.insert(node);
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
// Take the first one found (this is an arbitrary choice)
mRootNode = *noParentNodes.begin();
if (noParentNodes.size() > 1) {
// If there is more than one, order unicity cannot be garanteed!
orderUnicity = false;
}
rankedNodesToAdd.push_back(mRootNode);
}
nodesToRank.erase(mRootNode);
std::vector<NodePtr> rankedNodes;
rankedNodes.push_back(mRootNode);
for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
NodePtr curNode = rankedNodes[curNodeIdx];
for (auto childs : curNode->getOrderedChildren()) {
for (auto child : childs) {
if (nodesToRank.find(child) != nodesToRank.end()) {
rankedNodes.push_back(child);
nodesToRank.erase(child);
if (nodesToAdd.find(child) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(child);
nodesToAdd.erase(child);
for (auto parent : curNode->getParents()) {
if (nodesToRank.find(parent) != nodesToRank.end()) {
rankedNodes.push_back(parent);
nodesToRank.erase(parent);
if (nodesToAdd.find(parent) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(parent);
nodesToAdd.erase(parent);
}
}
if (!nodesToAdd.empty()) {
// There are remaining nodes without path to the root node
orderUnicity = false;
while (!nodesToAdd.empty()) {
const auto it = nodesToAdd.begin();
rankedNodesToAdd.push_back(*it);
nodesToAdd.erase(it);
for (auto node_ptr : rankedNodesToAdd) {
add(node_ptr, includeLearnableParam);
}
return orderUnicity;
bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
mRootNode = nodes.first;
add(nodes.first, includeLearnableParam);
}
return add(nodes.second, includeLearnableParam);
bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
if (mRootNode == nullptr) {
mRootNode = graph->getRootNode();
}
return add(graph->getNodes(), false);
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
}
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
std::shared_ptr<Node> fromOutNode,
const Aidge::IOIndex_t fromTensor,
Aidge::IOIndex_t toTensor) {
if (fromOutNode)
assert(inView(fromOutNode) && "Output Node not found in the GraphView.");
else {
assert((outputNodes().size() == 1U) &&
"Must specify an outputNode or have only one.");
fromOutNode = *(outputNodes().begin());
}
fromOutNode->addChild(toOtherNode, fromTensor, toTensor);
add(toOtherNode);
}
void Aidge::GraphView::addChild(
std::shared_ptr<GraphView> toOtherView,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) {
// assert output node is valid
if (!fromOutNode.first) {
assert(outputNodes().size() == 1U &&
"If no output node is provided, the graph should have only one to "
"make the choice explicit.");
fromOutNode.first = *(outputNodes().begin());
} else
assert(inView(fromOutNode.first));
// assert input node is valid
if (!toNode.first) {
assert(toOtherView->inputNodes().size() == 1U &&
"If no intput node is provided, the other graph should have only "
"one to make the choice explicit.");
toNode.first = *(toOtherView->inputNodes().begin());
} else {
assert(toOtherView->inView(toNode.first));
}
// Tensor assertions are performed in the Node adChild method
fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second);
// once linking performed, add other graph to current graph
add(toOtherView);
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const {
// TODO: choose if we return a set or a vector
std::set<std::shared_ptr<Node>> parents;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
parents.insert(inputNode->getParents().begin(),
inputNode->getParents().end());
}
return parents;
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str());
exit(-1);
}
return (it->second)->getParents();
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getOrderedParents() const {
std::vector<std::vector<std::shared_ptr<Node>>> parents;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
parents.push_back(inputNode->getParents());
}
return parents;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const {
std::set<std::shared_ptr<Node>> children;
for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
children.insert((outputNode->getChildren()).begin(),
(outputNode->getChildren()).end());
}
return children;
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getChildren(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(),
name().c_str());
exit(-1);
}
return (it->second)->getOrderedChildren();
}
std::set<std::shared_ptr<Aidge::Node>>
Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const {
std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode);
if (it == mNodes.end()) {
printf("No such node in graph.\n");
exit(-1);
}
return (*it)->getChildren();
}
std::shared_ptr<Aidge::Node>
Aidge::GraphView::getNode(const std::string& nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =

Maxence Naud
committed
if (it != mNodeRegistry.cend()) {
printf("No Node named %s in the current GraphView.\n", nodeName.c_str());

Maxence Naud
committed
return nullptr;
}
}
void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) {
for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) {
if (inputI.first != nullptr) {
bool removeNode = true;
for (const auto& parentOutput : inputI.first->outputs()) {
for (const auto& childOfParentOutput : parentOutput) {
// only remove the learnable parameter if not related to any other Node in the GraphView
if (childOfParentOutput.first != nodePtr) {
removeNode = false;
break;
}
if (removeNode) {
// assert Learnable Parameter in the GraphView scope
if (mNodes.find(inputI.first) != mNodes.end()) {
mNodes.erase(inputI.first);
inputI.first->removeView(shared_from_this());
}
if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); }
// check if the node was an input/output node
updateInputsOutputsDelete(inputI.first);
}
if (mNodes.find(nodePtr) != mNodes.end()) {
mNodes.erase(nodePtr);
nodePtr->removeView(shared_from_this());
// check if the nodePtr was an input/output node
updateInputsOutputsDelete(nodePtr);
}
if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }
}
bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) {
printf("Swap() not implementated yet. Return false.\n");
return false;
}
void Aidge::GraphView::link(std::string /*name1_inID*/,
std::string /*name2_outID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx){
NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
// Remove child from current parent & current Parent from child
currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);
currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);
add(newParentNode);
bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included
// (1) create GraphViews from both sets of Nodes
auto oldG = std::make_shared<GraphView>("oldG");
oldG->add(oldNodes, false);
auto newG = std::make_shared<GraphView>("newG");
newG->add(newNodes, false);
const auto oldOI = oldG->getOrderedInputs();
const auto oldOO = oldG->getOrderedOutputs();
const auto newOI = newG->getOrderedInputs();
const auto newOO = newG->getOrderedOutputs();
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOI.size());
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOO.size());
// keep in memory every parent
for (std::size_t i = 0; i < oldOI.size(); ++i) {
auto inputParent = oldOI[i].first -> input(oldOI[i].second);
inputParents[i]= inputParent;
// inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
for (std::size_t i = 0; i < oldOO.size();) {
auto outputChildList = oldOO[i].first -> output(oldOO[i].second);
if (outputChildList.empty()) {
outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex});
++i;
else {
for (const auto& child : outputChildList) {
if (oldNodes.find(child.first) == oldNodes.cend()) {
outputChildren[i] = child;
++i;
}
}
}
}
// only keep common views to each node for the new set
// set of common GraphView for oldNodes' Nodes
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views();
for (const auto& nodePtr : oldNodes) {
const auto nodeView = nodePtr->views();
std::set<std::shared_ptr<GraphView>> intersection;
std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
nodeView.begin(), nodeView.end(),
std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection;
commonGraphViews.erase(oldG);
commonGraphViews.erase(newG);
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
if ((newNodes.size() > 0) && (oldOI.size() != newOI.size()) && (oldOO.size() != newOO.size())) {
for (const auto& nodePtr : oldNodes) {
nodePtr->removeView(oldG);
}
for (const auto& nodePtr : newNodes) {
nodePtr->removeView(newG);
}
return false;
}
if ((oldOI.size() == newOI.size()) &&
(oldOO.size() == newOO.size())) {
// Case 1
for (std::size_t i = 0; i < oldOI.size(); ++i) {
if (inputParents[i].first) {
inputParents[i].first -> addChild(newOI[i].first, inputParents[i].second, newOI[i].second);
}
}
for (std::size_t o = 0; o < oldOO.size(); ++o) {
if (outputChildren[o].first) {
newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second);
}
}
}
else {
// get the number of Parents for oldG->inputNodes()
// get the number of Children for oldg->outputNodes()
if (newNodes.size() == 0) {
// Case 3
if (oldOI.size() == oldOO.size()) {
for (std::size_t i = 0; i < oldOI.size(); ++i) {
if (inputParents[i].first)
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);

Maxence Naud
committed
else if ((oldOI.size() == 1) && (inputParents[0].first)) {
for (std::size_t i = 0; i < oldOI.size(); ++i) {
inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second);
}
}
}
else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes
((oldOI.size() == 1) || (newOI.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1
((oldOO.size() == newOO.size()))
) {
// Case 2

Maxence Naud
committed
if ((oldOI.size() == 1) && (inputParents[0].first)) {
for (std::size_t i = 0; i < newOI.size(); ++i) {
inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second);
}
} else {
for (std::size_t i = 0; i < oldOI.size(); ++i) {

Maxence Naud
committed
if (inputParents[i].first) {
inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second);
}
}
}
for (std::size_t o = 0; o < oldOO.size(); ++o) {
if (outputChildren[o].first) {
newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second);
else {
for (const auto& nodePtr : oldNodes) {
nodePtr->removeView(oldG);
}
for (const auto& nodePtr : newNodes) {
nodePtr->removeView(newG);
}
return false;
}

Maxence Naud
committed
auto oldGOutputs = oldG->outputNodes();
for (const auto& nodePtr : oldNodes) {
bool removeFromGraphs = true;
if (std::find(oldGOutputs.cbegin(), oldGOutputs.cend(), nodePtr) == oldGOutputs.cend()) {
for (const auto& chPtr : nodePtr->getChildren()) {
if (oldNodes.find(chPtr) == oldNodes.cend()) {
removeFromGraphs = false;
}
}
}
if (removeFromGraphs) {
for (const auto& g : commonGraphViews) {
g -> remove(nodePtr, false);
g -> updateInputsOutputsDelete(nodePtr);
}
nodePtr -> resetConnections(true);
}
}
for (const auto& nodePtr : newNodes) {
for (const auto& g : commonGraphViews) {
g -> add(nodePtr);
for (const auto& nodePtr : oldNodes) {
nodePtr -> removeView(oldG);
for (const auto& nodePtr : newNodes) {
nodePtr -> removeView(newG);
return true;
}
void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Can be called several times with the same node, e.g. when addChild() is
// called on a node already part of the GraphView. In this case, inputs/outputs
// need to be updated!
std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend();
// Remove inputs that are not input anymore because connected to newNode
for (auto orderedChilds : newNode->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// Check that newNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.cend()) {
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode is connected to it
if (pa_ptr == newNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val);
// Check that it was not already the case (if node UPDATE)
if (iter != mInputNodes.cend()) { // newNode is linked to an actual inputNode to an input connection
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
// Manage newNode parents
// Check if any input connection is an input for the GraphView
IOIndex_t inputIdx = 0U;
for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
const auto val = std::make_pair(newNode, inputIdx);
const auto it = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val);
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) == mNodes.cend())) {
// Parent doesn't exist || Parent not in the graph
if (it == mInputNodes.cend()) {
// If node's inputs are inputs for the GraphView: add them to the input list
// Addition rule:
// - Inputs addition order follows node inputs order
// - Inputs are inserted at the position of the first input removed
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
newInputsInsertionPoint = std::next(newInputsInsertionPoint);
}
} else if (it != mInputNodes.cend()) {
// Parent already in the graph SO edge is not an input anymore for the graph
mInputNodes.erase(it);
}
++inputIdx;
std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend();
// Remove outputs that are not output anymore because connected to newNode
for (const std::shared_ptr<Node>& parent : newNode->getParents()) {
// Check that newNode parent is in current GraphView
if (mNodes.find(parent) != mNodes.cend()) {
for (auto orderedChilds : parent->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// If newNode is connected to it
if (ch_ptr == newNode) {
const auto val = std::make_pair(parent, outputIdx);
const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val);
if (iter != mOutputNodes.cend()) {
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
}
}
}
}
// Check if node outputs are outputs for the GraphView and add them to the output list if so

Maxence Naud
committed
for (const auto& orderedChilds : newNode->getOrderedChildren()) {

Maxence Naud
committed
for (const auto& ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.cend()) {