From 6fdf2ae65b25572b73a08d9213c39f1840c7396c Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sat, 2 Mar 2024 19:29:49 +0100 Subject: [PATCH] More generic FuseMulAdd --- src/recipes/FuseMulAdd.cpp | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp index a09c27c2b..3047ad715 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/FuseMulAdd.cpp @@ -36,11 +36,37 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< // Step 1 : Create FC // Fetch the output dimension throught the bias size - std::shared_ptr<Node> bias = (addNode->getParent(1)) ? addNode->getParent(1)->cloneSharedOperators() : nullptr; + std::shared_ptr<Node> bias = nullptr; + if (addNode->getParent(0) == matmulNode) { + AIDGE_ASSERT(matmulNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe."); + bias = addNode->getParent(1)->cloneSharedOperators(); + } + else if (addNode->getParent(1) == matmulNode) { + AIDGE_ASSERT(matmulNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe."); + bias = addNode->getParent(0)->cloneSharedOperators(); + } - AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe."); + std::shared_ptr<Node> weight = nullptr; + if (matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type) + { + weight = matmulNode->getParent(1)->cloneSharedOperators(); + } + else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type) + { + weight = matmulNode->getParent(0)->cloneSharedOperators(); + } + else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type) + { + // If both inputs are producers, there is an ambiguity, but both options + // result in a correct solution. + fmt::print("Warning: both MatMul inputs are Producers, assume data at input#0 and weights at input#1.\n"); + weight = matmulNode->getParent(1)->cloneSharedOperators(); + } + AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator."); - std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators(); // TODO: find another way to get OutChannels for FC operator. // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output DimSize_t outSize = 0; -- GitLab