From 1f2d196d6037d56bd5748d9f51a4eca9c1398140 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 28 Nov 2023 23:10:23 +0100 Subject: [PATCH] Working version: node ordering is now well defined --- include/aidge/graph/GraphView.hpp | 16 ++- src/graph/GraphView.cpp | 152 ++++++++++++++++++---------- unit_tests/graph/Test_GraphView.cpp | 34 +++++-- 3 files changed, 137 insertions(+), 65 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 5462935be..df8352bcf 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -35,6 +35,9 @@ private: /// @brief Name of the graphview std::string mName; + /// @brief GraphView root node + NodePtr mRootNode; + /// @brief Set of nodes included in the GraphView std::set<NodePtr> mNodes; @@ -99,6 +102,10 @@ public: return mNodes.find(nodePtr) != mNodes.end(); } + NodePtr getRootNode() { + return mRootNode; + } + /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// @@ -263,8 +270,9 @@ public: * @brief Include a set of Nodes to the current GraphView object. * @param otherNodes * @param includeLearnableParam + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - void add(std::set<NodePtr> otherNodes, + bool add(std::set<NodePtr> otherNodes, bool includeLearnableParam = true); /** @@ -272,16 +280,18 @@ public: * The second element in the otherNodes pair is the start node. * @param otherNodes * @param includeLearnableParam + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - void add(std::pair<NodePtr, std::set<NodePtr>> otherNodes, + bool add(std::pair<NodePtr, std::set<NodePtr>> otherNodes, bool includeLearnableParam = true); /** * @brief Include every Node inside another GraphView to the current * GraphView. * @param other_graph GraphView containing the Nodes to include. + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - void add(std::shared_ptr<GraphView> otherGraph); + bool add(std::shared_ptr<GraphView> otherGraph); /** * @brief Include a Node in the current GraphView and link it to another diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 499cdcf2c..f3095702b 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -75,15 +75,22 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { : node_ptr->name(); namePtrTable[node_ptr] = (currentType + "_" + std::to_string(typeCounter[currentType])); - std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), - givenName.c_str()); + + if (node_ptr == mRootNode) { + std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } + else { + std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } } // Write every link for (const std::shared_ptr<Node> &node_ptr : mNodes) { IOIndex_t outputIdx = 0; for (auto childs : node_ptr->getOrderedChildren()) { for (auto child : childs) { - if (child) { + if (child != nullptr && mNodes.find(child) != mNodes.end()) { IOIndex_t inputIdx = 0; for (auto pa_ptr : child->getParents()) { if (pa_ptr == node_ptr) { @@ -116,6 +123,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { std::fprintf(fp, "classDef inputCls fill:#afa\n"); std::fprintf(fp, "classDef outputCls fill:#ffa\n"); + std::fprintf(fp, "classDef rootCls stroke:#f00\n"); if (verbose) { for (const auto &c : typeCounter) { @@ -382,6 +390,11 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/, } void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) { + // first node to be added to the graph is the root node by default + if (mRootNode == nullptr) { + mRootNode = node; + } + // add to the GraphView nodes node->addView(shared_from_this()); mNodes.insert(node); @@ -407,80 +420,117 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara } } -void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { +bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { + if (otherNodes.empty()) { + return true; + } + + bool orderUnicity = true; + // List only the nodes that are not already present in current graph std::set<NodePtr> nodesToAdd; std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin())); - do { - std::set<NodePtr> nextNodesToAdd; - - // Find nodes that are direct parent of current GraphView and add them first - // such that the obtained GraphView inputs list will be the same, regardless - // of the evaluation order of those nodes - // (i.e. one of their child is in current GraphView) - for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) { - for (auto child : (*it)->getChildren()) { - if (mNodes.find(child) != mNodes.end()) { - nextNodesToAdd.insert(*it); - it = nodesToAdd.erase(it); + // List the nodes to rank, initially all the nodes in the GraphView + std::set<NodePtr> nodesToRank(mNodes); + nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end()); + std::vector<NodePtr> rankedNodesToAdd; + + if (mRootNode == nullptr) { + std::set<NodePtr> noParentNodes; + + // If no root node is defined, check nodes without parents + for (auto node : nodesToRank) { + bool noParent = true; + for (auto parent : node->getParents()) { + if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) { + noParent = false; break; } } - if (it == nodesToAdd.end()) { - break; + + if (noParent) { + noParentNodes.insert(node); } } - // If there is no more parent, find nodes that are direct children of current GraphView, - // such that the obtained GraphView outputs list will be the same, regardless - // of the evaluation order of those nodes - // (i.e. one of their parent is in current GraphView) - // TODO: this might be done simultaneously with direct parents, by removing - // the empty() condition, but there might be edge cases that may change - // the resulting inputs/outputs order depending on evaluation order (???) - if (nextNodesToAdd.empty()) { - for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) { - for (auto parent : (*it)->getParents()) { - if (mNodes.find(parent) != mNodes.end()) { - nextNodesToAdd.insert(*it); - it = nodesToAdd.erase(it); - break; + // Take the first one found (this is an arbitrary choice) + mRootNode = *noParentNodes.begin(); + + if (noParentNodes.size() > 1) { + // If there is more than one, order unicity cannot be garanteed! + orderUnicity = false; + } + + rankedNodesToAdd.push_back(mRootNode); + } + + nodesToRank.erase(mRootNode); + std::vector<NodePtr> rankedNodes; + rankedNodes.push_back(mRootNode); + + for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) { + NodePtr curNode = rankedNodes[curNodeIdx]; + + for (auto childs : curNode->getOrderedChildren()) { + for (auto child : childs) { + if (nodesToRank.find(child) != nodesToRank.end()) { + rankedNodes.push_back(child); + nodesToRank.erase(child); + + if (nodesToAdd.find(child) != nodesToAdd.end()) { + rankedNodesToAdd.push_back(child); + nodesToAdd.erase(child); } } - if (it == nodesToAdd.end()) { - break; - } } } - // If no node if found, there might be remaining nodes that form an independant sub-graph - // In this case, additionnal inputs/outputs will be added at the end of - // the GraphView inputs/outputs list, in no particular order. - // TODO: we might try to preserve the initial inputs/ouputs relative order of those nodes - // if they actually comes from a GraphView, but I think that would be a far-fetched expectation - // from the users... - if (nextNodesToAdd.empty()) { - nodesToAdd.swap(nextNodesToAdd); + for (auto parent : curNode->getParents()) { + if (nodesToRank.find(parent) != nodesToRank.end()) { + rankedNodes.push_back(parent); + nodesToRank.erase(parent); + + if (nodesToAdd.find(parent) != nodesToAdd.end()) { + rankedNodesToAdd.push_back(parent); + nodesToAdd.erase(parent); + } + } } + } - // Add selected nodes in the current GraphView, in no particular order - for (auto node_ptr : nextNodesToAdd) { - add(node_ptr, includeLearnableParam); + if (!nodesToAdd.empty()) { + // There are remaining nodes without path to the root node + orderUnicity = false; + + while (!nodesToAdd.empty()) { + const auto it = nodesToAdd.begin(); + rankedNodesToAdd.push_back(*it); + nodesToAdd.erase(it); } } - while (!nodesToAdd.empty()); + + for (auto node_ptr : rankedNodesToAdd) { + add(node_ptr, includeLearnableParam); + } + + return orderUnicity; } -void Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) { +bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) { if (nodes.first != nullptr) { + mRootNode = nodes.first; add(nodes.first, includeLearnableParam); } - add(nodes.second, includeLearnableParam); + return add(nodes.second, includeLearnableParam); } -void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { - add(graph->getNodes(), false); +bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { + if (mRootNode == nullptr) { + mRootNode = graph->getRootNode(); + } + + return add(graph->getNodes(), false); } void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index a80c7a7aa..8da67e784 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -88,20 +88,32 @@ std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<s TEST_CASE("genRandomDAG") { - std::random_device rd; - const std::mt19937::result_type seed(rd()); + const size_t nbTests = 100; + size_t nbUnicity = 0; - auto g1 = std::make_shared<GraphView>(); - g1->add(genRandomDAG(seed, 10, 0.5)); - auto g2 = std::make_shared<GraphView>(); - g2->add(genRandomDAG(seed, 10, 0.5)); + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); - g1->save("./genRandomDAG1"); - g2->save("./genRandomDAG2"); + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(genRandomDAG(seed, 10, 0.5)); + const auto g2 = std::make_shared<GraphView>("g2"); + const bool unicity2 = g2->add(genRandomDAG(seed, 10, 0.5)); - REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes())); - REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs())); - REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs())); + g1->save("./genRandomDAG1"); + g2->save("./genRandomDAG2"); + + REQUIRE(unicity1 == unicity2); + + if (unicity1) { + REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes())); + REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs())); + REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs())); + ++nbUnicity; + } + } + + printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests); } TEST_CASE("[core/graph] GraphView(Constructor)") { -- GitLab