Skip to content
Snippets Groups Projects
Commit 30781475 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed issue with fuseMulAdd

parent 837bead8
No related branches found
No related tags found
No related merge requests found
......@@ -39,11 +39,11 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
std::shared_ptr<Node> bias = nullptr;
if (addNode->getParent(0) == matmulNode) {
AIDGE_ASSERT(matmulNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(1)->cloneSharedOperators();
bias = addNode->getParent(1);
}
else if (addNode->getParent(1) == matmulNode) {
AIDGE_ASSERT(matmulNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(0)->cloneSharedOperators();
bias = addNode->getParent(0);
}
std::shared_ptr<Node> weight = nullptr;
......@@ -51,13 +51,13 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
|| (matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type))
{
weight = matmulNode->getParent(1)->cloneSharedOperators();
weight = matmulNode->getParent(1);
}
else if ((matmulNode->getParent(0) && !matmulNode->getParent(1))
|| (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type))
{
weight = matmulNode->getParent(0)->cloneSharedOperators();
weight = matmulNode->getParent(0);
}
else if (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type)
......@@ -65,7 +65,7 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// If both inputs are producers, there is an ambiguity, but both options
// result in a correct solution.
Log::notice("Notice: both MatMul inputs are Producers, assume data at input#0 and weights at input#1.");
weight = matmulNode->getParent(1)->cloneSharedOperators();
weight = matmulNode->getParent(1);
}
AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator.");
......@@ -90,9 +90,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// Step 2 : Branch existing producers & create the others
// link weights & bias
weight->addChild(fc, 0, 1);
weight->cloneSharedOperators()->addChild(fc, 0, 1);
if (bias) {
bias->addChild(fc, 0, 2);
bias->cloneSharedOperators()->addChild(fc, 0, 2);
}
......@@ -100,8 +100,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
// Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
// Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory?
auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)});
GraphView::replace({matmulNode, addNode, addNode->getParent(1), matmulNode->getParent(1)}, newNodes);
auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1), fc->getParent(2)});
GraphView::replace({matmulNode, addNode, bias, weight}, newNodes);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment