Skip to content
Snippets Groups Projects
Commit d64f9665 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Untested PoC

parent bc289a20
No related branches found
No related tags found
1 merge request!53GraphView inputs/outputs ordering
Pipeline #34725 failed
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <set> #include <unordered_set>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -41,11 +41,11 @@ private: ...@@ -41,11 +41,11 @@ private:
/// @brief Set of nodes included in the graphview with names /// @brief Set of nodes included in the graphview with names
std::map<std::string, NodePtr> mNodeRegistry; std::map<std::string, NodePtr> mNodeRegistry;
/// @brief Nodes without input link (computable, cached) /// @brief GraphView inputs
std::set<NodePtr> mInputNodes; std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes;
/// @brief Nodes without output link (computable, cached) /// @brief GraphView outputs
std::set<NodePtr> mOutputNodes; std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes;
public: public:
GraphView(std::string name="") GraphView(std::string name="")
...@@ -54,11 +54,21 @@ public: ...@@ -54,11 +54,21 @@ public:
// ctor // ctor
} }
// GraphView(std::set<NodePtr> nodes, std::string name="") /**
// : mName(name) * Construct a GraphView from a set of nodes. The startNode parameters
// { * allows to define a default inputs/ouputs order relative to this node.
// add(nodes); * For two topologically identical graphs, using the same topological node
// } * as starting node will lead to the same topologically ordered inputs/outputs list.
* Otherwise, inputs/outputs order will be arbitrary.
*/
GraphView(std::set<NodePtr> nodes, NodePtr startNode = nullptr, std::string name="")
: mName(name)
{
if (startNode != nullptr) {
add(startNode, false);
}
add(nodes);
}
bool operator==(const GraphView &gv) const bool operator==(const GraphView &gv) const
{ {
...@@ -110,19 +120,35 @@ public: ...@@ -110,19 +120,35 @@ public:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
public: public:
/** @brief Get reference to the set of input Nodes. */ /** @brief Get reference to the set of input Nodes. */
inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; } inline const std::set<NodePtr>& inputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mInputNodes) {
nodes.insert(node.first);
}
return nodes;
}
/** @brief Get reference to the set of output Nodes. */ /** @brief Get reference to the set of output Nodes. */
inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; } inline const std::set<NodePtr>& outputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mOutputNodes) {
nodes.insert(node.first);
}
return nodes;
}
/** @brief Assess if the given Node is an input Node of the GraphView object. */ /** @brief Assess if the given Node is an input Node of the GraphView object. */
inline bool isInputNode(NodePtr nodePtr) const { inline bool isInputNode(NodePtr nodePtr) const {
return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false; const auto nodes = inputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
} }
/** @brief Assess if the given Node is an output Node of the GraphView object. */ /** @brief Assess if the given Node is an output Node of the GraphView object. */
inline bool isOutputNode(NodePtr nodePtr) const { inline bool isOutputNode(NodePtr nodePtr) const {
return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false; const auto nodes = outputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
} }
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
/** /**
* @brief List outside data input connections of the GraphView. * @brief List outside data input connections of the GraphView.
* Data inputs exclude inputs expecting parameters (weights or bias). * Data inputs exclude inputs expecting parameters (weights or bias).
...@@ -357,11 +383,10 @@ public: ...@@ -357,11 +383,10 @@ public:
*/ */
static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes);
void updateInputNodes();
/** /**
* @brief Process from zero the set of output Nodes. * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes with ordered inputs/outputs in a GraphView if possible.
*/ */
void updateOutputNodes(); static bool replace(const std::shared_ptr<GraphView>& oldNodes, const std::shared_ptr<GraphView>& newNodes);
/** /**
* @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones. * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones.
...@@ -415,27 +440,33 @@ private: ...@@ -415,27 +440,33 @@ private:
IOIndex_t getNbDataInputs() const; IOIndex_t getNbDataInputs() const;
/** /**
* @brief Update the set of inputNodes with a new Node, checking if it can be * @brief Update inputs/outputs of the GraphView, with no particular order.
* added and removing any Node not part of mInputNode anymore. * This function DOES NOT preserve inputs/outputs order and should NOT BE USED.
* It is here only to leave time to adapt the replace() function.
*/
[[deprecated]] void updateInputsOutputsNodes();
/**
* @brief Automatically update GraphView inputs/outputs with a new Node, checking if
* it this Node becomes an input/output for the graph and if previous inputs are still
* inputs/outputs after adding this node.
* @param nodePtr * @param nodePtr
*/ */
void updateInputNodes(NodePtr node); void updateInputsOutputsNew(NodePtr newNode);
/** /**
* @brief Update the set of outputNodes with a new Node, checking if it can be * @brief Automatically update GraphView inputs/outputs with a Node removed, checking if
* added and removing any Node not part of mOutputNode anymore. * it this Node was an input/output for the graph and if this node childs become new inputs/outputs
* for the graph.
* @param nodePtr * @param nodePtr
*/ */
void updateOutputNodes(NodePtr node); void updateInputsOutputsDelete(NodePtr deletedNode);
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TOPOLOGY // TOPOLOGY
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes); void _forwardDims(std::set<NodePtr> listNodes);
void removeInputNode(const std::string nodeName);
void removeOutputNode(const std::string nodeName);
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -106,6 +106,34 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { ...@@ -106,6 +106,34 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
// TENSOR MANAGEMENT // TENSOR MANAGEMENT
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) {
AIDGE_ASSERT(inputs.size() > mInputNodes.size(), "too many specified number of inputs");
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes);
for (auto input : inputs) {
auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input);
AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input");
ignoredInputs.erase(it);
}
mInputNodes = inputs;
mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end());
}
void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) {
AIDGE_ASSERT(outputs.size() > mOutputNodes.size(), "too many specified number of outputs");
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes);
for (auto output : outputs) {
auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output);
AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output");
ignoredOutputs.erase(it);
}
mOutputNodes = outputs;
mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end());
}
Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const { Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const {
IOIndex_t nbDataInput = 0; IOIndex_t nbDataInput = 0;
for (const std::shared_ptr<Node> &inNode : inputNodes()) { for (const std::shared_ptr<Node> &inNode : inputNodes()) {
...@@ -128,7 +156,7 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { ...@@ -128,7 +156,7 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
IOIndex_t nbIn = 0; IOIndex_t nbIn = 0;
// Free inputs within the GraphView are logically also free inputs from outside // Free inputs within the GraphView are logically also free inputs from outside
// the GraphView. // the GraphView.
for (const std::shared_ptr<Node>& inputNode : mInputNodes) { for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
nbIn += inputNode->getNbFreeDataInputs(); nbIn += inputNode->getNbFreeDataInputs();
} }
return nbIn; return nbIn;
...@@ -139,7 +167,7 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> ...@@ -139,7 +167,7 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const { Aidge::GraphView::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) { for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->dataInputs(); inputNode->dataInputs();
...@@ -157,7 +185,7 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> ...@@ -157,7 +185,7 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const { Aidge::GraphView::inputs() const {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) { for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->inputs(); inputNode->inputs();
...@@ -242,7 +270,7 @@ void Aidge::GraphView::setDatatype(const DataType &datatype) { ...@@ -242,7 +270,7 @@ void Aidge::GraphView::setDatatype(const DataType &datatype) {
node->getOperator()->setDatatype(datatype); node->getOperator()->setDatatype(datatype);
} }
} }
/*
void Aidge::GraphView::updateOutputNodes() { void Aidge::GraphView::updateOutputNodes() {
mOutputNodes.clear(); mOutputNodes.clear();
for (const std::shared_ptr<Node>& go_it : mNodes) { for (const std::shared_ptr<Node>& go_it : mNodes) {
...@@ -292,13 +320,13 @@ void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) { ...@@ -292,13 +320,13 @@ void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) {
} }
} }
} }
*/
std::vector< std::vector<
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const { Aidge::GraphView::outputs() const {
std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
outsideOutputs; outsideOutputs;
for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
outputNodeOutputs = outputNode->outputs(); outputNodeOutputs = outputNode->outputs();
...@@ -334,6 +362,10 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara ...@@ -334,6 +362,10 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
mNodes.insert(node); mNodes.insert(node);
if (!(node->name()).empty()) if (!(node->name()).empty())
mNodeRegistry.insert(std::make_pair(node->name(), node)); mNodeRegistry.insert(std::make_pair(node->name(), node));
// check if the node is an input/output node
updateInputsOutputsNew(node);
// add learnable parameters to the graph // add learnable parameters to the graph
if (includeLearnableParam) { if (includeLearnableParam) {
for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) { for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) {
...@@ -343,33 +375,74 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara ...@@ -343,33 +375,74 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
mNodes.insert(parentNode); mNodes.insert(parentNode);
if (!(parentNode->name()).empty()) if (!(parentNode->name()).empty())
mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode)); mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode));
// check if the Node is an input node // check if the parentNode is an input/output node
updateInputNodes(parentNode); updateInputsOutputsNew(parentNode);
} }
} }
} }
// check if the Node is an input node
updateInputNodes(node);
// check if the Node is an input node
updateOutputNodes(node);
} }
void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
for (auto& nodePtr : otherNodes) { add(nodePtr, includeLearnableParam); } // List only the nodes that are not already present in current graph
std::set<NodePtr> nodesToAdd;
std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::back_inserter(nodesToAdd));
do {
std::set<NodePtr> nextNodesToAdd;
// Find nodes that are direct parent of current GraphView and add them first
// such that the obtained GraphView inputs list will be the same, regardless
// of the evaluation order of those nodes
// (i.e. one of their child is in current GraphView)
for (const std::shared_ptr<Node> &node_ptr : nodesToAdd) {
for (auto child : node_ptr->getChildren()) {
if (mNodes.find(child) != mNodes.end()) {
nextNodesToAdd.insert(node_ptr);
nodesToAdd.erase(node_ptr);
break;
}
}
}
// If there is no more parent, find nodes that are direct children of current GraphView,
// such that the obtained GraphView outputs list will be the same, regardless
// of the evaluation order of those nodes
// (i.e. one of their parent is in current GraphView)
// TODO: this might be done simultaneously with direct parents, by removing
// the empty() condition, but there might be edge cases that may change
// the resulting inputs/outputs order depending on evaluation order (???)
if (nextNodesToAdd.empty()) {
for (const std::shared_ptr<Node> &node_ptr : nodesToAdd) {
for (auto parent : node_ptr->getParents()) {
if (mNodes.find(parent) != mNodes.end()) {
nextNodesToAdd.insert(node_ptr);
nodesToAdd.erase(node_ptr);
break;
}
}
}
}
// If no node if found, there might be remaining nodes that form an independant sub-graph
// In this case, additionnal inputs/outputs will be added at the end of
// the GraphView inputs/outputs list, in no particular order.
// TODO: we might try to preserve the initial inputs/ouputs relative order of those nodes
// if they actually comes from a GraphView, but I think that would be a far-fetched expectation
// from the users...
if (nextNodesToAdd.empty()) {
nodesToAdd.swap(nextNodesToAdd);
}
// Add selected nodes in the current GraphView, in no particular order
for (auto node_ptr : nextNodesToAdd) {
add(node_ptr, includeLearnableParam);
}
}
while (!nodesToAdd.empty());
} }
void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
for (const std::shared_ptr<Node> &node_ptr : graph->getNodes()) { add(graph->getNodes(), false);
node_ptr->addView(shared_from_this());
mNodes.insert(node_ptr);
if (!(node_ptr->name()).empty())
mNodeRegistry.insert(std::make_pair(node_ptr->name(), node_ptr));
// if node_ptr is part of graph inputNodes or outputNodes
// if (graph->isInputNode(node_ptr) || graph->isOutputNode(node_ptr)) {
// Update OutputNodes/inputNodes
updateInputNodes();
updateOutputNodes();
}
} }
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
...@@ -417,7 +490,7 @@ void Aidge::GraphView::addChild( ...@@ -417,7 +490,7 @@ void Aidge::GraphView::addChild(
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const { std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const {
// TODO: choose if we return a set or a vector // TODO: choose if we return a set or a vector
std::set<std::shared_ptr<Node>> parents; std::set<std::shared_ptr<Node>> parents;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) { for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
parents.insert(inputNode->getParents().begin(), parents.insert(inputNode->getParents().begin(),
inputNode->getParents().end()); inputNode->getParents().end());
} }
...@@ -436,7 +509,7 @@ std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std ...@@ -436,7 +509,7 @@ std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std
std::vector<std::vector<std::shared_ptr<Aidge::Node>>> std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getOrderedParents() const { Aidge::GraphView::getOrderedParents() const {
std::vector<std::vector<std::shared_ptr<Node>>> parents; std::vector<std::vector<std::shared_ptr<Node>>> parents;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) { for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
parents.push_back(inputNode->getParents()); parents.push_back(inputNode->getParents());
} }
return parents; return parents;
...@@ -444,7 +517,7 @@ Aidge::GraphView::getOrderedParents() const { ...@@ -444,7 +517,7 @@ Aidge::GraphView::getOrderedParents() const {
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const { std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const {
std::set<std::shared_ptr<Node>> children; std::set<std::shared_ptr<Node>> children;
for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
children.insert((outputNode->getChildren()).begin(), children.insert((outputNode->getChildren()).begin(),
(outputNode->getChildren()).end()); (outputNode->getChildren()).end());
} }
...@@ -488,13 +561,7 @@ Aidge::GraphView::getNode(const std::string& nodeName) const { ...@@ -488,13 +561,7 @@ Aidge::GraphView::getNode(const std::string& nodeName) const {
void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) { void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) {
if (mNodes.find(nodePtr) != mNodes.end()) { // remove learnable params
mNodes.erase(nodePtr);
nodePtr->removeView(shared_from_this());
}
if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }
// same for learnable params
if (includeLearnableParam) { if (includeLearnableParam) {
for (IOIndex_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) { for (IOIndex_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) {
auto inputI = nodePtr->input(i); auto inputI = nodePtr->input(i);
...@@ -515,11 +582,21 @@ void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnab ...@@ -515,11 +582,21 @@ void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnab
inputI.first->removeView(shared_from_this()); inputI.first->removeView(shared_from_this());
} }
if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); }
// check if the node was an input/output node
updateInputsOutputsDelete(inputI.first);
} }
} }
} }
updateInputNodes();
updateOutputNodes(); if (mNodes.find(nodePtr) != mNodes.end()) {
mNodes.erase(nodePtr);
nodePtr->removeView(shared_from_this());
}
if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }
// check if the nodePtr was an input/output node
updateInputsOutputsDelete(nodePtr);
} }
...@@ -662,8 +739,8 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s ...@@ -662,8 +739,8 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
for (const auto& graphPtr : commonGraphViews) { for (const auto& graphPtr : commonGraphViews) {
graphPtr->add(newNodes, false); graphPtr->add(newNodes, false);
if (newNodes.empty()) { if (newNodes.empty()) {
graphPtr->updateInputNodes(); // TODO: FIXME: this function should not be called anymore!
graphPtr->updateOutputNodes(); graphPtr->updateInputsOutputsNodes();
} }
} }
...@@ -676,21 +753,216 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s ...@@ -676,21 +753,216 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
return true; return true;
} }
/*
void Aidge::GraphView::updateInputNodes() { void Aidge::GraphView::updateInputNodes() {
mInputNodes.clear(); std::set<std::pair<NodePtr, IOIndex_t>> inputNodes;
for (const std::shared_ptr<Node>& go_ptr : mNodes) { for (const std::shared_ptr<Node>& go_ptr : mNodes) {
size_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) { for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) {
if ((pa_ptr == nullptr) || if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) == (mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph mNodes.end())) { // Parent doesn't exist || Parent not in the graph
mInputNodes.insert(go_ptr); inputNodes.insert(std::make_pair(go_ptr, inputIdx));
}
++inputIdx;
}
}
// Remove inputs that are not input anymore (deleted node or input connected internally)
for (auto it = mInputNodes.begin(); it != mInputNodes.end(); ++it) {
if (inputNodes.find(*it) == inputNodes.end()) {
it = mInputNodes.erase(it);
}
}
// Add remaining new inputs
for (auto inputNode : inputNodes) {
mInputNodes.push_back(inputNode);
}
}
*/
void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end();
// Remove inputs that are not input anymore because connected to newNode
for (auto orderedChilds : newNode->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// Check that newNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.end()) {
std::size_t inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode is connected to it
if (pa_ptr == newNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val);
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
}
++inputIdx;
}
}
}
}
// Check if node inputs are inputs for the GraphView and add them to the input list if so
// Inputs addition order follows node inputs order
// Inputs are inserted at the position of the first input removed
std::size_t inputIdx = 0U;
for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
if ((pa_ptr == nullptr) ||
(mNodes.find(pa_ptr) ==
mNodes.end())) { // Parent doesn't exist || Parent not in the graph
const auto val = std::make_pair(newNode, inputIdx);
// Make sure to not add this input twice, as updateInputsNew() may be
// called several times for the same node.
if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
}
}
++inputIdx;
}
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end();
// Remove outputs that are not output anymore because connected to newNode
std::size_t outputIdx = 0;
for (const std::shared_ptr<Node>& parent : newNode->getParents()) {
// Check that newNode parent is in current GraphView
if (mNodes.find(parent) != mNodes.end()) {
for (auto orderedChilds : parent->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// If newNode is connected to it
if (ch_ptr == newNode) {
const auto val = std::make_pair(parent, outputIdx);
const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val);
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
}
}
}
}
}
++outputIdx;
}
// Check if node outputs are outputs for the GraphView and add them to the output list if so
outputIdx = 0;
for (auto orderedChilds : newNode->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.end()) {
noInsideConnection = false;
break; break;
} }
} }
if (noInsideConnection) {
const auto val = std::make_pair(newNode, outputIdx);
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
}
}
++outputIdx;
}
}
void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) {
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end();
// Check if node inputs were inputs for the GraphView and remove them from the list if so
std::size_t inputIdx = 0;
for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) {
const auto val = std::make_pair(deletedNode, inputIdx);
const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val);
// The first old (removed) input becomes the insertion point for newNode GraphView inputs
if (std::distance(newInputsInsertionPoint, iter) <= 0) {
newInputsInsertionPoint = mInputNodes.erase(iter);
}
else {
mInputNodes.erase(iter);
}
++inputIdx;
}
// Add child node inputs that become GraphView input following the removal of the node
// Inputs addition order follows deletedNode outputs order
for (auto orderedChilds : deletedNode->getOrderedChildren()) {
for (auto ch_ptr : orderedChilds) {
// Check that deletedNode child is in current GraphView
if (mNodes.find(ch_ptr) != mNodes.end()) {
inputIdx = 0;
for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
// If newNode was connected to it
if (pa_ptr == deletedNode) {
const auto val = std::make_pair(ch_ptr, inputIdx);
// Make sure to not add this input twice, as updateInputsNew() may be
// called several times for the same node.
if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
}
}
++inputIdx;
}
}
}
}
std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end();
// Check if node outputs were outputs for the GraphView and remove them from the list if so
std::size_t outputIdx = 0;
for (auto orderedChilds : deletedNode->getOrderedChildren()) {
const auto val = std::make_pair(deletedNode, outputIdx);
const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val);
// The first old (removed) output becomes the insertion point for newNode GraphView outputs
if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
newOutputsInsertionPoint = mOutputNodes.erase(iter);
}
else {
mOutputNodes.erase(iter);
}
++outputIdx;
}
// Add parent node outputs that become GraphView output following the removal of the node
// Outputs addition order follows deletedNode inputs order
for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) {
std::size_t outputIdx = 0;
for (auto orderedChilds : parent->getOrderedChildren()) {
bool noInsideConnection = true;
for (auto ch_ptr : orderedChilds) {
if (mNodes.find(ch_ptr) != mNodes.end()) {
noInsideConnection = false;
break;
}
}
if (noInsideConnection) {
const auto val = std::make_pair(parent, outputIdx);
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
}
}
++outputIdx;
}
} }
} }
/*
void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) { void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) {
// add node_ptr to inputNode if it can // add node_ptr to inputNode if it can
std::size_t filledWithKnownInputs = 0U; std::size_t filledWithKnownInputs = 0U;
...@@ -731,8 +1003,8 @@ void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) { ...@@ -731,8 +1003,8 @@ void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) {
} }
} }
} }
*/
/*
void Aidge::GraphView::removeInputNode(const std::string nodeName) { void Aidge::GraphView::removeInputNode(const std::string nodeName) {
std::map<std::string, std::shared_ptr<Node>>::iterator it = std::map<std::string, std::shared_ptr<Node>>::iterator it =
mNodeRegistry.find(nodeName); mNodeRegistry.find(nodeName);
...@@ -754,7 +1026,7 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { ...@@ -754,7 +1026,7 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
} }
} }
} }
*/
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);
...@@ -770,16 +1042,14 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone ...@@ -770,16 +1042,14 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
if (oldToNewNode.second == nullptr) if (oldToNewNode.second == nullptr)
continue; // deleted node continue; // deleted node
// Add new node to new GraphView
newGraph->add(oldToNewNode.second, false);
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0; size_t parentId = 0;
for (auto parent : oldToNewNode.first->inputs()) { for (auto parent : oldToNewNode.first->inputs()) {
while (oldToNewNodes[parent.first] == nullptr) { while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph // Find next valid parent in line, going backward in the graph
assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); AIDGE_ASSERT(parent.first->getChildren().size() == 1, "deleted nodes in GraphView::clone() cannot have multiple children");
const auto& parents = parent.first->inputs(); AIDGE_ASSERT(parent.first->nbDataInputs() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents");
const auto& parents = parent.first->dataInputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
...@@ -792,6 +1062,8 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone ...@@ -792,6 +1062,8 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
} }
if (oldToNewNodes[parent.first]) { if (oldToNewNodes[parent.first]) {
AIDGE_ASSERT(oldToNewNodes[parent.first]->nbOutputs() == parent.first->nbOutputs(),
"next valid parent after deleted nodes in GraphView::clone() has wrong number of outputs");
oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
} }
...@@ -799,9 +1071,64 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone ...@@ -799,9 +1071,64 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
} }
} }
// Update OutputNodes/inputNodes // Once connected, add each new nodes to new GraphView
newGraph->updateInputNodes(); // This has to be done in a second step to ensure that new GraphView inputs/outputs
newGraph->updateOutputNodes(); // are properly set (otherwise, some node's inputs/outputs may be wrongly registered as
// GraphView inputs/outputs because not yet connected to other nodes)
for (auto &oldToNewNode : oldToNewNodes) {
if (oldToNewNode.second == nullptr)
continue; // deleted node
newGraph->add(oldToNewNode.second, false);
}
// Update cloned graph inputs/outputs order to match initial graph order
auto newInputNodes = mInputNodes;
for (auto it = newInputNodes.begin(); it != newInputNodes.end(); ++it) {
// If input node was removed, find next valid input
while (oldToNewNodes[it->first] == nullptr) {
// Removed node should have only one connected output, otherwise cloning is invalid
AIDGE_INTERNAL_ASSERT(it->first->getChildren().size() == 1);
auto child = *it->first->getChildren().begin();
bool found = false;
std::size_t inputIdx = 0;
for (auto parent : child->getParents()) {
if (parent == it->first) {
it->first = child;
it->second = inputIdx;
found = true;
break;
}
++inputIdx;
}
if (!found) {
it = newInputNodes.erase(it);
break;
}
}
}
newGraph->setOrderedInputs(newInputNodes);
auto newOutputNodes = mOutputNodes;
for (auto it = newOutputNodes.begin(); it != newOutputNodes.end(); ++it) {
// If output node was removed, find previous valid output
while (oldToNewNodes[it->first] == nullptr) {
// Removed node should have only one connected data input, otherwise cloning is invalid
AIDGE_INTERNAL_ASSERT(it->first->nbDataInputs() <= 1);
auto parents = it->first->dataInputs();
if (!parents.empty()) {
*it = parents[0];
}
else {
it = newOutputNodes.erase(it);
break;
}
}
}
newGraph->setOrderedOutputs(newOutputNodes);
return newGraph; return newGraph;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment