diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp
index fa63608c74bfa39e7e53f5a1fc9c775e80d4b777..4c7195af28f15add22e82149091e94943c7f460d 100644
--- a/src/operator/MatMul.cpp
+++ b/src/operator/MatMul.cpp
@@ -32,6 +32,11 @@ void Aidge::MatMul_Op::computeOutputDims() {
         std::vector<std::size_t> dims0 = getInput(0)->dims();
         std::vector<std::size_t> dims1 = getInput(1)->dims();
 
+        // keep second-to-last dimension of dims0
+        const bool keepDim0 = dims0.size() > 1;
+        // keep last dimension of dims1
+        const bool keepDim1 = dims1.size() > 1;
+
         if (dims0.size() == 1) {
             dims0.insert(dims0.cbegin(), 1);
         }
@@ -42,10 +47,10 @@ void Aidge::MatMul_Op::computeOutputDims() {
 
 
         if (dims0.size() > dims1.size()) {
-                dims1.insert(dims1.cbegin(), dims0.begin(), dims0.end() - dims1.size());
+            dims1.insert(dims1.cbegin(), dims0.begin(), dims0.end() - dims1.size());
         }
         else if (dims1.size() > dims0.size()) {
-                dims0.insert(dims0.cbegin(), dims1.begin(), dims1.end() - dims0.size());
+            dims0.insert(dims0.cbegin(), dims1.begin(), dims1.end() - dims0.size());
         }
 
         AIDGE_ASSERT(dims0[dims_size-1] == dims1[dims_size-2], "Incompatible matrices sizes.");
@@ -56,11 +61,10 @@ void Aidge::MatMul_Op::computeOutputDims() {
             outDims[i] = std::max(dims0[i], dims1[i]);
         }
 
-        // keep second-to-last dimension of dims0
-        if (dims0.size() > 1)
+        // use keepDim0 instead of dims0.size() because dims0 has been modified
+        if (keepDim0)
             outDims.push_back(dims0[dims_size-2]);
-        // keep last dimension of dims1
-        if (dims1.size() > 1)
+        if (keepDim1)
             outDims.push_back(dims1[dims_size-1]);
 
         mOutputs[0]->resize(outDims);