From 009d227dc38d6ba69a6be28d21366866b38f482a Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 9 Jun 2024 16:23:28 +0200
Subject: [PATCH] Fixed minor issues

---
 src/recipes/MatMulTiling.cpp | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/src/recipes/MatMulTiling.cpp b/src/recipes/MatMulTiling.cpp
index d21c7d406..a8cb8b955 100644
--- a/src/recipes/MatMulTiling.cpp
+++ b/src/recipes/MatMulTiling.cpp
@@ -56,6 +56,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims)
         slice00_ends->addChild(slice00, 0, 2);
         auto slice00_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true);
         slice00_axes->addChild(slice00, 0, 3);
+        auto slice00_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true);
+        slice00_steps->addChild(slice00, 0, 4);
         auto matMul00 = MatMul();
         auto identity1 = Identity();
         auto slice01 = Slice();
@@ -65,6 +67,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims)
         slice01_ends->addChild(slice01, 0, 2);
         auto slice01_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true);
         slice01_axes->addChild(slice01, 0, 3);
+        auto slice01_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true);
+        slice01_steps->addChild(slice01, 0, 4);
         auto matMul01 = MatMul();
         auto concat0 = Concat(2, axis);
 
@@ -82,7 +86,7 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims)
 
         auto g = std::make_shared<GraphView>();
         g->add({identity0, identity1});
-        g->add({slice00, matMul00, matMul01, slice01, concat0});
+        g->add({slice00, slice00_starts, slice00_ends, slice00_axes, slice00_steps, matMul00, matMul01, slice01, slice01_starts, slice01_ends, slice01_axes, slice01_steps, concat0});
         g->save("micrograph");
 
         auto replaced = GraphView::replace(gMatMul, g);
@@ -91,8 +95,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims)
             g->forwardDims();
 
             // Recursive tiling
-            MatMulTiling(matMul00, maxDims);
-            MatMulTiling(matMul01, maxDims);
+            matMulTiling(matMul00, maxDims);
+            matMulTiling(matMul01, maxDims);
         }
         else {
             Log::warn("Unable to split MatMul {}", matMul->name());
-- 
GitLab