diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 4ed9cd106615be6d664e06dbf933e1d4342f95c4..eaf23d88efa0dd367047d5fd3881a0f52240f3ee 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -120,7 +120,7 @@ public: /////////////////////////////////////////////////////// public: /** @brief Get reference to the set of input Nodes. */ - inline const std::set<NodePtr>& inputNodes() const noexcept { + inline std::set<NodePtr> inputNodes() const noexcept { std::set<NodePtr> nodes; for (auto node : mInputNodes) { nodes.insert(node.first); @@ -128,7 +128,7 @@ public: return nodes; } /** @brief Get reference to the set of output Nodes. */ - inline const std::set<NodePtr>& outputNodes() const noexcept { + inline std::set<NodePtr> outputNodes() const noexcept { std::set<NodePtr> nodes; for (auto node : mOutputNodes) { nodes.insert(node.first); diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index e74ae55e676c87ee129b8b490bf0cdd9d12859b3..87a90c74174cc8c4a3df4aa0e45e6b565aa80c59 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -47,7 +47,7 @@ public: mGraph(op.mGraph->clone()) { // cpy-ctor - // TODO: FIXME: mInputOps and mOutputOps are not populated! + // TODO: FIXME: mInputNodes and mOutputNodes are not populated! // Issue: how to map new (cloned) nodes with old nodes? getNodes() does // not garantee any order! Check issue #52. } @@ -71,8 +71,8 @@ public: void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); - const auto& inputOp = mInputOps[inputIdx]; - inputOp.first->associateInput(inputOp.second, data); + const auto& inputOp = mInputNodes[inputIdx]; + inputOp.first->getOperator()->associateInput(inputOp.second, data); // Associate inputs for custom implementation mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); @@ -83,9 +83,9 @@ public: mGraph->forwardDims(); // Associate outputs to micro-graph outputs for custom implementation - for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { - const auto& outputOp = mOutputOps[outputIdx]; - mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); + for (size_t outputIdx = 0; outputIdx < mOutputNodes.size(); ++outputIdx) { + const auto& outputOp = mOutputNodes[outputIdx]; + mOutputs[outputIdx] = outputOp.first->getOperator()->getOutput(outputOp.second); } } diff --git a/src/graph/Connector.cpp b/src/graph/Connector.cpp index cd2ceff8b58076a5054269e4676120b94c8b5beb..cf3b18dc5f17a4213a32066a3d244d5bed6f77e5 100644 --- a/src/graph/Connector.cpp +++ b/src/graph/Connector.cpp @@ -26,7 +26,15 @@ Aidge::Connector::Connector(std::shared_ptr<Aidge::Node> node) { Aidge::IOIndex_t Aidge::Connector::size() const { return mNode->nbOutputs(); } -std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ctors) { +std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ctors) { + std::set<NodePtr> nodesToAdd; + for (const Connector& ctor : ctors) { + nodesToAdd.insert(ctor.node()); + } + return std::make_shared<GraphView>(nodesToAdd, ctors.back().node()); + + // TODO: FIXME: don't understand the following code! +/* std::shared_ptr<GraphView> graph = std::make_shared<GraphView>(); std::vector<std::shared_ptr<Node>> nodesToAdd = std::vector<std::shared_ptr<Node>>(); for (const Connector& ctor : ctors) { @@ -51,4 +59,5 @@ std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ct buffer = {}; } return graph; +*/ } \ No newline at end of file diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 3854e5b05aded73c92f5acc99d1e30ef6306fd23..3c6a39921a3eb326a1c276ea4733f2e99b88f2e4 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -385,7 +385,7 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { // List only the nodes that are not already present in current graph std::set<NodePtr> nodesToAdd; - std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::back_inserter(nodesToAdd)); + std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin())); do { std::set<NodePtr> nextNodesToAdd; @@ -394,14 +394,17 @@ void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl // such that the obtained GraphView inputs list will be the same, regardless // of the evaluation order of those nodes // (i.e. one of their child is in current GraphView) - for (const std::shared_ptr<Node> &node_ptr : nodesToAdd) { - for (auto child : node_ptr->getChildren()) { + for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) { + for (auto child : (*it)->getChildren()) { if (mNodes.find(child) != mNodes.end()) { - nextNodesToAdd.insert(node_ptr); - nodesToAdd.erase(node_ptr); + nextNodesToAdd.insert(*it); + it = nodesToAdd.erase(it); break; } } + if (it == nodesToAdd.end()) { + break; + } } // If there is no more parent, find nodes that are direct children of current GraphView, @@ -412,14 +415,17 @@ void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl // the empty() condition, but there might be edge cases that may change // the resulting inputs/outputs order depending on evaluation order (???) if (nextNodesToAdd.empty()) { - for (const std::shared_ptr<Node> &node_ptr : nodesToAdd) { - for (auto parent : node_ptr->getParents()) { + for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) { + for (auto parent : (*it)->getParents()) { if (mNodes.find(parent) != mNodes.end()) { - nextNodesToAdd.insert(node_ptr); - nodesToAdd.erase(node_ptr); + nextNodesToAdd.insert(*it); + it = nodesToAdd.erase(it); break; } } + if (it == nodesToAdd.end()) { + break; + } } } @@ -783,6 +789,9 @@ void Aidge::GraphView::updateInputNodes() { */ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { + // Can be called several times with the same node, e.g. when addChild() is + // called on a node already part of the GraphView. In this case, inputs/outputs + // need to be updated! std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end(); // Remove inputs that are not input anymore because connected to newNode @@ -790,19 +799,22 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { for (auto ch_ptr : orderedChilds) { // Check that newNode child is in current GraphView if (mNodes.find(ch_ptr) != mNodes.end()) { - std::size_t inputIdx = 0; + IOIndex_t inputIdx = 0; for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { // If newNode is connected to it if (pa_ptr == newNode) { const auto val = std::make_pair(ch_ptr, inputIdx); const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); - // The first old (removed) input becomes the insertion point for newNode GraphView inputs - if (std::distance(newInputsInsertionPoint, iter) <= 0) { - newInputsInsertionPoint = mInputNodes.erase(iter); - } - else { - mInputNodes.erase(iter); + // Check that it was not already the case (if node UPDATE) + if (iter != mInputNodes.end()) { + // The first old (removed) input becomes the insertion point for newNode GraphView inputs + if (std::distance(newInputsInsertionPoint, iter) <= 0) { + newInputsInsertionPoint = mInputNodes.erase(iter); + } + else { + mInputNodes.erase(iter); + } } } ++inputIdx; @@ -814,16 +826,30 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { // Check if node inputs are inputs for the GraphView and add them to the input list if so // Inputs addition order follows node inputs order // Inputs are inserted at the position of the first input removed - std::size_t inputIdx = 0U; + IOIndex_t inputIdx = 0U; for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { if ((pa_ptr == nullptr) || (mNodes.find(pa_ptr) == mNodes.end())) { // Parent doesn't exist || Parent not in the graph const auto val = std::make_pair(newNode, inputIdx); - // Make sure to not add this input twice, as updateInputsNew() may be - // called several times for the same node. - if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { - newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()); + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + } + ++inputIdx; + } + + // (if node UPDATE) + // newNode may already exists in the graph and may have been updated + // Check and remove inputs that are not inputs anymore + inputIdx = 0U; + for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { + if ((pa_ptr != nullptr) && + (mNodes.find(pa_ptr) != + mNodes.end())) { + const auto val = std::make_pair(newNode, inputIdx); + auto it = std::find(mInputNodes.begin(), mInputNodes.end(), val); + if (it != mInputNodes.end()) { + mInputNodes.erase(it); } } ++inputIdx; @@ -832,33 +858,35 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end(); // Remove outputs that are not output anymore because connected to newNode - std::size_t outputIdx = 0; for (const std::shared_ptr<Node>& parent : newNode->getParents()) { // Check that newNode parent is in current GraphView if (mNodes.find(parent) != mNodes.end()) { for (auto orderedChilds : parent->getOrderedChildren()) { + IOIndex_t outputIdx = 0; for (auto ch_ptr : orderedChilds) { // If newNode is connected to it if (ch_ptr == newNode) { const auto val = std::make_pair(parent, outputIdx); const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val); - // The first old (removed) output becomes the insertion point for newNode GraphView outputs - if (std::distance(newOutputsInsertionPoint, iter) <= 0) { - newOutputsInsertionPoint = mOutputNodes.erase(iter); - } - else { - mOutputNodes.erase(iter); + if (iter != mOutputNodes.end()) { + // The first old (removed) output becomes the insertion point for newNode GraphView outputs + if (std::distance(newOutputsInsertionPoint, iter) <= 0) { + newOutputsInsertionPoint = mOutputNodes.erase(iter); + } + else { + mOutputNodes.erase(iter); + } } } } + ++outputIdx; } } - ++outputIdx; } // Check if node outputs are outputs for the GraphView and add them to the output list if so - outputIdx = 0; + IOIndex_t outputIdx = 0; for (auto orderedChilds : newNode->getOrderedChildren()) { bool noInsideConnection = true; for (auto ch_ptr : orderedChilds) { @@ -870,6 +898,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { if (noInsideConnection) { const auto val = std::make_pair(newNode, outputIdx); + // Output may be already be present (see addChild() with a node already in GraphView) if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); } @@ -882,19 +911,19 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end(); // Check if node inputs were inputs for the GraphView and remove them from the list if so - std::size_t inputIdx = 0; - for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { + for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) { const auto val = std::make_pair(deletedNode, inputIdx); const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); - // The first old (removed) input becomes the insertion point for newNode GraphView inputs - if (std::distance(newInputsInsertionPoint, iter) <= 0) { - newInputsInsertionPoint = mInputNodes.erase(iter); - } - else { - mInputNodes.erase(iter); + if (iter != mInputNodes.end()) { + // The first old (removed) input becomes the insertion point for newNode GraphView inputs + if (std::distance(newInputsInsertionPoint, iter) <= 0) { + newInputsInsertionPoint = mInputNodes.erase(iter); + } + else { + mInputNodes.erase(iter); + } } - ++inputIdx; } // Add child node inputs that become GraphView input following the removal of the node @@ -903,16 +932,13 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo for (auto ch_ptr : orderedChilds) { // Check that deletedNode child is in current GraphView if (mNodes.find(ch_ptr) != mNodes.end()) { - inputIdx = 0; + IOIndex_t inputIdx = 0; for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { // If newNode was connected to it if (pa_ptr == deletedNode) { const auto val = std::make_pair(ch_ptr, inputIdx); - // Make sure to not add this input twice, as updateInputsNew() may be - // called several times for the same node. - if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { - newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); - } + AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()); + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); } ++inputIdx; } @@ -923,25 +949,29 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end(); // Check if node outputs were outputs for the GraphView and remove them from the list if so - std::size_t outputIdx = 0; - for (auto orderedChilds : deletedNode->getOrderedChildren()) { + for (IOIndex_t outputIdx = 0; outputIdx < deletedNode->getOrderedChildren().size(); ++outputIdx) { const auto val = std::make_pair(deletedNode, outputIdx); const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val); - // The first old (removed) output becomes the insertion point for newNode GraphView outputs - if (std::distance(newOutputsInsertionPoint, iter) <= 0) { - newOutputsInsertionPoint = mOutputNodes.erase(iter); - } - else { - mOutputNodes.erase(iter); + if (iter != mOutputNodes.end()) { + // The first old (removed) output becomes the insertion point for newNode GraphView outputs + if (std::distance(newOutputsInsertionPoint, iter) <= 0) { + newOutputsInsertionPoint = mOutputNodes.erase(iter); + } + else { + mOutputNodes.erase(iter); + } } - ++outputIdx; } // Add parent node outputs that become GraphView output following the removal of the node // Outputs addition order follows deletedNode inputs order for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { - std::size_t outputIdx = 0; + if (parent == nullptr) { + continue; + } + + IOIndex_t outputIdx = 0; for (auto orderedChilds : parent->getOrderedChildren()) { bool noInsideConnection = true; for (auto ch_ptr : orderedChilds) { @@ -953,10 +983,42 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo if (noInsideConnection) { const auto val = std::make_pair(parent, outputIdx); - if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { - newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + AIDGE_INTERNAL_ASSERT(std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()); + newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + } + ++outputIdx; + } + } +} + +void Aidge::GraphView::updateInputsOutputsNodes() { + mInputNodes.clear(); + for (const std::shared_ptr<Node>& go_ptr : mNodes) { + IOIndex_t inputIdx = 0; + for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) { + if ((pa_ptr == nullptr) || + (mNodes.find(pa_ptr) == + mNodes.end())) { // Parent doesn't exist || Parent not in the graph + mInputNodes.push_back(std::make_pair(go_ptr, inputIdx)); + } + + ++inputIdx; + } + } + + mOutputNodes.clear(); + for (const std::shared_ptr<Node>& go_ptr : mNodes) { + IOIndex_t outputIdx = 0; + for (auto orderedChilds : go_ptr->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph + mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx)); } } + if (orderedChilds.empty()) { + // an output linked to nothing + mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx)); + } ++outputIdx; } } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index dbba1a7d698641d0858f6c3d2f15c4c7ff610261..7da7cf164bd1dfdcd12a7788d7904b33d07187f2 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -14,6 +14,7 @@ #include <memory> #include <set> #include <string> +#include <random> #include <catch2/catch_test_macros.hpp> @@ -26,6 +27,39 @@ using namespace Aidge; +std::set<NodePtr> genRandomDAG(size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) { + std::random_device rd; + std::mt19937 gen(rd()); + std::binomial_distribution<> dIn(maxIn, avgIn/maxIn); + std::binomial_distribution<> dOut(maxOut, avgOut/maxOut); + std::binomial_distribution<> dLink(1, density); + + std::vector<NodePtr> nodes; + for (size_t i = 0; i < nbNodes; ++i) { + nodes.push_back(GenericOperator("Fictive", dIn(gen), dIn(gen), dOut(gen))); + } + + for (size_t i = 0; i < nbNodes; ++i) { + for (size_t j = i + 1; j < nbNodes; ++j) { + for (size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) { + for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { + if (dLink(gen)) { + nodes[i]->addChild(nodes[j], outId, inId); + } + } + } + } + } + return std::set<NodePtr>(nodes.begin(), nodes.end()); +} + + +TEST_CASE("genRandomDAG") { + auto g = std::make_shared<GraphView>(genRandomDAG(10)); + REQUIRE(g->getNodes().size() == 10); + g->save("./genRandomDAG"); +} + TEST_CASE("[core/graph] GraphView(Constructor)") { std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1");