From 97dc21c52b1949812b8332a76ba917d31e08d77f Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 13 Jun 2024 07:41:24 +0000 Subject: [PATCH] Update fuseMulAdd to transpose the weight matrice if it is the first input of the MatMul. --- src/recipes/FuseMulAdd.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp index bb4b0e3db..2be6f1006 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/FuseMulAdd.cpp @@ -52,6 +52,13 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< && matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type)) { weight = matmulNode->getParent(1); + // Transpose weights because weight Tensor is in first input + auto weightOpTensor = std::static_pointer_cast<OperatorTensor>(weight->getOperator()); + const std::shared_ptr<Aidge::Tensor>& weightTensor = weightOpTensor->getOutput(0); + std::vector<DimSize_t> shape = weightTensor->dims(); + std::reverse(shape.begin(), shape.end()); + weightTensor->copyTranspose(*weightTensor, std::vector<Aidge::DimSize_t>({1ul, 0ul})); + // weightOpTensor->setOutput(0, std::make_shared<Aidge::Tensor>(weightTensor->transpose(shape))); } else if ((matmulNode->getParent(0) && !matmulNode->getParent(1)) || (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type @@ -82,7 +89,7 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< break; } } - AIDGE_ASSERT(outSize, "Couldnt get output number of channels for FC operator."); + AIDGE_ASSERT(outSize, "Could not get output number of channels for FC operator."); // Instanciate FC std::string fcName = matmulNode->name(); @@ -138,4 +145,4 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){ } -} \ No newline at end of file +} -- GitLab