Skip to content
Snippets Groups Projects
Commit b92b7f38 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Improved graph visualization

parent c1ab7a5e
No related branches found
No related tags found
No related merge requests found
......@@ -133,6 +133,9 @@ public:
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; };
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; };
/**
* @brief List outside data input connections of the GraphView.
* Data inputs exclude inputs expecting parameters (weights or bias).
......@@ -255,6 +258,7 @@ public:
* in the GraphView automatically. Default: true.
*/
void add(NodePtr otherNode, bool includeLearnableParam = true);
/**
* @brief Include a set of Nodes to the current GraphView object.
* @param otherNodes
......@@ -263,6 +267,15 @@ public:
void add(std::set<NodePtr> otherNodes,
bool includeLearnableParam = true);
/**
* @brief Include a set of Nodes to the current GraphView object.
* The second element in the otherNodes pair is the start node.
* @param otherNodes
* @param includeLearnableParam
*/
void add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
bool includeLearnableParam = true);
/**
* @brief Include every Node inside another GraphView to the current
* GraphView.
......
......@@ -79,23 +79,48 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
givenName.c_str());
}
// Write every link
std::size_t emptyInputCounter = 0;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) {
if ((pa_ptr == nullptr) || !inView(pa_ptr)) {
std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter,
emptyInputCounter, namePtrTable[node_ptr].c_str());
++emptyInputCounter;
} else {
std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(),
namePtrTable[node_ptr].c_str());
}
IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) {
if (child) {
IOIndex_t inputIdx = 0;
for (auto pa_ptr : child->getParents()) {
if (pa_ptr == node_ptr) {
std::fprintf(fp, "%s-->|%u..%u|%s\n", namePtrTable[node_ptr].c_str(),
outputIdx, inputIdx, namePtrTable[child].c_str());
break;
}
++inputIdx;
}
}
}
++outputIdx;
}
}
size_t inputIdx = 0;
for (auto input : mInputNodes) {
std::fprintf(fp, "input%lu((in#%lu)):::inputCls-->|..%u|%s\n", inputIdx, inputIdx,
input.second, namePtrTable[input.first].c_str());
++inputIdx;
}
size_t outputIdx = 0;
for (auto output : mOutputNodes) {
std::fprintf(fp, "%s-->|%u..|output%lu((out#%lu)):::outputCls\n",
namePtrTable[output.first].c_str(), output.second,
outputIdx, outputIdx);
++outputIdx;
}
std::fprintf(fp, "classDef inputCls fill:#afa\n");
std::fprintf(fp, "classDef outputCls fill:#ffa\n");
if (verbose) {
for (const auto &c : typeCounter) {
for (const auto &c : typeCounter) {
std::printf("%s - %zu\n", c.first.c_str(), c.second);
}
}
}
std::fprintf(fp, "\n");
......@@ -447,6 +472,13 @@ void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl
while (!nodesToAdd.empty());
}
void Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
if (nodes.first != nullptr) {
add(nodes.first, includeLearnableParam);
}
add(nodes.second, includeLearnableParam);
}
void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
add(graph->getNodes(), false);
}
......@@ -834,6 +866,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
const auto val = std::make_pair(newNode, inputIdx);
if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
newInputsInsertionPoint = std::next(newInputsInsertionPoint);
}
}
++inputIdx;
......@@ -902,6 +935,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
// Output may be already be present (see addChild() with a node already in GraphView)
if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
newOutputsInsertionPoint = std::next(newOutputsInsertionPoint);
}
}
++outputIdx;
......@@ -940,6 +974,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
const auto val = std::make_pair(ch_ptr, inputIdx);
AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end());
newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
newInputsInsertionPoint = std::next(newInputsInsertionPoint);
}
++inputIdx;
}
......@@ -986,6 +1021,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
const auto val = std::make_pair(parent, outputIdx);
AIDGE_INTERNAL_ASSERT(std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end());
newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
newOutputsInsertionPoint = std::next(newOutputsInsertionPoint);
}
++outputIdx;
}
......
......@@ -15,6 +15,8 @@
#include <set>
#include <string>
#include <random>
#include <algorithm>
#include <utility>
#include <catch2/catch_test_macros.hpp>
......@@ -27,16 +29,28 @@
using namespace Aidge;
std::set<NodePtr> genRandomDAG(size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) {
std::random_device rd;
std::mt19937 gen(rd());
std::binomial_distribution<> dIn(maxIn, avgIn/maxIn);
std::binomial_distribution<> dOut(maxOut, avgOut/maxOut);
std::pair<NodePtr, std::set<NodePtr>> genRandomDAG(std::mt19937::result_type seed, size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) {
std::mt19937 gen(seed);
std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn);
std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut);
std::binomial_distribution<> dLink(1, density);
std::vector<NodePtr> nodes;
std::vector<std::pair<int, int>> nbIOs;
for (size_t i = 0; i < nbNodes; ++i) {
nodes.push_back(GenericOperator("Fictive", dIn(gen), dIn(gen), dOut(gen)));
const auto nbIn = 1 + dIn(gen);
nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen)));
}
std::vector<int> nodesSeq(nbNodes);
std::iota(nodesSeq.begin(), nodesSeq.end(), 0);
// Don't use gen or seed here, must be different each time!
std::shuffle(nodesSeq.begin(), nodesSeq.end(), std::default_random_engine(std::random_device{}()));
std::vector<NodePtr> nodes(nbNodes, nullptr);
for (auto idx : nodesSeq) {
const std::string type = "Fictive";
const std::string name = type + std::to_string(idx);
nodes[idx] = GenericOperator(type.c_str(), nbIOs[idx].first, nbIOs[idx].first, nbIOs[idx].second, name.c_str());
}
for (size_t i = 0; i < nbNodes; ++i) {
......@@ -45,20 +59,49 @@ std::set<NodePtr> genRandomDAG(size_t nbNodes, float density = 0.5, size_t maxIn
for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) {
if (dLink(gen)) {
nodes[i]->addChild(nodes[j], outId, inId);
break;
}
}
}
}
}
return std::set<NodePtr>(nodes.begin(), nodes.end());
return std::make_pair(nodes[0], std::set<NodePtr>(nodes.begin(), nodes.end()));
}
std::set<std::string> nodePtrToName(const std::set<NodePtr>& nodes) {
std::set<std::string> nodesName;
std::transform(nodes.begin(), nodes.end(), std::inserter(nodesName, nodesName.begin()),
[](const NodePtr& node) {
return node->name();
});
return nodesName;
}
std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes) {
std::vector<std::pair<std::string, IOIndex_t>> nodesName;
std::transform(nodes.begin(), nodes.end(), std::back_inserter(nodesName),
[](const std::pair<NodePtr, IOIndex_t>& node) {
return std::make_pair(node.first->name(), node.second);
});
return nodesName;
}
TEST_CASE("genRandomDAG") {
auto g = std::make_shared<GraphView>();
g->add(genRandomDAG(10));
REQUIRE(g->getNodes().size() == 10);
g->save("./genRandomDAG");
std::random_device rd;
const std::mt19937::result_type seed(rd());
auto g1 = std::make_shared<GraphView>();
g1->add(genRandomDAG(seed, 10, 0.5));
auto g2 = std::make_shared<GraphView>();
g2->add(genRandomDAG(seed, 10, 0.5));
g1->save("./genRandomDAG1");
g2->save("./genRandomDAG2");
REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes()));
REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs()));
REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
}
TEST_CASE("[core/graph] GraphView(Constructor)") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment