From 7aec113a6869024a0e2e46835ddc30703b83d7d9 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 30 Jan 2024 16:23:37 +0100
Subject: [PATCH] fix computOutputDims

---
 include/aidge/operator/MatMul.hpp | 18 +----------
 src/operator/MatMul.cpp           | 50 ++++++++++++++++++++++++++++++-
 2 files changed, 50 insertions(+), 18 deletions(-)

diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp
index 5f06e8c2a..a6904740f 100644
--- a/include/aidge/operator/MatMul.hpp
+++ b/include/aidge/operator/MatMul.hpp
@@ -55,23 +55,7 @@ public:
     }
 
 
-    void computeOutputDims() override final {
-        if (!getInput(0)->empty() && !getInput(1)->empty())
-        {
-            std::vector<std::size_t> outDims;
-            for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++)
-            {
-                outDims.push_back(getInput(0)->dims()[i]);
-            }
-            size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0;
-            for (std::size_t i = 0; i < getInput(1)->nbDims(); i++)
-            {
-                if(i != secondToLastIdx)
-                   outDims.push_back(getInput(1)->dims()[i]);
-            }     
-            mOutputs[0]->resize(outDims);
-        }
-    }
+    void computeOutputDims() override final;
 
 
     void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp
index 666ed3921..4bb54e83b 100644
--- a/src/operator/MatMul.cpp
+++ b/src/operator/MatMul.cpp
@@ -9,8 +9,56 @@
  *
  ********************************************************************************/
 
+#include <algorithm>
 #include <string>
+#include <vector>
 
 #include "aidge/operator/MatMul.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
 
-const std::string Aidge::MatMul_Op::Type = "MatMul";
\ No newline at end of file
+const std::string Aidge::MatMul_Op::Type = "MatMul";
+
+void Aidge::MatMul_Op::computeOutputDims() {
+    if (!getInput(0)->empty() && !getInput(1)->empty())
+    {
+        const auto dims0 = getInput(0)->dims();
+        const auto dims1 = getInput(1)->dims();
+
+        if (dims0.size() > 2 && dims1.size() > 2)
+        {
+            bool supportedSizes = true;
+            std::size_t d0 = dims0.size()-3, d1 = dims1.size()-3;
+            while(d0>0 && d1>0 && supportedSizes)
+            {
+                if(dims0[d0] != dims1[d1])
+                    supportedSizes = false;
+
+                d0--;
+                d1--;
+            }
+            if(!supportedSizes)
+                AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported sizes for MatMul!");
+        }
+
+        std::size_t secondToLastIdx2 = dims1.size()>1 ? dims1.size() - 2 : dims1.size() - 1;
+        if(dims0[dims0.size() - 1]  != dims1[secondToLastIdx2])
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Inner dimension missmatch for MatMul!");
+
+        std::vector<std::size_t> outDims;
+        if(dims0.size() > 2 || dims1.size() > 2)
+        {
+            if(dims0.size() > dims1.size())
+                std::copy_n(dims0.begin(), dims0.size()-2, std::back_inserter(outDims));
+            else
+                std::copy_n(dims1.begin(), dims1.size()-2, std::back_inserter(outDims));
+        }
+
+        if(dims0.size() > 1)
+            outDims.push_back(dims0[dims0.size()-2]);
+        if(dims1.size() > 1)
+            outDims.push_back(dims1[dims1.size() - 1]);
+
+        mOutputs[0]->resize(outDims);
+    }
+}
\ No newline at end of file
-- 
GitLab