Skip to content
Snippets Groups Projects
Commit 499196e8 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed algo + added unit tests

parent f7b58554
No related branches found
No related tags found
1 merge request!8GraphView cloning proposal + labelGraph proof of concept
Pipeline #31302 passed
......@@ -33,8 +33,7 @@ class GraphView;
* @brief Object carrying the topological information of the computational graph.
*/
class Node : public std::enable_shared_from_this<Node> {
//private:
public: // TODO: workaround to make GraphView:clone() work, friend doesn't work because of forward declaration
private:
struct weakCompare {
bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const {
// Compare the content of the weak_ptrs
......@@ -400,9 +399,6 @@ public:
return node->clone();
}
// TODO: does not work, friend requires full definition, but there is a circular dependency between Node and GraphView
//friend std::shared_ptr<GraphView> GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const;
private:
///////////////////////////////////////////////////////
// OPERATORS
......
......@@ -15,6 +15,7 @@
#include <assert.h>
#include <map>
#include <vector>
#include <string>
namespace Aidge {
......@@ -30,11 +31,6 @@ private:
struct is_vector<std::vector<T, Alloc>> : std::true_type {};
public:
// not copyable, not movable
CParameter(CParameter const &) = delete;
CParameter(CParameter &&) = delete;
CParameter &operator=(CParameter const &) = delete;
CParameter &operator=(CParameter &&) = delete;
CParameter() : m_Params({}){};
~CParameter() = default;
......
......@@ -676,69 +676,52 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
}
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph;
std::map<NodePtr, NodePtr> oldToNewNodes;
std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);
// Map for old node -> new node correspondance
std::map<NodePtr, NodePtr> oldToNewNodes;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
oldToNewNodes[node_ptr] = cloneNode(node_ptr);
}
// For each node, convert old node -> new node references
// For each node, convert old node -> new node connections
for (auto &oldToNewNode : oldToNewNodes) {
// The only GraphView for the new graph is the one we return at the end
oldToNewNode.second->addView(newGraph);
if (oldToNewNode.second == nullptr)
continue; // deleted node
if (!(oldToNewNode.second->name()).empty())
newGraph->mNodeRegistry.insert(std::make_pair(oldToNewNode.second->name(), oldToNewNode.second));
// Add new node to new GraphView
newGraph->add(oldToNewNode.second, false);
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0;
for (auto& parent : oldToNewNode.first->mParents) {
oldToNewNode.second->mParents.push_back(oldToNewNodes[parent]);
if (oldToNewNodes[parent])
oldToNewNode.second->mIdOutParents.push_back(oldToNewNode.first->mIdOutParents[parentId]);
else
oldToNewNode.second->mIdOutParents.push_back(gk_IODefaultIndex);
++parentId;
}
// Connect child nodes.
size_t childId = 0;
for (const auto &childrenOfOneOutput : oldToNewNode.first->mChildren) {
oldToNewNode.second->mChildren.push_back(std::vector<std::weak_ptr<Node>>());
oldToNewNode.second->mIdInChildren.push_back(std::vector<IOIndex_t>());
size_t j = 0;
for (const auto &oneChild : childrenOfOneOutput) {
if (oldToNewNodes[oneChild.lock()]) {
oldToNewNode.second->mChildren.back().push_back(std::weak_ptr<Node>(oldToNewNodes[oneChild.lock()]));
oldToNewNode.second->mIdInChildren.back().push_back(oldToNewNode.first->mIdInChildren[childId][j]);
}
++j;
for (auto parent : oldToNewNode.first->inputs()) {
while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph
assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs");
const auto& parents = parent.first->inputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
parent = parents[0];
}
else {
break;
}
}
++childId;
}
}
// Copy remaining GraphView informations
newGraph->setName(mName);
if (oldToNewNodes[parent.first]) {
oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
}
for (const auto& inputNode : mInputNodes) {
if (oldToNewNodes[inputNode]) {
newGraph->mInputNodes.insert(oldToNewNodes[inputNode]);
++parentId;
}
}
for (const auto& outputNode : mOutputNodes) {
if (oldToNewNodes[outputNode]) {
newGraph->mOutputNodes.insert(oldToNewNodes[outputNode]);
}
}
// Update OutputNodes/inputNodes
newGraph->updateInputNodes();
newGraph->updateOutputNodes();
return newGraph;
}
......@@ -330,4 +330,132 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4);
}
}
\ No newline at end of file
}
TEST_CASE("[GraphView] clone") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("clone_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0));
}
auto g2 = g1->clone();
g2->save("clone_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator());
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
}
TEST_CASE("[GraphView] cloneSharedProducers") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("cloneSharedProducers_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0));
}
auto g2 = g1->cloneSharedProducers();
g2->save("cloneSharedProducers_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment