From b92b7f38a856f37991198b266754a094e708e34f Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Mon, 27 Nov 2023 17:15:10 +0100 Subject: [PATCH] Improved graph visualization --- include/aidge/graph/GraphView.hpp | 13 ++++++ src/graph/GraphView.cpp | 60 ++++++++++++++++++++------ unit_tests/graph/Test_GraphView.cpp | 67 +++++++++++++++++++++++------ 3 files changed, 116 insertions(+), 24 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 62f6ac11c..5462935be 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -133,6 +133,9 @@ public: void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; }; + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; }; + /** * @brief List outside data input connections of the GraphView. * Data inputs exclude inputs expecting parameters (weights or bias). @@ -255,6 +258,7 @@ public: * in the GraphView automatically. Default: true. */ void add(NodePtr otherNode, bool includeLearnableParam = true); + /** * @brief Include a set of Nodes to the current GraphView object. * @param otherNodes @@ -263,6 +267,15 @@ public: void add(std::set<NodePtr> otherNodes, bool includeLearnableParam = true); + /** + * @brief Include a set of Nodes to the current GraphView object. + * The second element in the otherNodes pair is the start node. + * @param otherNodes + * @param includeLearnableParam + */ + void add(std::pair<NodePtr, std::set<NodePtr>> otherNodes, + bool includeLearnableParam = true); + /** * @brief Include every Node inside another GraphView to the current * GraphView. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 2714484eb..7b21cc889 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -79,23 +79,48 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { givenName.c_str()); } // Write every link - std::size_t emptyInputCounter = 0; for (const std::shared_ptr<Node> &node_ptr : mNodes) { - for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) { - if ((pa_ptr == nullptr) || !inView(pa_ptr)) { - std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter, - emptyInputCounter, namePtrTable[node_ptr].c_str()); - ++emptyInputCounter; - } else { - std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(), - namePtrTable[node_ptr].c_str()); - } + IOIndex_t outputIdx = 0; + for (auto childs : node_ptr->getOrderedChildren()) { + for (auto child : childs) { + if (child) { + IOIndex_t inputIdx = 0; + for (auto pa_ptr : child->getParents()) { + if (pa_ptr == node_ptr) { + std::fprintf(fp, "%s-->|%u..%u|%s\n", namePtrTable[node_ptr].c_str(), + outputIdx, inputIdx, namePtrTable[child].c_str()); + break; + } + ++inputIdx; + } + } } + ++outputIdx; + } + } + + size_t inputIdx = 0; + for (auto input : mInputNodes) { + std::fprintf(fp, "input%lu((in#%lu)):::inputCls-->|..%u|%s\n", inputIdx, inputIdx, + input.second, namePtrTable[input.first].c_str()); + ++inputIdx; } + + size_t outputIdx = 0; + for (auto output : mOutputNodes) { + std::fprintf(fp, "%s-->|%u..|output%lu((out#%lu)):::outputCls\n", + namePtrTable[output.first].c_str(), output.second, + outputIdx, outputIdx); + ++outputIdx; + } + + std::fprintf(fp, "classDef inputCls fill:#afa\n"); + std::fprintf(fp, "classDef outputCls fill:#ffa\n"); + if (verbose) { - for (const auto &c : typeCounter) { + for (const auto &c : typeCounter) { std::printf("%s - %zu\n", c.first.c_str(), c.second); - } + } } std::fprintf(fp, "\n"); @@ -447,6 +472,13 @@ void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl while (!nodesToAdd.empty()); } +void Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) { + if (nodes.first != nullptr) { + add(nodes.first, includeLearnableParam); + } + add(nodes.second, includeLearnableParam); +} + void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { add(graph->getNodes(), false); } @@ -834,6 +866,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { const auto val = std::make_pair(newNode, inputIdx); if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); } } ++inputIdx; @@ -902,6 +935,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { // Output may be already be present (see addChild() with a node already in GraphView) if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); } } ++outputIdx; @@ -940,6 +974,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo const auto val = std::make_pair(ch_ptr, inputIdx); AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()); newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); } ++inputIdx; } @@ -986,6 +1021,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo const auto val = std::make_pair(parent, outputIdx); AIDGE_INTERNAL_ASSERT(std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()); newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); } ++outputIdx; } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 75f9a47fe..a80c7a7aa 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -15,6 +15,8 @@ #include <set> #include <string> #include <random> +#include <algorithm> +#include <utility> #include <catch2/catch_test_macros.hpp> @@ -27,16 +29,28 @@ using namespace Aidge; -std::set<NodePtr> genRandomDAG(size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) { - std::random_device rd; - std::mt19937 gen(rd()); - std::binomial_distribution<> dIn(maxIn, avgIn/maxIn); - std::binomial_distribution<> dOut(maxOut, avgOut/maxOut); +std::pair<NodePtr, std::set<NodePtr>> genRandomDAG(std::mt19937::result_type seed, size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) { + std::mt19937 gen(seed); + std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn); + std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut); std::binomial_distribution<> dLink(1, density); - std::vector<NodePtr> nodes; + std::vector<std::pair<int, int>> nbIOs; for (size_t i = 0; i < nbNodes; ++i) { - nodes.push_back(GenericOperator("Fictive", dIn(gen), dIn(gen), dOut(gen))); + const auto nbIn = 1 + dIn(gen); + nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen))); + } + + std::vector<int> nodesSeq(nbNodes); + std::iota(nodesSeq.begin(), nodesSeq.end(), 0); + // Don't use gen or seed here, must be different each time! + std::shuffle(nodesSeq.begin(), nodesSeq.end(), std::default_random_engine(std::random_device{}())); + + std::vector<NodePtr> nodes(nbNodes, nullptr); + for (auto idx : nodesSeq) { + const std::string type = "Fictive"; + const std::string name = type + std::to_string(idx); + nodes[idx] = GenericOperator(type.c_str(), nbIOs[idx].first, nbIOs[idx].first, nbIOs[idx].second, name.c_str()); } for (size_t i = 0; i < nbNodes; ++i) { @@ -45,20 +59,49 @@ std::set<NodePtr> genRandomDAG(size_t nbNodes, float density = 0.5, size_t maxIn for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { if (dLink(gen)) { nodes[i]->addChild(nodes[j], outId, inId); + break; } } } } } - return std::set<NodePtr>(nodes.begin(), nodes.end()); + return std::make_pair(nodes[0], std::set<NodePtr>(nodes.begin(), nodes.end())); +} + +std::set<std::string> nodePtrToName(const std::set<NodePtr>& nodes) { + std::set<std::string> nodesName; + std::transform(nodes.begin(), nodes.end(), std::inserter(nodesName, nodesName.begin()), + [](const NodePtr& node) { + return node->name(); + }); + return nodesName; +} + +std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes) { + std::vector<std::pair<std::string, IOIndex_t>> nodesName; + std::transform(nodes.begin(), nodes.end(), std::back_inserter(nodesName), + [](const std::pair<NodePtr, IOIndex_t>& node) { + return std::make_pair(node.first->name(), node.second); + }); + return nodesName; } TEST_CASE("genRandomDAG") { - auto g = std::make_shared<GraphView>(); - g->add(genRandomDAG(10)); - REQUIRE(g->getNodes().size() == 10); - g->save("./genRandomDAG"); + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + auto g1 = std::make_shared<GraphView>(); + g1->add(genRandomDAG(seed, 10, 0.5)); + auto g2 = std::make_shared<GraphView>(); + g2->add(genRandomDAG(seed, 10, 0.5)); + + g1->save("./genRandomDAG1"); + g2->save("./genRandomDAG2"); + + REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes())); + REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs())); + REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs())); } TEST_CASE("[core/graph] GraphView(Constructor)") { -- GitLab