From 97dc21c52b1949812b8332a76ba917d31e08d77f Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 13 Jun 2024 07:41:24 +0000
Subject: [PATCH] Update fuseMulAdd to transpose the weight matrice if it is
 the first input of the MatMul.

---
 src/recipes/FuseMulAdd.cpp | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp
index bb4b0e3db..2be6f1006 100644
--- a/src/recipes/FuseMulAdd.cpp
+++ b/src/recipes/FuseMulAdd.cpp
@@ -52,6 +52,13 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
             && matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type))
     {
         weight = matmulNode->getParent(1);
+        // Transpose weights because weight Tensor is in first input
+        auto weightOpTensor = std::static_pointer_cast<OperatorTensor>(weight->getOperator());
+        const std::shared_ptr<Aidge::Tensor>& weightTensor = weightOpTensor->getOutput(0);
+        std::vector<DimSize_t> shape =  weightTensor->dims();
+        std::reverse(shape.begin(), shape.end());
+        weightTensor->copyTranspose(*weightTensor, std::vector<Aidge::DimSize_t>({1ul, 0ul}));
+        // weightOpTensor->setOutput(0, std::make_shared<Aidge::Tensor>(weightTensor->transpose(shape)));
     }
     else if ((matmulNode->getParent(0) && !matmulNode->getParent(1))
         || (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
@@ -82,7 +89,7 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
             break;
         }
     }
-    AIDGE_ASSERT(outSize, "Couldnt get output number of channels for FC operator.");
+    AIDGE_ASSERT(outSize, "Could not get output number of channels for FC operator.");
 
     // Instanciate FC
     std::string fcName = matmulNode->name();
@@ -138,4 +145,4 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){
 
 
     }
-}
\ No newline at end of file
+}
-- 
GitLab