Forked from
Eclipse Projects / aidge / aidge_core
2048 commits behind the upstream repository.
-
Olivier BICHLER authoredOlivier BICHLER authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
GraphView.cpp 43.34 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include <numeric>
#include "aidge/utils/Types.h"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Aidge::Connector Aidge::GraphView::operator()(
const std::vector<Aidge::Connector> ctors) {
// TODO: allow for multiple inputNodes?
assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
std::shared_ptr<Node> inNode = *inputNodes().begin();
assert((ctors.size() == static_cast<std::size_t>(inNode->nbData())) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
(void)input; // avoid unused warning
}
IOIndex_t inID = 0;
for (const Connector &ctor : ctors) {
assert((ctor.node() != nullptr) &&
"Input Connector must be associated with a node");
ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
{inNode, inID++});
}
return Connector(*(outputNodes().begin()));
}
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(std::string path, bool verbose) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
"'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n");
std::map<const std::string, std::size_t> typeCounter;
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
// Start by creating every node
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0;
++typeCounter[currentType];
std::string givenName =
(node_ptr->name().empty())
? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>"
: node_ptr->name() + " <sub><em>" + currentType + "</em></sub>";
namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType]));
if (node_ptr == mRootNode) {
std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
else {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
}
// Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) {
if (child != nullptr) {
IOIndex_t inputIdx = 0;
for (auto parent : child->inputs()) {
if (parent.first == node_ptr && parent.second == outputIdx) {
if (mNodes.find(child) != mNodes.end()) {
std::fprintf(fp, "%s-->|%u→%u|%s\n", namePtrTable[node_ptr].c_str(),
outputIdx, inputIdx, namePtrTable[child].c_str());
}
else if (verbose) {
std::fprintf(fp, "%s-->|%u→%u|%p:::externalCls\n", namePtrTable[node_ptr].c_str(),
outputIdx, inputIdx, static_cast<void*>(child.get()));
}
break;
}
++inputIdx;
}
}
}
++outputIdx;
}
}
size_t inputIdx = 0;
for (auto input : mInputNodes) {
std::fprintf(fp, "input%lu((in#%lu)):::inputCls-->|→%u|%s\n", inputIdx, inputIdx,
input.second, namePtrTable[input.first].c_str());
++inputIdx;
}
size_t outputIdx = 0;
for (auto output : mOutputNodes) {
std::fprintf(fp, "%s-->|%u→|output%lu((out#%lu)):::outputCls\n",
namePtrTable[output.first].c_str(), output.second,
outputIdx, outputIdx);
++outputIdx;
}
std::fprintf(fp, "classDef inputCls fill:#afa\n");
std::fprintf(fp, "classDef outputCls fill:#ffa\n");
std::fprintf(fp, "classDef externalCls fill:#ccc\n");
std::fprintf(fp, "classDef rootCls stroke:#f00\n");
if (verbose) {
for (const auto &c : typeCounter) {
std::printf("%s - %zu\n", c.first.c_str(), c.second);
}
}
std::fprintf(fp, "\n");
std::fclose(fp);
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) {
AIDGE_ASSERT(inputs.size() <= mInputNodes.size(), "too many specified number of inputs");
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes);
for (auto input : inputs) {
auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input);
AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input");
ignoredInputs.erase(it);
}
mInputNodes = inputs;
mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end());
}
void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) {
AIDGE_ASSERT(outputs.size() <= mOutputNodes.size(), "too many specified number of outputs");
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes);
for (auto output : outputs) {
auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output);
AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output");
ignoredOutputs.erase(it);
}
mOutputNodes = outputs;
mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end());
}
Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const {
IOIndex_t nbDataInput = 0;
for (const std::shared_ptr<Node> &inNode : inputNodes()) {
// We cannot simply add inNode->nbDataInputs(), as input nodes may already
// have some inputs connected within the GraphView, which would therefore not
// constitue inputs (from outside) for the GraphView!
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inNode->dataInputs();
for (const auto& input : inputNodeinputs) {
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
++nbDataInput;
}
}
}
return nbDataInput;
}
Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
IOIndex_t nbIn = 0;
// Free inputs within the GraphView are logically also free inputs from outside
// the GraphView.
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
nbIn += inputNode->getNbFreeDataInputs();
}
return nbIn;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->dataInputs();
for (const auto& input : inputNodeinputs) {
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->inputs();
for (const auto& input : inputNodeinputs) {
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs(std::string name) const {
return mNodeRegistry.at(name)->inputs();
}
void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype) {
// Backend
// TODO: add Backend attribute to Operator
setBackend(backend);
// Data type
// TODO: manage Datatype attribute in OperatorImpl
setDataType(datatype);
// Data Format
// TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
// Forward dimensions
forwardDims();
}
void Aidge::GraphView::forwardDims() {
// setInputs
// Link every tensor to the right pointer
// following parent - children informations
for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output
std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
if (inputI.first) {
if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
if ((strcmp(nodePtr->getOperator()->getRawInput(i)->type(), Tensor::Type) == 0) && (strcmp(inputI.first->getOperator()->getRawOutput(inputI.second)->type(), Tensor::Type)==0)) {
// assert provided Data is of "Tensor" type
nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
}
else {
assert(false && "Non-tensor entries not handled yet.\n");
}
}
} else
{
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
}
}
}
// Compute dimensions of every node
_forwardDims(inputNodes());
}
void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
// TODO: support multi-inputs/outputs
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
if (!op->outputDimsForwarded()) {
op->computeOutputDims();
}
if (!op->outputDimsForwarded()) { // try to compute output dimensions again later
nextList.insert(nodePtr);
} else { // compute output dimensions of children
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
nextList.insert(children.begin(), children.end());
}
}
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
}
}
}
if (!nextList.empty()) {
_forwardDims(nextList);
}
}
void Aidge::GraphView::setBackend(const std::string &backend) {
for (auto node : getNodes()) {
node->getOperator()->setBackend(backend);
}
}
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) {
for (auto node : getNodes()) {
node->getOperator()->setDataType(datatype);
}
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
outsideOutputs;
for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
outputNodeOutputs = outputNode->outputs();
for (const auto& outputPos : outputNodeOutputs) {
// Keep only the nodes connected at this output position that are outside the GraphView
std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos;
for (const auto& output : outputPos) {
if (mNodes.find(output.first) == mNodes.end()) {
outsideOutputPos.push_back(output);
}
}
outsideOutputs.push_back(outsideOutputPos);
}
}
return outsideOutputs;
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs(std::string nodeName) const {
return mNodeRegistry.at(nodeName)->outputs();
}
void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
Aidge::IOIndex_t /*newNodeOutID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
// first node to be added to the graph is the root node by default
if (mRootNode == nullptr) {
mRootNode = node;
}
// add to the GraphView nodes
node->addView(shared_from_this());
mNodes.insert(node);
if (!(node->name()).empty())
mNodeRegistry.insert(std::make_pair(node->name(), node));
// check if the node is an input/output node
updateInputsOutputsNew(node);
// add learnable parameters to the graph
if (includeLearnableParam) {
for (IOIndex_t i = node->nbData(); i < node->nbInputs(); ++i) {
std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
if (parentNode) {
parentNode->addView(shared_from_this());
mNodes.insert(parentNode);
if (!(parentNode->name()).empty())
mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode));
// check if the parentNode is an input/output node
updateInputsOutputsNew(parentNode);
}
}
}
}
bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
if (otherNodes.empty()) {
return true;
}
bool orderUnicity = true;
// List only the nodes that are not already present in current graph
std::set<NodePtr> nodesToAdd;
std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));
// List the nodes to rank, initially all the nodes in the GraphView
std::set<NodePtr> nodesToRank(mNodes);
nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end());
std::vector<NodePtr> rankedNodesToAdd;
if (mRootNode == nullptr) {
std::set<NodePtr> noParentNodes;
// If no root node is defined, check nodes without parents
for (auto node : nodesToRank) {
bool noParent = true;
for (auto parent : node->getParents()) {
if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
noParent = false;
break;
}
}
if (noParent) {
noParentNodes.insert(node);
}
}
// Take the first one found (this is an arbitrary choice)
mRootNode = *noParentNodes.begin();
if (noParentNodes.size() > 1) {
// If there is more than one, order unicity cannot be garanteed!
orderUnicity = false;
}
rankedNodesToAdd.push_back(mRootNode);
}
nodesToRank.erase(mRootNode);
std::vector<NodePtr> rankedNodes;
rankedNodes.push_back(mRootNode);
for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
NodePtr curNode = rankedNodes[curNodeIdx];
for (auto childs : curNode->getOrderedChildren()) {
for (auto child : childs) {
if (nodesToRank.find(child) != nodesToRank.end()) {
rankedNodes.push_back(child);
nodesToRank.erase(child);
if (nodesToAdd.find(child) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(child);
nodesToAdd.erase(child);
}
}
}
}
for (auto parent : curNode->getParents()) {
if (nodesToRank.find(parent) != nodesToRank.end()) {
rankedNodes.push_back(parent);
nodesToRank.erase(parent);
if (nodesToAdd.find(parent) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(parent);
nodesToAdd.erase(parent);
}
}
}
}
if (!nodesToAdd.empty()) {
// There are remaining nodes without path to the root node
orderUnicity = false;
while (!nodesToAdd.empty()) {
const auto it = nodesToAdd.begin();
rankedNodesToAdd.push_back(*it);
nodesToAdd.erase(it);
}
}
for (auto node_ptr : rankedNodesToAdd) {
add(node_ptr, includeLearnableParam);
}
return orderUnicity;
}
bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
if (nodes.first != nullptr) {
mRootNode = nodes.first;
add(nodes.first, includeLearnableParam);
}
return add(nodes.second, includeLearnableParam);
}
bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
if (mRootNode == nullptr) {
mRootNode = graph->getRootNode();
}
return add(graph->getNodes(), false);
}
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
std::shared_ptr<Node> fromOutNode,
const Aidge::IOIndex_t fromTensor,
Aidge::IOIndex_t toTensor) {
if (fromOutNode)
assert(inView(fromOutNode) && "Output Node not found in the GraphView.");
else {
assert((outputNodes().size() == 1U) &&
"Must specify an outputNode or have only one.");
fromOutNode = *(outputNodes().begin());
}
fromOutNode->addChild(toOtherNode, fromTensor, toTensor);
add(toOtherNode);
}
void Aidge::GraphView::addChild(
std::shared_ptr<GraphView> toOtherView,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) {
// assert output node is valid
if (!fromOutNode.first) {
assert(outputNodes().size() == 1U &&
"If no output node is provided, the graph should have only one to "
"make the choice explicit.");
fromOutNode.first = *(outputNodes().begin());
} else
assert(inView(fromOutNode.first));
// assert input node is valid
if (!toNode.first) {
assert(toOtherView->inputNodes().size() == 1U &&
"If no intput node is provided, the other graph should have only "
"one to make the choice explicit.");
toNode.first = *(toOtherView->inputNodes().begin());
} else {
assert(toOtherView->inView(toNode.first));
}
// Tensor assertions are performed in the Node adChild method
fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second);
// once linking performed, add other graph to current graph
add(toOtherView);
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const {
// TODO: choose if we return a set or a vector
std::set<std::shared_ptr<Node>> parents;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
parents.insert(inputNode->getParents().begin(),
inputNode->getParents().end());
}
return parents;
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str());
exit(-1);
}
return (it->second)->getParents();
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getOrderedParents() const {
std::vector<std::vector<std::shared_ptr<Node>>> parents;
for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
parents.push_back(inputNode->getParents());
}
return parents;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const {
std::set<std::shared_ptr<Node>> children;
for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
children.insert((outputNode->getChildren()).begin(),
(outputNode->getChildren()).end());
}
return children;
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getChildren(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(),
name().c_str());
exit(-1);
}
return (it->second)->getOrderedChildren();
}
std::set<std::shared_ptr<Aidge::Node>>
Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const {
std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode);
if (it == mNodes.end()) {
printf("No such node in graph.\n");
exit(-1);
}
return (*it)->getChildren();
}
std::shared_ptr<Aidge::Node>
Aidge::GraphView::getNode(const std::string& nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
return it->second;
} else {
printf("No Node named %s in the current GraphView.\n", nodeName.c_str());
exit(-1);
}
}
void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) {
// remove learnable params
if (includeLearnableParam) {
for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) {
auto inputI = nodePtr->input(i);
if (inputI.first != nullptr) {
bool removeNode = true;
for (const auto& parentOutput : inputI.first->outputs()) {
for (const auto& childOfParentOutput : parentOutput) {
// only remove the learnable parameter if not related to any other Node in the GraphView
if (childOfParentOutput.first != nodePtr) {
removeNode = false;
break;
}
}
}
if (removeNode) {
// assert Learnable Parameter in the GraphView scope
if (mNodes.find(inputI.first) != mNodes.end()) {
mNodes.erase(inputI.first);
inputI.first->removeView(shared_from_this());
}
if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); }
// check if the node was an input/output node
updateInputsOutputsDelete(inputI.first);
}
}
}
}
if (mNodes.find(nodePtr) != mNodes.end()) {
mNodes.erase(nodePtr);
nodePtr->removeView(shared_from_this());
// check if the nodePtr was an input/output node
updateInputsOutputsDelete(nodePtr);
}
if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }
}
bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) {
printf("Swap() not implementated yet. Return false.\n");
return false;
}
void Aidge::GraphView::link(std::string /*name1_inID*/,
std::string /*name2_outID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx){
NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
// Remove child from current parent & current Parent from child
currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);
// Add child
currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);
add(newParentNode);
}
bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included
auto oldG = std::make_shared<GraphView>("oldG");
oldG->add(oldNodes, false);
auto newG = std::make_shared<GraphView>("newG");
newG->add(newNodes, false);
if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) {
return false;
}
if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) ||
(newG->outputNodes().size() != 1))) {
return false;
}
// there is at least one inputNode in the old/new GraphView
std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin());
std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin());
// find Node to link to new input Node
//compute number of input for firstPreviousInputNode not in oldNodes set
std::size_t nbExternalInputs = 0;
std::shared_ptr<Node> externalInput = nullptr;
IOIndex_t externalInputId = gk_IODefaultIndex;
for (const auto& input : firstPreviousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG
nbExternalInputs++;
externalInput = input.first;
externalInputId = input.second;
}
}
if (nbExternalInputs > 1) {
AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set");
}
if (oldG->inputNodes().size() > 1){
// one or no input has been identified. Checking every input points to the same source
for (const auto& previousInputNode : oldG->inputNodes()) {
for (const auto& input : previousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) {
if ( (externalInput != input.first) || (externalInputId != input.second) ) {
return false; // an inputNode points to an external Node different from the registered one
}
}
}
}
}
if (firstPreviousOutputNode->nbOutputs() != 1) {
return false;
}
// find Node to replicate output connections
std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin());
auto copyOutputs = firstPreviousOutputNode->outputs();
// manage Views for newNodes
// only keep common views to each node for the new set
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views();
for (const auto& nodePtr : oldNodes) {
const auto nodeView = nodePtr->views();
std::set<std::shared_ptr<GraphView>> intersection;
std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
nodeView.begin(), nodeView.end(),
std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection;
}
commonGraphViews.erase(oldG);
commonGraphViews.erase(newG);
// clean Nodes to replace
// Do not include common Nodes to avoid cleaning Producers linked to newNodes
std::set<std::shared_ptr<Node>> nodesToClean;
std::set_difference(oldNodes.begin(), oldNodes.end(),
newNodes.begin(), newNodes.end(),
std::inserter(nodesToClean, nodesToClean.begin()));
for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
// copy output connections
if (newOutputNode) {
for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
auto outputPairs = copyOutputs[o];
for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second);
}
}
}
// copy input connections
if (!newNodes.empty() && externalInput) {
for (const auto& newInputNode : newG->inputNodes()) {
IOIndex_t inputId = 0;
for (const auto& input : newInputNode->inputs()) {
if (newNodes.find(input.first) == newNodes.end()) {
externalInput->addChild(newInputNode, externalInputId, inputId);
}
inputId++;
}
}
}
// insert new Nodes in the right GraphViews
for (const auto& graphPtr : commonGraphViews) {
graphPtr->add(newNodes, false);
if (newNodes.empty()) {
// TODO: FIXME: this function should not be called anymore!
graphPtr->updateInputsOutputsNodes_DEPRECATED();
}
}
for (const auto& node : oldNodes) {
node->removeView(oldG);
}
for (const auto& node : newNodes) {
node->removeView(newG);
}
return true;
}
void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Can be called several times with the same node, e.g. when addChild() is
// called on a node already part of the GraphView. In this case, inputs/outputs
// need to be updated!
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end();
// Remove inputs that are not input anymore because connected to newNode
for (auto orderedChilds : newNode->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// Check that newNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.end()) {
IOIndex_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode is connected to it
if (pa_ptr == newNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val);
// Check that it was not already the case (if node UPDATE)
if (iter != mInputNodes.end()) {
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
}
}
++inputIdx;
}
}
}
}
// Check if node inputs are inputs for the GraphView and add them to the input list if so
// Inputs addition order follows node inputs order
// Inputs are inserted at the position of the first input removed
IOIndex_t inputIdx = 0U;
for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
const auto val = std::make_pair(newNode, inputIdx);
if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
newInputsInsertionPoint = std::next(newInputsInsertionPoint);
}
}
++inputIdx;
}
// (if node UPDATE)
// newNode may already exists in the graph and may have been updated
// Check and remove inputs that are not inputs anymore
inputIdx = 0U;
for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
if ((pa_ptr != nullptr) &&
(mNodes.find(pa_ptr) !=
mNodes.end())) {
const auto val = std::make_pair(newNode, inputIdx);
auto it = std::find(mInputNodes.begin(), mInputNodes.end(), val);
if (it != mInputNodes.end()) {
mInputNodes.erase(it);
}
}
++inputIdx;
}
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end();
// Remove outputs that are not output anymore because connected to newNode
for (const std::shared_ptr<Node>& parent : newNode->getParents()) {
// Check that newNode parent is in current GraphView
if (mNodes.find(parent) != mNodes.end()) {
IOIndex_t outputIdx = 0;
for (auto orderedChilds : parent->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// If newNode is connected to it
if (ch_ptr == newNode) {
const auto val = std::make_pair(parent, outputIdx);
const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val);
if (iter != mOutputNodes.end()) {
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
}
}
}
}
++outputIdx;
}
}
}
// Check if node outputs are outputs for the GraphView and add them to the output list if so
IOIndex_t outputIdx = 0;
for (auto orderedChilds : newNode->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.end()) {
noInsideConnection = false;
break;
}
}
if (noInsideConnection) {
const auto val = std::make_pair(newNode, outputIdx);
// Output may be already be present (see addChild() with a node already in GraphView)
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
newOutputsInsertionPoint = std::next(newOutputsInsertionPoint);
}
}
++outputIdx;
}
}
void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) {
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end();
// Check if node inputs were inputs for the GraphView and remove them from the list if so
for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) {
const auto val = std::make_pair(deletedNode, inputIdx);
const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val);
if (iter != mInputNodes.end()) {
// The first old (removed) input becomes the insertion point for new GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
}
}
// Add child node inputs that become GraphView input following the removal of the node
// Inputs addition order follows deletedNode outputs order
for (auto orderedChilds : deletedNode->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// Check that deletedNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.end()) {
IOIndex_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode was connected to it
if (pa_ptr == deletedNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
newInputsInsertionPoint = std::next(newInputsInsertionPoint);
}
}
++inputIdx;
}
}
}
}
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end();
// Check if node outputs were outputs for the GraphView and remove them from the list if so
for (IOIndex_t outputIdx = 0; outputIdx < deletedNode->getOrderedChildren().size(); ++outputIdx) {
const auto val = std::make_pair(deletedNode, outputIdx);
const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val);
if (iter != mOutputNodes.end()) {
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
}
}
}
// Add parent node outputs that become GraphView output following the removal of the node
// Outputs addition order follows deletedNode inputs order
for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) {
if (parent != nullptr && mNodes.find(parent) != mNodes.end()) {
IOIndex_t outputIdx = 0;
for (auto orderedChilds : parent->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.end()) {
noInsideConnection = false;
break;
}
}
if (noInsideConnection) {
const auto val = std::make_pair(parent, outputIdx);
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
newOutputsInsertionPoint = std::next(newOutputsInsertionPoint);
}
}
++outputIdx;
}
}
}
}
void Aidge::GraphView::updateInputsOutputsNodes_DEPRECATED() {
mInputNodes.clear();
for (const std::shared_ptr<Node>& go_ptr : mNodes) {
IOIndex_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
mInputNodes.push_back(std::make_pair(go_ptr, inputIdx));
}
++inputIdx;
}
}
mOutputNodes.clear();
for (const std::shared_ptr<Node>& go_ptr : mNodes) {
IOIndex_t outputIdx = 0;
for (auto orderedChilds : go_ptr->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.end()) {
noInsideConnection = false;
break;
}
}
if (noInsideConnection) {
mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx));
}
++outputIdx;
}
}
}
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);
// Map for old node -> new node correspondance
std::map<NodePtr, NodePtr> oldToNewNodes;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
auto clonedNode = cloneNode(node_ptr);
if (clonedNode == nullptr) {
AIDGE_ASSERT(node_ptr->getChildren().size() <= 1, "deleted nodes in GraphView::clone() cannot have multiple children");
AIDGE_ASSERT(node_ptr->nbData() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents");
}
oldToNewNodes[node_ptr] = clonedNode;
}
// For each node, convert old node -> new node connections
for (auto &oldToNewNode : oldToNewNodes) {
if (oldToNewNode.second == nullptr) {
continue; // deleted node
}
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0;
for (auto parent : oldToNewNode.first->inputs()) {
if (parent.first != nullptr) {
while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph
AIDGE_INTERNAL_ASSERT(parent.first->getChildren().size() == 1);
AIDGE_INTERNAL_ASSERT(parent.first->nbData() <= 1);
const auto& parents = parent.first->dataInputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
parent = parents[0];
}
else {
break;
}
}
if (oldToNewNodes[parent.first]) {
AIDGE_INTERNAL_ASSERT(oldToNewNodes[parent.first]->nbOutputs() == parent.first->nbOutputs());
oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
}
}
++parentId;
}
}
// Once connected, add each new nodes to new GraphView
// This has to be done in a second step to ensure that new GraphView inputs/outputs
// are properly set (otherwise, some node's inputs/outputs may be wrongly registered as
// GraphView inputs/outputs because not yet connected to other nodes)
if (oldToNewNodes[mRootNode] != nullptr) {
// Add root node first if is still exists!
newGraph->add(oldToNewNodes[mRootNode], false);
}
for (auto &oldToNewNode : oldToNewNodes) {
if (oldToNewNode.second == nullptr)
continue; // deleted node
newGraph->add(oldToNewNode.second, false);
}
// Update cloned graph inputs/outputs order to match initial graph order
auto newInputNodes = mInputNodes;
for (auto it = newInputNodes.begin(); it != newInputNodes.end(); ) {
// If input node was removed, find next valid input
while (oldToNewNodes[it->first] == nullptr) {
// Removed node should have only one connected output, otherwise cloning is invalid
AIDGE_INTERNAL_ASSERT(it->first->getChildren().size() <= 1);
bool found = false;
if (it->first->getChildren().size() == 1) {
auto child = *it->first->getChildren().begin();
std::size_t inputIdx = 0;
for (auto parent : child->getParents()) {
if (parent == it->first) {
it->first = child;
it->second = inputIdx;
found = true;
break;
}
++inputIdx;
}
}
if (!found) {
break;
}
}
if (oldToNewNodes[it->first] == nullptr) {
it = newInputNodes.erase(it);
}
else {
it->first = oldToNewNodes[it->first];
++it;
}
}
newGraph->setOrderedInputs(newInputNodes);
auto newOutputNodes = mOutputNodes;
for (auto it = newOutputNodes.begin(); it != newOutputNodes.end(); ) {
// If output node was removed, find previous valid output
while (oldToNewNodes[it->first] == nullptr) {
// Removed node should have only one connected data input, otherwise cloning is invalid
AIDGE_INTERNAL_ASSERT(it->first->nbData() <= 1);
auto parents = it->first->dataInputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
*it = parents[0];
}
else {
break;
}
}
if (oldToNewNodes[it->first] == nullptr) {
it = newOutputNodes.erase(it);
}
else {
it->first = oldToNewNodes[it->first];
++it;
}
}
newGraph->setOrderedOutputs(newOutputNodes);
return newGraph;
}