diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp index a09c27c2be574adbccde8a1392a90a4c50a727a1..3047ad71563cd6d66979a0d70b950784d4b6ee7e 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/FuseMulAdd.cpp @@ -36,11 +36,37 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< // Step 1 : Create FC // Fetch the output dimension throught the bias size - std::shared_ptr<Node> bias = (addNode->getParent(1)) ? addNode->getParent(1)->cloneSharedOperators() : nullptr; + std::shared_ptr<Node> bias = nullptr; + if (addNode->getParent(0) == matmulNode) { + AIDGE_ASSERT(matmulNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe."); + bias = addNode->getParent(1)->cloneSharedOperators(); + } + else if (addNode->getParent(1) == matmulNode) { + AIDGE_ASSERT(matmulNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe."); + bias = addNode->getParent(0)->cloneSharedOperators(); + } - AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe."); + std::shared_ptr<Node> weight = nullptr; + if (matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type) + { + weight = matmulNode->getParent(1)->cloneSharedOperators(); + } + else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type) + { + weight = matmulNode->getParent(0)->cloneSharedOperators(); + } + else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type) + { + // If both inputs are producers, there is an ambiguity, but both options + // result in a correct solution. + fmt::print("Warning: both MatMul inputs are Producers, assume data at input#0 and weights at input#1.\n"); + weight = matmulNode->getParent(1)->cloneSharedOperators(); + } + AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator."); - std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators(); // TODO: find another way to get OutChannels for FC operator. // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output DimSize_t outSize = 0;