diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp index bb4b0e3db1974ccf106699b25fd71fc9cc09654c..2be6f10060bf8164c99ecba05457bf4cd3ac162a 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/FuseMulAdd.cpp @@ -52,6 +52,13 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< && matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type)) { weight = matmulNode->getParent(1); + // Transpose weights because weight Tensor is in first input + auto weightOpTensor = std::static_pointer_cast<OperatorTensor>(weight->getOperator()); + const std::shared_ptr<Aidge::Tensor>& weightTensor = weightOpTensor->getOutput(0); + std::vector<DimSize_t> shape = weightTensor->dims(); + std::reverse(shape.begin(), shape.end()); + weightTensor->copyTranspose(*weightTensor, std::vector<Aidge::DimSize_t>({1ul, 0ul})); + // weightOpTensor->setOutput(0, std::make_shared<Aidge::Tensor>(weightTensor->transpose(shape))); } else if ((matmulNode->getParent(0) && !matmulNode->getParent(1)) || (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type @@ -82,7 +89,7 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< break; } } - AIDGE_ASSERT(outSize, "Couldnt get output number of channels for FC operator."); + AIDGE_ASSERT(outSize, "Could not get output number of channels for FC operator."); // Instanciate FC std::string fcName = matmulNode->name(); @@ -138,4 +145,4 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){ } -} \ No newline at end of file +}