From b2298d78e15ffbe670f75de06716e4058ebd4eec Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 28 May 2024 14:14:25 +0200
Subject: [PATCH] fix Slice node creation in HorizontalTiling

---
 src/recipes/HorizontalTiling.cpp | 47 ++++++++++++++++++++++++++------
 1 file changed, 38 insertions(+), 9 deletions(-)

diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp
index 342a4afe6..989754930 100644
--- a/src/recipes/HorizontalTiling.cpp
+++ b/src/recipes/HorizontalTiling.cpp
@@ -90,30 +90,59 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
         clonedInputs[1] -> addChild(newNode, 0, 1);
         clonedInputs[2] -> addChild(newNode, 0, 2);
 
+        auto slice = Slice();
         auto backend = outTensor->getImpl()->backend();
-        // Create Slice's Starts attribute
+        // Create Slice's Starts producer node
         std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size());
         for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
             inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]);
         }
-        // Create Slice's Ends attribute
+        const std::shared_ptr<Tensor> starts = std::make_shared<Tensor>();
+        starts -> setDataType(DataType::Int64);
+        starts -> setBackend(backend);
+        starts -> resize(std::vector<std::size_t>({inputDimsStart.size()}));
+        starts -> getImpl() -> copyFromHost(inputDimsStart.data(), inputDimsStart.size());
+        auto startsNode = Producer(starts, slice->name() + sliceInputsNames[1]);
+        startsNode -> addChild(slice, 0, 1);
+
+        // Create Slice's Ends producer node
         std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size());
         for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
             inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]);
         }
-
-        // Create Slice's Axes attribute
+        const std::shared_ptr<Tensor> ends = std::make_shared<Tensor>();
+        ends -> setDataType(DataType::Int64);
+        ends -> setBackend(backend);
+        ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()}));
+        ends -> getImpl() -> copyFromHost(inputDimsEnd.data(), inputDimsEnd.size());
+        auto endsNode = Producer(ends, slice->name() + sliceInputsNames[2]);
+        endsNode -> addChild(slice, 0, 2);
+
+        // Create Slice's Axes producer node
         std::vector<std::int8_t> usedDims(inputDimsEnd.size());
         std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0));
+        const std::shared_ptr<Tensor> axes = std::make_shared<Tensor>();
+        axes -> setDataType(DataType::Int8);
+        axes -> setBackend(backend);
+        axes -> resize(std::vector<std::size_t>({usedDims.size()}));
+        axes -> getImpl() -> copyFromHost(usedDims.data(), usedDims.size());
+        auto axesNode = Producer(axes, slice->name() + sliceInputsNames[3]);
+        axesNode -> addChild(slice, 0, 3);
+
+        // Create Slice's Steps producer node
+        std::vector<std::int64_t> inputDimsSteps(inputDimsEnd.size(), static_cast<std::int64_t>(1));
+        const std::shared_ptr<Tensor> steps = std::make_shared<Tensor>();
+        steps -> setDataType(DataType::Int64);
+        steps -> setBackend(backend);
+        steps -> resize(std::vector<std::size_t>({inputDimsSteps.size()}));
+        steps -> getImpl() -> copyFromHost(inputDimsSteps.data(), inputDimsSteps.size());
+        auto stepsNode = Producer(steps, slice->name() + sliceInputsNames[4]);
+        stepsNode -> addChild(slice, 0, 4);
 
-        // Create Slice's Steps attribute
-        std::vector<std::int64_t> steps(inputDimsEnd.size(), static_cast<std::int64_t>(1));
-
-        auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, steps, "Slice_" + std::to_string(currentFirstDims[axis]));
         slice -> addChild(newNode, 0, 0);
         newNode -> addChild(concat, 0, i);
 
-        tiledOperator.insert({slice, newNode});
+        tiledOperator.insert({slice, newNode, startsNode, endsNode, axesNode, stepsNode});
     }
 
     return tiledOperator;
-- 
GitLab