Skip to content
Snippets Groups Projects
Commit 499196e8 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed algo + added unit tests

parent f7b58554
No related branches found
No related tags found
No related merge requests found
...@@ -33,8 +33,7 @@ class GraphView; ...@@ -33,8 +33,7 @@ class GraphView;
* @brief Object carrying the topological information of the computational graph. * @brief Object carrying the topological information of the computational graph.
*/ */
class Node : public std::enable_shared_from_this<Node> { class Node : public std::enable_shared_from_this<Node> {
//private: private:
public: // TODO: workaround to make GraphView:clone() work, friend doesn't work because of forward declaration
struct weakCompare { struct weakCompare {
bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const { bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const {
// Compare the content of the weak_ptrs // Compare the content of the weak_ptrs
...@@ -400,9 +399,6 @@ public: ...@@ -400,9 +399,6 @@ public:
return node->clone(); return node->clone();
} }
// TODO: does not work, friend requires full definition, but there is a circular dependency between Node and GraphView
//friend std::shared_ptr<GraphView> GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const;
private: private:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// OPERATORS // OPERATORS
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <assert.h> #include <assert.h>
#include <map> #include <map>
#include <vector> #include <vector>
#include <string>
namespace Aidge { namespace Aidge {
...@@ -30,11 +31,6 @@ private: ...@@ -30,11 +31,6 @@ private:
struct is_vector<std::vector<T, Alloc>> : std::true_type {}; struct is_vector<std::vector<T, Alloc>> : std::true_type {};
public: public:
// not copyable, not movable
CParameter(CParameter const &) = delete;
CParameter(CParameter &&) = delete;
CParameter &operator=(CParameter const &) = delete;
CParameter &operator=(CParameter &&) = delete;
CParameter() : m_Params({}){}; CParameter() : m_Params({}){};
~CParameter() = default; ~CParameter() = default;
......
...@@ -676,69 +676,52 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { ...@@ -676,69 +676,52 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
} }
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const { std::shared_ptr<Aidge::GraphView> Aidge::GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph; std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);
std::map<NodePtr, NodePtr> oldToNewNodes;
// Map for old node -> new node correspondance // Map for old node -> new node correspondance
std::map<NodePtr, NodePtr> oldToNewNodes;
for (const std::shared_ptr<Node> &node_ptr : mNodes) { for (const std::shared_ptr<Node> &node_ptr : mNodes) {
oldToNewNodes[node_ptr] = cloneNode(node_ptr); oldToNewNodes[node_ptr] = cloneNode(node_ptr);
} }
// For each node, convert old node -> new node references // For each node, convert old node -> new node connections
for (auto &oldToNewNode : oldToNewNodes) { for (auto &oldToNewNode : oldToNewNodes) {
// The only GraphView for the new graph is the one we return at the end if (oldToNewNode.second == nullptr)
oldToNewNode.second->addView(newGraph); continue; // deleted node
if (!(oldToNewNode.second->name()).empty()) // Add new node to new GraphView
newGraph->mNodeRegistry.insert(std::make_pair(oldToNewNode.second->name(), oldToNewNode.second)); newGraph->add(oldToNewNode.second, false);
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0; size_t parentId = 0;
for (auto& parent : oldToNewNode.first->mParents) { for (auto parent : oldToNewNode.first->inputs()) {
oldToNewNode.second->mParents.push_back(oldToNewNodes[parent]); while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph
if (oldToNewNodes[parent]) assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs");
oldToNewNode.second->mIdOutParents.push_back(oldToNewNode.first->mIdOutParents[parentId]); const auto& parents = parent.first->inputs();
else
oldToNewNode.second->mIdOutParents.push_back(gk_IODefaultIndex); if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
++parentId; {
} parent = parents[0];
// Connect child nodes.
size_t childId = 0;
for (const auto &childrenOfOneOutput : oldToNewNode.first->mChildren) {
oldToNewNode.second->mChildren.push_back(std::vector<std::weak_ptr<Node>>());
oldToNewNode.second->mIdInChildren.push_back(std::vector<IOIndex_t>());
size_t j = 0;
for (const auto &oneChild : childrenOfOneOutput) {
if (oldToNewNodes[oneChild.lock()]) {
oldToNewNode.second->mChildren.back().push_back(std::weak_ptr<Node>(oldToNewNodes[oneChild.lock()]));
oldToNewNode.second->mIdInChildren.back().push_back(oldToNewNode.first->mIdInChildren[childId][j]);
}
++j;
} }
else {
break;
}
}
++childId; if (oldToNewNodes[parent.first]) {
} oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
} }
// Copy remaining GraphView informations
newGraph->setName(mName);
for (const auto& inputNode : mInputNodes) { ++parentId;
if (oldToNewNodes[inputNode]) {
newGraph->mInputNodes.insert(oldToNewNodes[inputNode]);
} }
} }
for (const auto& outputNode : mOutputNodes) { // Update OutputNodes/inputNodes
if (oldToNewNodes[outputNode]) { newGraph->updateInputNodes();
newGraph->mOutputNodes.insert(oldToNewNodes[outputNode]); newGraph->updateOutputNodes();
}
}
return newGraph; return newGraph;
} }
...@@ -330,4 +330,132 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { ...@@ -330,4 +330,132 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4); REQUIRE((r1->output(0))[0].first == r4);
} }
} }
\ No newline at end of file
TEST_CASE("[GraphView] clone") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("clone_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0));
}
auto g2 = g1->clone();
g2->save("clone_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator());
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
}
TEST_CASE("[GraphView] cloneSharedProducers") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("cloneSharedProducers_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0));
}
auto g2 = g1->cloneSharedProducers();
g2->save("cloneSharedProducers_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment