Skip to content
Snippets Groups Projects
Commit 1f2d196d authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working version: node ordering is now well defined

parent a8337fd1
No related branches found
No related tags found
No related merge requests found
......@@ -35,6 +35,9 @@ private:
/// @brief Name of the graphview
std::string mName;
/// @brief GraphView root node
NodePtr mRootNode;
/// @brief Set of nodes included in the GraphView
std::set<NodePtr> mNodes;
......@@ -99,6 +102,10 @@ public:
return mNodes.find(nodePtr) != mNodes.end();
}
NodePtr getRootNode() {
return mRootNode;
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
......@@ -263,8 +270,9 @@ public:
* @brief Include a set of Nodes to the current GraphView object.
* @param otherNodes
* @param includeLearnableParam
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/
void add(std::set<NodePtr> otherNodes,
bool add(std::set<NodePtr> otherNodes,
bool includeLearnableParam = true);
/**
......@@ -272,16 +280,18 @@ public:
* The second element in the otherNodes pair is the start node.
* @param otherNodes
* @param includeLearnableParam
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/
void add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
bool add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
bool includeLearnableParam = true);
/**
* @brief Include every Node inside another GraphView to the current
* GraphView.
* @param other_graph GraphView containing the Nodes to include.
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/
void add(std::shared_ptr<GraphView> otherGraph);
bool add(std::shared_ptr<GraphView> otherGraph);
/**
* @brief Include a Node in the current GraphView and link it to another
......
......@@ -75,15 +75,22 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
: node_ptr->name();
namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType]));
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
if (node_ptr == mRootNode) {
std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
else {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
}
// Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) {
if (child) {
if (child != nullptr && mNodes.find(child) != mNodes.end()) {
IOIndex_t inputIdx = 0;
for (auto pa_ptr : child->getParents()) {
if (pa_ptr == node_ptr) {
......@@ -116,6 +123,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
std::fprintf(fp, "classDef inputCls fill:#afa\n");
std::fprintf(fp, "classDef outputCls fill:#ffa\n");
std::fprintf(fp, "classDef rootCls stroke:#f00\n");
if (verbose) {
for (const auto &c : typeCounter) {
......@@ -382,6 +390,11 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
}
void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
// first node to be added to the graph is the root node by default
if (mRootNode == nullptr) {
mRootNode = node;
}
// add to the GraphView nodes
node->addView(shared_from_this());
mNodes.insert(node);
......@@ -407,80 +420,117 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
}
}
void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
if (otherNodes.empty()) {
return true;
}
bool orderUnicity = true;
// List only the nodes that are not already present in current graph
std::set<NodePtr> nodesToAdd;
std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));
do {
std::set<NodePtr> nextNodesToAdd;
// Find nodes that are direct parent of current GraphView and add them first
// such that the obtained GraphView inputs list will be the same, regardless
// of the evaluation order of those nodes
// (i.e. one of their child is in current GraphView)
for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
for (auto child : (*it)->getChildren()) {
if (mNodes.find(child) != mNodes.end()) {
nextNodesToAdd.insert(*it);
it = nodesToAdd.erase(it);
// List the nodes to rank, initially all the nodes in the GraphView
std::set<NodePtr> nodesToRank(mNodes);
nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end());
std::vector<NodePtr> rankedNodesToAdd;
if (mRootNode == nullptr) {
std::set<NodePtr> noParentNodes;
// If no root node is defined, check nodes without parents
for (auto node : nodesToRank) {
bool noParent = true;
for (auto parent : node->getParents()) {
if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
noParent = false;
break;
}
}
if (it == nodesToAdd.end()) {
break;
if (noParent) {
noParentNodes.insert(node);
}
}
// If there is no more parent, find nodes that are direct children of current GraphView,
// such that the obtained GraphView outputs list will be the same, regardless
// of the evaluation order of those nodes
// (i.e. one of their parent is in current GraphView)
// TODO: this might be done simultaneously with direct parents, by removing
// the empty() condition, but there might be edge cases that may change
// the resulting inputs/outputs order depending on evaluation order (???)
if (nextNodesToAdd.empty()) {
for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
for (auto parent : (*it)->getParents()) {
if (mNodes.find(parent) != mNodes.end()) {
nextNodesToAdd.insert(*it);
it = nodesToAdd.erase(it);
break;
// Take the first one found (this is an arbitrary choice)
mRootNode = *noParentNodes.begin();
if (noParentNodes.size() > 1) {
// If there is more than one, order unicity cannot be garanteed!
orderUnicity = false;
}
rankedNodesToAdd.push_back(mRootNode);
}
nodesToRank.erase(mRootNode);
std::vector<NodePtr> rankedNodes;
rankedNodes.push_back(mRootNode);
for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
NodePtr curNode = rankedNodes[curNodeIdx];
for (auto childs : curNode->getOrderedChildren()) {
for (auto child : childs) {
if (nodesToRank.find(child) != nodesToRank.end()) {
rankedNodes.push_back(child);
nodesToRank.erase(child);
if (nodesToAdd.find(child) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(child);
nodesToAdd.erase(child);
}
}
if (it == nodesToAdd.end()) {
break;
}
}
}
// If no node if found, there might be remaining nodes that form an independant sub-graph
// In this case, additionnal inputs/outputs will be added at the end of
// the GraphView inputs/outputs list, in no particular order.
// TODO: we might try to preserve the initial inputs/ouputs relative order of those nodes
// if they actually comes from a GraphView, but I think that would be a far-fetched expectation
// from the users...
if (nextNodesToAdd.empty()) {
nodesToAdd.swap(nextNodesToAdd);
for (auto parent : curNode->getParents()) {
if (nodesToRank.find(parent) != nodesToRank.end()) {
rankedNodes.push_back(parent);
nodesToRank.erase(parent);
if (nodesToAdd.find(parent) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(parent);
nodesToAdd.erase(parent);
}
}
}
}
// Add selected nodes in the current GraphView, in no particular order
for (auto node_ptr : nextNodesToAdd) {
add(node_ptr, includeLearnableParam);
if (!nodesToAdd.empty()) {
// There are remaining nodes without path to the root node
orderUnicity = false;
while (!nodesToAdd.empty()) {
const auto it = nodesToAdd.begin();
rankedNodesToAdd.push_back(*it);
nodesToAdd.erase(it);
}
}
while (!nodesToAdd.empty());
for (auto node_ptr : rankedNodesToAdd) {
add(node_ptr, includeLearnableParam);
}
return orderUnicity;
}
void Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
if (nodes.first != nullptr) {
mRootNode = nodes.first;
add(nodes.first, includeLearnableParam);
}
add(nodes.second, includeLearnableParam);
return add(nodes.second, includeLearnableParam);
}
void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
add(graph->getNodes(), false);
bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
if (mRootNode == nullptr) {
mRootNode = graph->getRootNode();
}
return add(graph->getNodes(), false);
}
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
......
......@@ -88,20 +88,32 @@ std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<s
TEST_CASE("genRandomDAG") {
std::random_device rd;
const std::mt19937::result_type seed(rd());
const size_t nbTests = 100;
size_t nbUnicity = 0;
auto g1 = std::make_shared<GraphView>();
g1->add(genRandomDAG(seed, 10, 0.5));
auto g2 = std::make_shared<GraphView>();
g2->add(genRandomDAG(seed, 10, 0.5));
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
g1->save("./genRandomDAG1");
g2->save("./genRandomDAG2");
const auto g1 = std::make_shared<GraphView>("g1");
const bool unicity1 = g1->add(genRandomDAG(seed, 10, 0.5));
const auto g2 = std::make_shared<GraphView>("g2");
const bool unicity2 = g2->add(genRandomDAG(seed, 10, 0.5));
REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes()));
REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs()));
REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
g1->save("./genRandomDAG1");
g2->save("./genRandomDAG2");
REQUIRE(unicity1 == unicity2);
if (unicity1) {
REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes()));
REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs()));
REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
++nbUnicity;
}
}
printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests);
}
TEST_CASE("[core/graph] GraphView(Constructor)") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment