diff --git a/src/recipes/MatMulTiling.cpp b/src/recipes/MatMulTiling.cpp index d21c7d406d164721aaef4017db4211dd5d0d3bd8..a8cb8b955c5e470879ca24d04c4728514dc5945f 100644 --- a/src/recipes/MatMulTiling.cpp +++ b/src/recipes/MatMulTiling.cpp @@ -56,6 +56,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) slice00_ends->addChild(slice00, 0, 2); auto slice00_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true); slice00_axes->addChild(slice00, 0, 3); + auto slice00_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true); + slice00_steps->addChild(slice00, 0, 4); auto matMul00 = MatMul(); auto identity1 = Identity(); auto slice01 = Slice(); @@ -65,6 +67,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) slice01_ends->addChild(slice01, 0, 2); auto slice01_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true); slice01_axes->addChild(slice01, 0, 3); + auto slice01_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true); + slice01_steps->addChild(slice01, 0, 4); auto matMul01 = MatMul(); auto concat0 = Concat(2, axis); @@ -82,7 +86,7 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) auto g = std::make_shared<GraphView>(); g->add({identity0, identity1}); - g->add({slice00, matMul00, matMul01, slice01, concat0}); + g->add({slice00, slice00_starts, slice00_ends, slice00_axes, slice00_steps, matMul00, matMul01, slice01, slice01_starts, slice01_ends, slice01_axes, slice01_steps, concat0}); g->save("micrograph"); auto replaced = GraphView::replace(gMatMul, g); @@ -91,8 +95,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) g->forwardDims(); // Recursive tiling - MatMulTiling(matMul00, maxDims); - MatMulTiling(matMul01, maxDims); + matMulTiling(matMul00, maxDims); + matMulTiling(matMul01, maxDims); } else { Log::warn("Unable to split MatMul {}", matMul->name());