Skip to content
Snippets Groups Projects
Commit 1f2d196d authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working version: node ordering is now well defined

parent a8337fd1
No related branches found
No related tags found
No related merge requests found
...@@ -35,6 +35,9 @@ private: ...@@ -35,6 +35,9 @@ private:
/// @brief Name of the graphview /// @brief Name of the graphview
std::string mName; std::string mName;
/// @brief GraphView root node
NodePtr mRootNode;
/// @brief Set of nodes included in the GraphView /// @brief Set of nodes included in the GraphView
std::set<NodePtr> mNodes; std::set<NodePtr> mNodes;
...@@ -99,6 +102,10 @@ public: ...@@ -99,6 +102,10 @@ public:
return mNodes.find(nodePtr) != mNodes.end(); return mNodes.find(nodePtr) != mNodes.end();
} }
NodePtr getRootNode() {
return mRootNode;
}
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TENSOR MANAGEMENT // TENSOR MANAGEMENT
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
...@@ -263,8 +270,9 @@ public: ...@@ -263,8 +270,9 @@ public:
* @brief Include a set of Nodes to the current GraphView object. * @brief Include a set of Nodes to the current GraphView object.
* @param otherNodes * @param otherNodes
* @param includeLearnableParam * @param includeLearnableParam
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/ */
void add(std::set<NodePtr> otherNodes, bool add(std::set<NodePtr> otherNodes,
bool includeLearnableParam = true); bool includeLearnableParam = true);
/** /**
...@@ -272,16 +280,18 @@ public: ...@@ -272,16 +280,18 @@ public:
* The second element in the otherNodes pair is the start node. * The second element in the otherNodes pair is the start node.
* @param otherNodes * @param otherNodes
* @param includeLearnableParam * @param includeLearnableParam
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/ */
void add(std::pair<NodePtr, std::set<NodePtr>> otherNodes, bool add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
bool includeLearnableParam = true); bool includeLearnableParam = true);
/** /**
* @brief Include every Node inside another GraphView to the current * @brief Include every Node inside another GraphView to the current
* GraphView. * GraphView.
* @param other_graph GraphView containing the Nodes to include. * @param other_graph GraphView containing the Nodes to include.
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/ */
void add(std::shared_ptr<GraphView> otherGraph); bool add(std::shared_ptr<GraphView> otherGraph);
/** /**
* @brief Include a Node in the current GraphView and link it to another * @brief Include a Node in the current GraphView and link it to another
......
...@@ -75,15 +75,22 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { ...@@ -75,15 +75,22 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
: node_ptr->name(); : node_ptr->name();
namePtrTable[node_ptr] = namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType])); (currentType + "_" + std::to_string(typeCounter[currentType]));
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str()); if (node_ptr == mRootNode) {
std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
else {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
} }
// Write every link // Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) { for (const std::shared_ptr<Node> &node_ptr : mNodes) {
IOIndex_t outputIdx = 0; IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) { for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) { for (auto child : childs) {
if (child) { if (child != nullptr && mNodes.find(child) != mNodes.end()) {
IOIndex_t inputIdx = 0; IOIndex_t inputIdx = 0;
for (auto pa_ptr : child->getParents()) { for (auto pa_ptr : child->getParents()) {
if (pa_ptr == node_ptr) { if (pa_ptr == node_ptr) {
...@@ -116,6 +123,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { ...@@ -116,6 +123,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
std::fprintf(fp, "classDef inputCls fill:#afa\n"); std::fprintf(fp, "classDef inputCls fill:#afa\n");
std::fprintf(fp, "classDef outputCls fill:#ffa\n"); std::fprintf(fp, "classDef outputCls fill:#ffa\n");
std::fprintf(fp, "classDef rootCls stroke:#f00\n");
if (verbose) { if (verbose) {
for (const auto &c : typeCounter) { for (const auto &c : typeCounter) {
...@@ -382,6 +390,11 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/, ...@@ -382,6 +390,11 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
} }
void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) { void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
// first node to be added to the graph is the root node by default
if (mRootNode == nullptr) {
mRootNode = node;
}
// add to the GraphView nodes // add to the GraphView nodes
node->addView(shared_from_this()); node->addView(shared_from_this());
mNodes.insert(node); mNodes.insert(node);
...@@ -407,80 +420,117 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara ...@@ -407,80 +420,117 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
} }
} }
void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
if (otherNodes.empty()) {
return true;
}
bool orderUnicity = true;
// List only the nodes that are not already present in current graph // List only the nodes that are not already present in current graph
std::set<NodePtr> nodesToAdd; std::set<NodePtr> nodesToAdd;
std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin())); std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));
do { // List the nodes to rank, initially all the nodes in the GraphView
std::set<NodePtr> nextNodesToAdd; std::set<NodePtr> nodesToRank(mNodes);
nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end());
// Find nodes that are direct parent of current GraphView and add them first std::vector<NodePtr> rankedNodesToAdd;
// such that the obtained GraphView inputs list will be the same, regardless
// of the evaluation order of those nodes if (mRootNode == nullptr) {
// (i.e. one of their child is in current GraphView) std::set<NodePtr> noParentNodes;
for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
for (auto child : (*it)->getChildren()) { // If no root node is defined, check nodes without parents
if (mNodes.find(child) != mNodes.end()) { for (auto node : nodesToRank) {
nextNodesToAdd.insert(*it); bool noParent = true;
it = nodesToAdd.erase(it); for (auto parent : node->getParents()) {
if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
noParent = false;
break; break;
} }
} }
if (it == nodesToAdd.end()) {
break; if (noParent) {
noParentNodes.insert(node);
} }
} }
// If there is no more parent, find nodes that are direct children of current GraphView, // Take the first one found (this is an arbitrary choice)
// such that the obtained GraphView outputs list will be the same, regardless mRootNode = *noParentNodes.begin();
// of the evaluation order of those nodes
// (i.e. one of their parent is in current GraphView) if (noParentNodes.size() > 1) {
// TODO: this might be done simultaneously with direct parents, by removing // If there is more than one, order unicity cannot be garanteed!
// the empty() condition, but there might be edge cases that may change orderUnicity = false;
// the resulting inputs/outputs order depending on evaluation order (???) }
if (nextNodesToAdd.empty()) {
for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) { rankedNodesToAdd.push_back(mRootNode);
for (auto parent : (*it)->getParents()) { }
if (mNodes.find(parent) != mNodes.end()) {
nextNodesToAdd.insert(*it); nodesToRank.erase(mRootNode);
it = nodesToAdd.erase(it); std::vector<NodePtr> rankedNodes;
break; rankedNodes.push_back(mRootNode);
for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
NodePtr curNode = rankedNodes[curNodeIdx];
for (auto childs : curNode->getOrderedChildren()) {
for (auto child : childs) {
if (nodesToRank.find(child) != nodesToRank.end()) {
rankedNodes.push_back(child);
nodesToRank.erase(child);
if (nodesToAdd.find(child) != nodesToAdd.end()) {
rankedNodesToAdd.push_back(child);
nodesToAdd.erase(child);
} }
} }
if (it == nodesToAdd.end()) {
break;
}
} }
} }
// If no node if found, there might be remaining nodes that form an independant sub-graph for (auto parent : curNode->getParents()) {
// In this case, additionnal inputs/outputs will be added at the end of if (nodesToRank.find(parent) != nodesToRank.end()) {
// the GraphView inputs/outputs list, in no particular order. rankedNodes.push_back(parent);
// TODO: we might try to preserve the initial inputs/ouputs relative order of those nodes nodesToRank.erase(parent);
// if they actually comes from a GraphView, but I think that would be a far-fetched expectation
// from the users... if (nodesToAdd.find(parent) != nodesToAdd.end()) {
if (nextNodesToAdd.empty()) { rankedNodesToAdd.push_back(parent);
nodesToAdd.swap(nextNodesToAdd); nodesToAdd.erase(parent);
}
}
} }
}
// Add selected nodes in the current GraphView, in no particular order if (!nodesToAdd.empty()) {
for (auto node_ptr : nextNodesToAdd) { // There are remaining nodes without path to the root node
add(node_ptr, includeLearnableParam); orderUnicity = false;
while (!nodesToAdd.empty()) {
const auto it = nodesToAdd.begin();
rankedNodesToAdd.push_back(*it);
nodesToAdd.erase(it);
} }
} }
while (!nodesToAdd.empty());
for (auto node_ptr : rankedNodesToAdd) {
add(node_ptr, includeLearnableParam);
}
return orderUnicity;
} }
void Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) { bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
if (nodes.first != nullptr) { if (nodes.first != nullptr) {
mRootNode = nodes.first;
add(nodes.first, includeLearnableParam); add(nodes.first, includeLearnableParam);
} }
add(nodes.second, includeLearnableParam); return add(nodes.second, includeLearnableParam);
} }
void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
add(graph->getNodes(), false); if (mRootNode == nullptr) {
mRootNode = graph->getRootNode();
}
return add(graph->getNodes(), false);
} }
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
......
...@@ -88,20 +88,32 @@ std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<s ...@@ -88,20 +88,32 @@ std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<s
TEST_CASE("genRandomDAG") { TEST_CASE("genRandomDAG") {
std::random_device rd; const size_t nbTests = 100;
const std::mt19937::result_type seed(rd()); size_t nbUnicity = 0;
auto g1 = std::make_shared<GraphView>(); for (int test = 0; test < nbTests; ++test) {
g1->add(genRandomDAG(seed, 10, 0.5)); std::random_device rd;
auto g2 = std::make_shared<GraphView>(); const std::mt19937::result_type seed(rd());
g2->add(genRandomDAG(seed, 10, 0.5));
g1->save("./genRandomDAG1"); const auto g1 = std::make_shared<GraphView>("g1");
g2->save("./genRandomDAG2"); const bool unicity1 = g1->add(genRandomDAG(seed, 10, 0.5));
const auto g2 = std::make_shared<GraphView>("g2");
const bool unicity2 = g2->add(genRandomDAG(seed, 10, 0.5));
REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes())); g1->save("./genRandomDAG1");
REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs())); g2->save("./genRandomDAG2");
REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
REQUIRE(unicity1 == unicity2);
if (unicity1) {
REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes()));
REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs()));
REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
++nbUnicity;
}
}
printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests);
} }
TEST_CASE("[core/graph] GraphView(Constructor)") { TEST_CASE("[core/graph] GraphView(Constructor)") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment