Skip to content
Snippets Groups Projects
Commit 30781475 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed issue with fuseMulAdd

parent 837bead8
No related branches found
No related tags found
No related merge requests found
...@@ -39,11 +39,11 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -39,11 +39,11 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
std::shared_ptr<Node> bias = nullptr; std::shared_ptr<Node> bias = nullptr;
if (addNode->getParent(0) == matmulNode) { if (addNode->getParent(0) == matmulNode) {
AIDGE_ASSERT(matmulNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe."); AIDGE_ASSERT(matmulNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(1)->cloneSharedOperators(); bias = addNode->getParent(1);
} }
else if (addNode->getParent(1) == matmulNode) { else if (addNode->getParent(1) == matmulNode) {
AIDGE_ASSERT(matmulNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe."); AIDGE_ASSERT(matmulNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(0)->cloneSharedOperators(); bias = addNode->getParent(0);
} }
std::shared_ptr<Node> weight = nullptr; std::shared_ptr<Node> weight = nullptr;
...@@ -51,13 +51,13 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -51,13 +51,13 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
|| (matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type || (matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type)) && matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type))
{ {
weight = matmulNode->getParent(1)->cloneSharedOperators(); weight = matmulNode->getParent(1);
} }
else if ((matmulNode->getParent(0) && !matmulNode->getParent(1)) else if ((matmulNode->getParent(0) && !matmulNode->getParent(1))
|| (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type || (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type)) && matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type))
{ {
weight = matmulNode->getParent(0)->cloneSharedOperators(); weight = matmulNode->getParent(0);
} }
else if (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type else if (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type) && matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type)
...@@ -65,7 +65,7 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -65,7 +65,7 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// If both inputs are producers, there is an ambiguity, but both options // If both inputs are producers, there is an ambiguity, but both options
// result in a correct solution. // result in a correct solution.
Log::notice("Notice: both MatMul inputs are Producers, assume data at input#0 and weights at input#1."); Log::notice("Notice: both MatMul inputs are Producers, assume data at input#0 and weights at input#1.");
weight = matmulNode->getParent(1)->cloneSharedOperators(); weight = matmulNode->getParent(1);
} }
AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator."); AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator.");
...@@ -90,9 +90,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -90,9 +90,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// Step 2 : Branch existing producers & create the others // Step 2 : Branch existing producers & create the others
// link weights & bias // link weights & bias
weight->addChild(fc, 0, 1); weight->cloneSharedOperators()->addChild(fc, 0, 1);
if (bias) { if (bias) {
bias->addChild(fc, 0, 2); bias->cloneSharedOperators()->addChild(fc, 0, 2);
} }
...@@ -100,8 +100,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -100,8 +100,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
// Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
// Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory? // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory?
auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)}); auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1), fc->getParent(2)});
GraphView::replace({matmulNode, addNode, addNode->getParent(1), matmulNode->getParent(1)}, newNodes); GraphView::replace({matmulNode, addNode, bias, weight}, newNodes);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment