Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include "aidge/utils/Types.h"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Aidge::Connector Aidge::GraphView::operator()(
const std::vector<Aidge::Connector> ctors) {
// TODO: allow for multiple inputNodes?
assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
std::shared_ptr<Node> inNode = *inputNodes().begin();
assert((ctors.size() == static_cast<std::size_t>(inNode->nbDataInputs())) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
assert((ctor.node() != nullptr) &&
"Input Connector must be associated with a node");
ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
{inNode, inID++});
}
return Connector(*(outputNodes().begin()));
}
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(std::string path, bool verbose) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
"'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n");
std::map<const std::string, std::size_t> typeCounter;
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
// Start by creating every node
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0;
++typeCounter[currentType];
const std::string givenName =
(node_ptr->name().empty())
? currentType + std::to_string(typeCounter[currentType])
: node_ptr->name();
namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType]));
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
// Write every link
std::size_t emptyInputCounter = 0;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) {
if ((pa_ptr == nullptr) || !inView(pa_ptr)) {
std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter,
emptyInputCounter, namePtrTable[node_ptr].c_str());
++emptyInputCounter;
} else {
std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(),
namePtrTable[node_ptr].c_str());
}
}
}
if (verbose) {
for (const auto &c : typeCounter) {
std::printf("%s - %zu\n", c.first.c_str(), c.second);
}
}
std::fprintf(fp, "\n");
std::fclose(fp);
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const {
IOIndex_t nbDataInput = 0;
for (const std::shared_ptr<Node> &inNode : inputNodes()) {
Olivier BICHLER
committed
// We cannot simply add inNode->nbDataInputs(), as input nodes may already
// have some inputs connected within the GraphView, which would therefore not
// constitue inputs (from outside) for the GraphView!
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inNode->dataInputs();
for (const auto& input : inputNodeinputs) {
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
++nbDataInput;
}
}
Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
IOIndex_t nbIn = 0;
Olivier BICHLER
committed
// Free inputs within the GraphView are logically also free inputs from outside
// the GraphView.
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
nbIn += inputNode->getNbFreeDataInputs();
}
return nbIn;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
Olivier BICHLER
committed
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
for (const auto& input : inputNodeinputs) {
Olivier BICHLER
committed
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
Olivier BICHLER
committed
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
for (const auto& input : inputNodeinputs) {
Olivier BICHLER
committed
if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs(std::string name) const {
return mNodeRegistry.at(name)->inputs();
}
void Aidge::GraphView::forwardDims() {
// setInputs
// Link every tensor to the right pointer
// following parent - children informations
for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output
std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
if (inputI.first) {
if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
if ((strcmp(nodePtr->getOperator()->getRawInput(i)->type(), Tensor::Type) == 0) && (strcmp(inputI.first->getOperator()->getRawOutput(inputI.second)->type(), Tensor::Type)==0)) {
// assert provided Data is of "Tensor" type
nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
}
else {
assert(false && "Non-tensor entries not handled yet.\n");
}
}
} else
{
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
}
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
}
}
// Compute dimensions of every node
_forwardDims(inputNodes());
}
void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
// TODO: support multi-inputs/outputs
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nodePtr->getOperator()->computeOutputDims();
}
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nextList.insert(nodePtr);
} else {
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
nextList.insert(children.begin(), children.end());
}
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
}
}
if (!nextList.empty()) {
_forwardDims(nextList);
}
}
void Aidge::GraphView::setBackend(const std::string &backend) {
for (auto node : getNodes()) {
node->getOperator()->setBackend(backend);
}
}
void Aidge::GraphView::setDatatype(const DataType &datatype) {
for (auto node : getNodes()) {
node->getOperator()->setDatatype(datatype);
}
}
void Aidge::GraphView::updateOutputNodes() {
mOutputNodes.clear();
for (const std::shared_ptr<Node>& go_it : mNodes) {
if (go_it->nbOutputs() !=
go_it->nbValidOutputs()) { // an output linked to nothing
mOutputNodes.insert(go_it);
continue;
}
for (const std::shared_ptr<Node>& ch_ptr : go_it->getChildren()) {
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
mOutputNodes.insert(go_it);
break;
}
}
}
}
void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) {
if (node->nbOutputs() !=
node->nbValidOutputs()) { // an output linked to nothing
mOutputNodes.insert(node);
} else { // don't enter if was already added to outputNodes
for (const std::shared_ptr<Node> &ch_ptr : node->getChildren()) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
mOutputNodes.insert(node);
break;
}
}
}
// update other outputNodes
for (const std::shared_ptr<Node> &pa_ptr :
node->getParents()) { // check if any parent is in OutputNodes too
if ((pa_ptr != nullptr) &&
(mOutputNodes.find(pa_ptr) !=
mOutputNodes.end())) { // it's a match! Must check if the outputNode
// found is still an outputNode
bool remove = (pa_ptr->nbOutputs() == pa_ptr->nbValidOutputs());
for (const std::shared_ptr<Node>& ch_ptr : pa_ptr->getChildren()) {
if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
remove = false;
break;
}
}
if (remove) {
mOutputNodes.erase(pa_ptr);
}
}
}
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
Olivier BICHLER
committed
outsideOutputs;
for (const std::shared_ptr<Node>& outputNode : mOutputNodes) {
Olivier BICHLER
committed
const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
outputNodeOutputs = outputNode->outputs();
for (const auto& outputPos : outputNodeOutputs) {
// Keep only the nodes connected at this output position that are outside the GraphView
std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos;
for (const auto& output : outputPos) {
if (mNodes.find(output.first) == mNodes.end()) {
outsideOutputPos.push_back(output);
}
}
outsideOutputs.push_back(outsideOutputPos);
}
Olivier BICHLER
committed
return outsideOutputs;
}
std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs(std::string nodeName) const {
return mNodeRegistry.at(nodeName)->outputs();
}
void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
Aidge::IOIndex_t /*newNodeOutID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
// add to the GraphView nodes
node->addView(shared_from_this());
mNodes.insert(node);
if (!(node->name()).empty())
mNodeRegistry.insert(std::make_pair(node->name(), node));
// add learnable parameters to the graph
if (includeLearnableParam) {
for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) {
std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
if (parentNode) {
parentNode->addView(shared_from_this());
mNodes.insert(parentNode);
if (!(parentNode->name()).empty())
mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode));
// check if the Node is an input node
updateInputNodes(parentNode);
}
}
}
// check if the Node is an input node
updateInputNodes(node);
// check if the Node is an input node
updateOutputNodes(node);
}
void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
for (auto& nodePtr : otherNodes) { add(nodePtr, includeLearnableParam); }
}
void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
for (const std::shared_ptr<Node> &node_ptr : graph->getNodes()) {
node_ptr->addView(shared_from_this());
mNodes.insert(node_ptr);
if (!(node_ptr->name()).empty())
mNodeRegistry.insert(std::make_pair(node_ptr->name(), node_ptr));
// if node_ptr is part of graph inputNodes or outputNodes
// if (graph->isInputNode(node_ptr) || graph->isOutputNode(node_ptr)) {
// Update OutputNodes/inputNodes
updateInputNodes();
updateOutputNodes();
}
}
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
std::shared_ptr<Node> fromOutNode,
const Aidge::IOIndex_t fromTensor,
Aidge::IOIndex_t toTensor) {
if (fromOutNode)
assert(inView(fromOutNode) && "Output Node not found in the GraphView.");
else {
assert((outputNodes().size() == 1U) &&
"Must specify an outputNode or have only one.");
fromOutNode = *(outputNodes().begin());
}
fromOutNode->addChild(toOtherNode, fromTensor, toTensor);
add(toOtherNode);
}
void Aidge::GraphView::addChild(
std::shared_ptr<GraphView> toOtherView,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode,
std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) {
// assert output node is valid
if (!fromOutNode.first) {
assert(outputNodes().size() == 1U &&
"If no output node is provided, the graph should have only one to "
"make the choice explicit.");
fromOutNode.first = *(outputNodes().begin());
} else
assert(inView(fromOutNode.first));
// assert input node is valid
if (!toNode.first) {
assert(toOtherView->inputNodes().size() == 1U &&
"If no intput node is provided, the other graph should have only "
"one to make the choice explicit.");
toNode.first = *(toOtherView->inputNodes().begin());
} else {
assert(toOtherView->inView(toNode.first));
}
// Tensor assertions are performed in the Node adChild method
fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second);
// once linking performed, add other graph to current graph
add(toOtherView);
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const {
// TODO: choose if we return a set or a vector
std::set<std::shared_ptr<Node>> parents;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
parents.insert(inputNode->getParents().begin(),
inputNode->getParents().end());
}
return parents;
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str());
exit(-1);
}
return (it->second)->getParents();
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getOrderedParents() const {
std::vector<std::vector<std::shared_ptr<Node>>> parents;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
parents.push_back(inputNode->getParents());
}
return parents;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const {
std::set<std::shared_ptr<Node>> children;
for (const std::shared_ptr<Node>& outputNode : mOutputNodes) {
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
children.insert((outputNode->getChildren()).begin(),
(outputNode->getChildren()).end());
}
return children;
}
std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getChildren(const std::string nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
mNodeRegistry.find(nodeName);
if (it == mNodeRegistry.end()) {
printf("No such node a %s in %s graph.\n", nodeName.c_str(),
name().c_str());
exit(-1);
}
return (it->second)->getOrderedChildren();
}
std::set<std::shared_ptr<Aidge::Node>>
Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const {
std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode);
if (it == mNodes.end()) {
printf("No such node in graph.\n");
exit(-1);
}
return (*it)->getChildren();
}
std::shared_ptr<Aidge::Node>
Aidge::GraphView::getNode(const std::string& nodeName) const {
std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
if (it != mNodeRegistry.end()) {
return it->second;
} else {
printf("No Node named %s in the current GraphView.\n", nodeName.c_str());
exit(-1);
}
}
void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) {
if (mNodes.find(nodePtr) != mNodes.end()) {
mNodes.erase(nodePtr);
nodePtr->removeView(shared_from_this());
}
if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }
// same for learnable params
if (includeLearnableParam) {
for (IOIndex_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) {
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
auto inputI = nodePtr->input(i);
bool removeNode = true;
for (const auto& parentOutput : inputI.first->outputs()) {
for (const auto& childOfParentOutput : parentOutput) {
// only remove the learnable parameter if not related to any other Node in the GraphView
if (childOfParentOutput.first != nodePtr) {
removeNode = false;
break;
}
}
}
if (removeNode) {
// assert Learnable Parameter in the GraphView scope
if (mNodes.find(inputI.first) != mNodes.end()) {
mNodes.erase(inputI.first);
inputI.first->removeView(shared_from_this());
}
if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); }
}
}
}
updateInputNodes();
updateOutputNodes();
}
bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) {
printf("Swap() not implementated yet. Return false.\n");
return false;
}
void Aidge::GraphView::link(std::string /*name1_inID*/,
std::string /*name2_outID*/) {
printf("Not implemented yet.\n");
}
void Aidge::GraphView::insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx){
NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
// Remove child from current parent & current Parent from child
currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);
currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);
add(newParentNode);
bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included
auto oldG = std::make_shared<GraphView>("oldG");
oldG->add(oldNodes, false);
auto newG = std::make_shared<GraphView>("newG");
newG->add(newNodes, false);
if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) {
return false;
}
if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) ||
(newG->outputNodes().size() != 1))) {
return false;
}
// there is at least one inputNode in the old/new GraphView
std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin());
std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin());
// find Node to link to new input Node
//compute number of input for firstPreviousInputNode not in oldNodes set
std::size_t nbExternalInputs = 0;
std::shared_ptr<Node> externalInput = nullptr;
IOIndex_t externalInputId = gk_IODefaultIndex;
for (const auto& input : firstPreviousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG
nbExternalInputs++;
externalInput = input.first;
externalInputId = input.second;
}
}
if (nbExternalInputs > 1) {
AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set");
}
if (oldG->inputNodes().size() > 1){
// one or no input has been identified. Checking every input points to the same source
for (const auto& previousInputNode : oldG->inputNodes()) {
for (const auto& input : previousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) {
if ( (externalInput != input.first) || (externalInputId != input.second) ) {
return false; // an inputNode points to an external Node different from the registered one
}
}
}
}
}
if (firstPreviousOutputNode->nbOutputs() != 1) {
return false;
}
// find Node to replicate output connections
std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin());
auto copyOutputs = firstPreviousOutputNode->outputs();
// manage Views for newNodes
// only keep common views to each node for the new set
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views();
for (const auto& nodePtr : oldNodes) {
const auto nodeView = nodePtr->views();
std::set<std::shared_ptr<GraphView>> intersection;
std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
nodeView.begin(), nodeView.end(),
std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection;
}
commonGraphViews.erase(oldG);
commonGraphViews.erase(newG);
// clean Nodes to replace
// Do not include common Nodes to avoid cleaning Producers linked to newNodes
std::set<std::shared_ptr<Node>> nodesToClean;
std::set_difference(oldNodes.begin(), oldNodes.end(),
newNodes.begin(), newNodes.end(),
std::inserter(nodesToClean, nodesToClean.begin()));
for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
// copy output connections
if (newOutputNode) {
for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
auto outputPairs = copyOutputs[o];
for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second);
}
// copy input connections
if (!newNodes.empty() && externalInput) {
for (const auto& newInputNode : newG->inputNodes()) {
IOIndex_t inputId = 0;
for (const auto& input : newInputNode->inputs()) {
if (newNodes.find(input.first) == newNodes.end()) {
externalInput->addChild(newInputNode, externalInputId, inputId);
}
inputId++;
}
}
}
// insert new Nodes in the right GraphViews
for (const auto& graphPtr : commonGraphViews) {
graphPtr->add(newNodes, false);
if (newNodes.empty()) {
graphPtr->updateInputNodes();
graphPtr->updateOutputNodes();
}
}
for (const auto& node : oldNodes) {
node->removeView(oldG);
}
for (const auto& node : newNodes) {
node->removeView(newG);
}
return true;
}
void Aidge::GraphView::updateInputNodes() {
mInputNodes.clear();
for (const std::shared_ptr<Node>& go_ptr : mNodes) {
for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
mInputNodes.insert(go_ptr);
break;
}
}
}
}
void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) {
// add node_ptr to inputNode if it can
std::size_t filledWithKnownInputs = 0U;
bool wasAdded = mInputNodes.find(node) != mInputNodes.end();
for (const std::shared_ptr<Node>& pa_ptr : node->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
mInputNodes.insert(node);
wasAdded = true;
break;
}
++filledWithKnownInputs;
}
if (filledWithKnownInputs == node->nbInputs() && wasAdded) {
mInputNodes.erase(node);
}
// update other inputNodes
for (const std::shared_ptr<Node>& ch_ptr :
node->getChildren()) { // check if any child is in InputNodes too
if (mInputNodes.find(ch_ptr) !=
mInputNodes.end()) { // it's a match! Must check if the inputNode found
// is still an inputNode
// change here
bool remove = true;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
if (pa_ptr == nullptr ||
mNodes.find(pa_ptr) ==
mNodes
.end()) { // Parent doesn't exist || Parent not in the graph
remove = false;
break;
}
}
if (remove) {
mInputNodes.erase(ch_ptr);
}
}
}
}
void Aidge::GraphView::removeInputNode(const std::string nodeName) {
std::map<std::string, std::shared_ptr<Node>>::iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
const std::shared_ptr<Node> val = (*it).second;
if (mInputNodes.find(val) != mInputNodes.end()) {
mInputNodes.erase(val);
}
}
}
void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
std::map<std::string, std::shared_ptr<Node>>::iterator it =
mNodeRegistry.find(nodeName);
if (it != mNodeRegistry.end()) {
const std::shared_ptr<Node> val = (*it).second;
if (mOutputNodes.find(val) != mOutputNodes.end()) {
mOutputNodes.erase(val);
}
}
}
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);
// Map for old node -> new node correspondance
std::map<NodePtr, NodePtr> oldToNewNodes;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
oldToNewNodes[node_ptr] = cloneNode(node_ptr);
}
// For each node, convert old node -> new node connections
for (auto &oldToNewNode : oldToNewNodes) {
if (oldToNewNode.second == nullptr)
continue; // deleted node
// Add new node to new GraphView
newGraph->add(oldToNewNode.second, false);
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0;
for (auto parent : oldToNewNode.first->inputs()) {
while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph
assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs");
const auto& parents = parent.first->inputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
parent = parents[0];
if (oldToNewNodes[parent.first]) {
oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
}
// Update OutputNodes/inputNodes
newGraph->updateInputNodes();
newGraph->updateOutputNodes();
return newGraph;
}