Skip to content
Snippets Groups Projects
Commit 7299fadf authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] recipies use replace() instead of replaceWith()

parent 422ac1ce
No related branches found
No related tags found
1 merge request!45[Upd] replace() instead of replaceWith() in GraphView
Pipeline #33924 failed
...@@ -116,15 +116,14 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ ...@@ -116,15 +116,14 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){
bias->set<float>(output, biasValue); bias->set<float>(output, biasValue);
} }
auto g = std::make_shared<GraphView>();
g->add(std::set<std::shared_ptr<Node>>({ GraphView::replace(std::set<std::shared_ptr<Node>>({
batchnorm, batchnorm,
batchnorm->input(1).first, batchnorm->input(1).first,
batchnorm->input(2).first, batchnorm->input(2).first,
batchnorm->input(3).first, batchnorm->input(3).first,
batchnorm->input(4).first batchnorm->input(4).first
})); }), {});
g->replaceWith({});
} }
......
...@@ -47,8 +47,11 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -47,8 +47,11 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Step 1 : Create FC // Step 1 : Create FC
// Fetch the output dimension throught the bias size // Fetch the output dimension throught the bias size
auto producer_add_bias = add->input(1); std::shared_ptr<Node> bias = (add->getParent(1)) ? add->getParent(1)->cloneSharedOperators() : nullptr;
Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0);
Tensor& bias_tensor = bias->getOperator()->output(0);
std::shared_ptr<Node> weight = (matmul->getParent(1)) ? matmul->getParent(1)->cloneSharedOperators() : nullptr;
// Instanciate FC // Instanciate FC
//std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
...@@ -56,25 +59,22 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -56,25 +59,22 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Step 2 : Branch existing producers & create the others // Step 2 : Branch existing producers & create the others
// link weights & bias // link weights & bias
if (matmul->getParent(1)==nullptr) { if (weight) {
matmul->getParent(0)->addChild(fc, 0, 1); weight->addChild(fc, 0, 1);
printf("MatMul out[1] == nullptr !\n"); }
} else { if (bias) {
printf("MatMul out[1] != nullptr !\n"); bias->addChild(fc, 0, 2);
if (matmul->getParent(0)!=nullptr)
matmul->getParent(0)->addChild(fc, 0, 0);
matmul->input(1).first->addChild(fc, 0, 1);
} }
(producer_add_bias.first)->addChild(fc,0,2);
// Step 3 : Update all graphviews that contains at least one node to replace // Step 3 : Update all graphviews that contains at least one node to replace
// Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
// Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
// Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ? // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ?
auto nodeToReplace = std::make_shared<GraphView>(); // auto nodeToReplace = std::make_shared<GraphView>();
nodeToReplace->add(nodes, false); // nodeToReplace->add(nodes, false);
nodeToReplace->replaceWith({fc}); // nodeToReplace->replaceWith({fc});
GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, {fc, weight, bias});
} }
......
...@@ -30,10 +30,8 @@ namespace Aidge { ...@@ -30,10 +30,8 @@ namespace Aidge {
flatten = element; flatten = element;
} }
} }
auto g = std::make_shared<GraphView>();
// TODO : avoid using replace_with and use a remove method instead GraphView::replace({flatten}, {});
g->add(std::set<std::shared_ptr<Node>>({flatten}));
g->replaceWith({});
} }
void removeFlatten(std::shared_ptr<GraphView> graphView){ void removeFlatten(std::shared_ptr<GraphView> graphView){
......
...@@ -277,7 +277,7 @@ TEST_CASE("Graph Forward dims", "[GraphView]") { ...@@ -277,7 +277,7 @@ TEST_CASE("Graph Forward dims", "[GraphView]") {
} }
} }
TEST_CASE("[core/graph] GraphView(replaceWith)") { TEST_CASE("[core/graph] GraphView(replaceWith)", "[replaceWith]") {
SECTION("replace small pattern") { SECTION("replace small pattern") {
// create original graph // create original graph
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
...@@ -298,19 +298,21 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { ...@@ -298,19 +298,21 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add})); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add}));
// create graph to replace // create graph to replace
std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>(); std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>("NodesToReplace");
nodeToReplace->add({matmul, add}, false); nodeToReplace->add({matmul, add}, true);
// create replacing graph // create replacing graph
std::shared_ptr<Node> newNode = GenericOperator("FC", 1, 3, 1, "fc"); std::shared_ptr<Node> newNode = GenericOperator("FC", 1, 3, 1, "fc");
other1->addChild(newNode); // other1->addChild(newNode);
matmulWeight->addChild(newNode, 0, 1); auto newMatmulWeight = matmulWeight->cloneSharedOperators();
addBias->addChild(newNode, 0, 2); newMatmulWeight->addChild(newNode, 0, 1);
auto newAddBias = addBias->cloneSharedOperators();
newAddBias->addChild(newNode, 0, 2);
// replace // replace
nodeToReplace->replaceWith({newNode}); nodeToReplace->replaceWith({newNode, newMatmulWeight, newAddBias});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, newNode})); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, newNode}));
} }
SECTION("replace with nothing") { SECTION("replace with nothing") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment