Skip to content
Snippets Groups Projects
Commit 6fdf2ae6 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

More generic FuseMulAdd

parent 0350b632
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
Pipeline #40396 failed
......@@ -36,11 +36,37 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// Step 1 : Create FC
// Fetch the output dimension throught the bias size
std::shared_ptr<Node> bias = (addNode->getParent(1)) ? addNode->getParent(1)->cloneSharedOperators() : nullptr;
std::shared_ptr<Node> bias = nullptr;
if (addNode->getParent(0) == matmulNode) {
AIDGE_ASSERT(matmulNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(1)->cloneSharedOperators();
}
else if (addNode->getParent(1) == matmulNode) {
AIDGE_ASSERT(matmulNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(0)->cloneSharedOperators();
}
AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe.");
std::shared_ptr<Node> weight = nullptr;
if (matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type)
{
weight = matmulNode->getParent(1)->cloneSharedOperators();
}
else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type)
{
weight = matmulNode->getParent(0)->cloneSharedOperators();
}
else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
&& matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type)
{
// If both inputs are producers, there is an ambiguity, but both options
// result in a correct solution.
fmt::print("Warning: both MatMul inputs are Producers, assume data at input#0 and weights at input#1.\n");
weight = matmulNode->getParent(1)->cloneSharedOperators();
}
AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator.");
std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators();
// TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment