Skip to content
Snippets Groups Projects
Commit 6fdf2ae6 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

More generic FuseMulAdd

parent 0350b632
No related branches found
No related tags found
No related merge requests found
......@@ -36,11 +36,37 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// Step 1 : Create FC
// Fetch the output dimension throught the bias size
std::shared_ptr<Node> bias = (addNode->getParent(1)) ? addNode->getParent(1)->cloneSharedOperators() : nullptr;
std::shared_ptr<Node> bias = nullptr;
if (addNode->getParent(0) == matmulNode) {
AIDGE_ASSERT(matmulNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(1)->cloneSharedOperators();
}
else if (addNode->getParent(1) == matmulNode) {
AIDGE_ASSERT(matmulNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(0)->cloneSharedOperators();
}
AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe.");
std::shared_ptr<Node> weight = nullptr;
if (matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type)
{
weight = matmulNode->getParent(1)->cloneSharedOperators();
}
else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type)
{
weight = matmulNode->getParent(0)->cloneSharedOperators();
}
else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type)
{
// If both inputs are producers, there is an ambiguity, but both options
// result in a correct solution.
fmt::print("Warning: both MatMul inputs are Producers, assume data at input#0 and weights at input#1.\n");
weight = matmulNode->getParent(1)->cloneSharedOperators();
}
AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator.");
std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators();
// TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment