Skip to content
Snippets Groups Projects
Commit ff19c397 authored by Maxence Naud's avatar Maxence Naud
Browse files

Completely remove replaceWith() member function from GraphView

parent d076d8df
No related branches found
No related tags found
No related merge requests found
...@@ -335,13 +335,6 @@ public: ...@@ -335,13 +335,6 @@ public:
IOIndex_t newParentInputTensorIdx, IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx); IOIndex_t newParentOutputTensorIdx);
/**
* @brief Replace the current GraphView with the set of given Nodes if possible
* @param newNodes Set of Nodes.
* @return true
* @return false
*/
bool replaceWith(std::set<NodePtr> newNodes);
/** /**
* @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible. * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible.
......
...@@ -530,10 +530,6 @@ void Aidge::GraphView::insertParent(NodePtr childNode, ...@@ -530,10 +530,6 @@ void Aidge::GraphView::insertParent(NodePtr childNode,
} }
bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
return GraphView::replace(mNodes, newNodes);
}
bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
......
...@@ -72,10 +72,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -72,10 +72,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Step 3 : Update all graphviews that contains at least one node to replace // Step 3 : Update all graphviews that contains at least one node to replace
// Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
// Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
// Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ? // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory?
// auto nodeToReplace = std::make_shared<GraphView>();
// nodeToReplace->add(nodes, false);
// nodeToReplace->replaceWith({fc});
auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)}); auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)});
GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, newNodes); GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, newNodes);
......
...@@ -278,62 +278,6 @@ TEST_CASE("Graph Forward dims", "[GraphView]") { ...@@ -278,62 +278,6 @@ TEST_CASE("Graph Forward dims", "[GraphView]") {
} }
} }
TEST_CASE("[core/graph] GraphView(replaceWith)", "[replaceWith]") {
SECTION("replace small pattern") {
// create original graph
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w");
auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b");
auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
auto matmul = GenericOperator("MatMul", 1, 2, 1, "matmul");
auto add = GenericOperator("Add", 1, 2, 1, "add");
otherInput->addChild(other1);
other1->addChild(matmul);
matmul->addChild(add);
add->addChild(other2);
matmulWeight->addChild(matmul, 0, 1);
addBias->addChild(add, 0, 1);
g->add({other1, matmul, add, other2});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add}));
// create graph to replace
std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>("NodesToReplace");
nodeToReplace->add({matmul, add}, true);
// create replacing graph
std::shared_ptr<Node> newNode = GenericOperator("FC", 1, 3, 1, "fc");
// other1->addChild(newNode);
auto newMatmulWeight = matmulWeight->cloneSharedOperators();
newMatmulWeight->addChild(newNode, 0, 1);
auto newAddBias = addBias->cloneSharedOperators();
newAddBias->addChild(newNode, 0, 2);
// replace
nodeToReplace->replaceWith({newNode, newMatmulWeight, newAddBias});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, newNode}));
}
SECTION("replace with nothing") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
auto r1 = GenericOperator("relu", 0, 0, 1);
auto r2 = GenericOperator("relu", 1, 1, 1);
auto r3 = GenericOperator("relu", 1, 1, 1);
auto r4 = GenericOperator("relu", 1, 1, 0);
r1->addChild(r2);
r2->addChild(r3);
r3->addChild(r4);
g->add({r1, r2, r3, r4});
auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3});
auto graphToReplace = std::make_shared<GraphView>();
graphToReplace->add(nodesToReplace);
graphToReplace->replaceWith({});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4);
}
}
TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
SECTION("replace small pattern") { SECTION("replace small pattern") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment