From e9fdb3b61f7476f48bd83a47e211a35639d076d7 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Wed, 22 May 2024 18:27:23 +0200
Subject: [PATCH] add Step attribute to Slice

---
 include/aidge/operator/Slice.hpp         |  38 +++--
 python_binding/operator/pybind_Slice.cpp |   1 +
 src/operator/Slice.cpp                   | 188 ++++++++++++++---------
 src/recipes/HorizontalTiling.cpp         |   6 +-
 unit_tests/operator/Test_SliceImpl.cpp   |   8 +-
 5 files changed, 154 insertions(+), 87 deletions(-)

diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp
index 3e46ca6c6..01dc85504 100644
--- a/include/aidge/operator/Slice.hpp
+++ b/include/aidge/operator/Slice.hpp
@@ -30,25 +30,26 @@ public:
     void forward() override;
 };
 
-enum class SliceAttr { Starts, Ends, Axes };
+enum class SliceAttr { Starts, Ends, Axes, Steps };
 
 class Slice_Op
     : public OperatorTensor,
       public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>,
-      public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>> {
+      public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>> {
 
 public:
     static const std::string Type;
 
     Slice_Op() = delete;
 
-    using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>>;
+    using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>>;
     template <SliceAttr e> using attr = typename Attributes_::template attr<e>;
-    Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes)
-        : OperatorTensor(Type, 4, 0, 1),
+    Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes, const std::vector<std::int64_t>& steps)
+        : OperatorTensor(Type, 5, 0, 1),
           Attributes_(attr<SliceAttr::Starts>(starts),
                       attr<SliceAttr::Ends>(ends),
-                      attr<SliceAttr::Axes>(axes))
+                      attr<SliceAttr::Axes>(axes),
+                      attr<SliceAttr::Steps>(steps))
     {
         mImpl = std::make_shared<Slice_OpImpl>(*this);
     }
@@ -83,11 +84,12 @@ public:
     void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
 
     static const std::vector<std::string> getInputsName(){
-        return {"data_input", "starts", "ends", "axes"};
+        return {"data_input", "starts", "ends", "axes", "steps"};
     }
     static const std::vector<std::string> getOutputsName(){
         return {"data_output"};
     }
+
 };
 
 /**
@@ -98,14 +100,32 @@ public:
 inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t>& starts = {},
                                    const std::vector<std::int64_t>& ends = {},
                                    const std::vector<std::int8_t>& axes = {},
+                                   const std::vector<std::int64_t>& steps = {},
                                    const std::string &name = "") {
-    return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
+    return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes, steps), name);
 }
 }  // namespace Aidge
 
 namespace {
 template <>
-const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" };
+const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" };
 }
 
+namespace Aidge {
+    class SliceImplForward
+    : public Registrable<SliceImplForward,
+                         std::tuple<DataType>,
+                         void(const Slice_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
+    template <typename I>
+    void Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_);
+
+namespace {
+static Registrar<SliceImplForward> registrarSliceImplForward_Float32(
+        {DataType::Float32}, Slice_forward_kernel<float>);
+static Registrar<SliceImplForward> registrarSliceImplForward_Int32(
+        {DataType::Int32}, Slice_forward_kernel<int>);
+static Registrar<SliceImplForward> registrarSliceImplForward_Int64(
+        {DataType::Float64}, Slice_forward_kernel<double>);
+}
+}
 #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
diff --git a/python_binding/operator/pybind_Slice.cpp b/python_binding/operator/pybind_Slice.cpp
index 68124262c..5270e2d92 100644
--- a/python_binding/operator/pybind_Slice.cpp
+++ b/python_binding/operator/pybind_Slice.cpp
@@ -30,6 +30,7 @@ void init_Slice(py::module& m) {
           py::arg("starts") = std::vector<std::int64_t>(),
           py::arg("ends") = std::vector<std::int64_t>(),
           py::arg("axes") = std::vector<std::int8_t>(),
+          py::arg("steps") = std::vector<std::int64_t>(),
           py::arg("name") = "");
 }
 }  // namespace Aidge
diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp
index 76cf64119..166442bd8 100644
--- a/src/operator/Slice.cpp
+++ b/src/operator/Slice.cpp
@@ -11,6 +11,7 @@
 
 #include "aidge/operator/Slice.hpp"
 
+#include <algorithm>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -25,88 +26,91 @@
 #include "aidge/data/Tensor.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/data/Data.hpp"
+#include "aidge/utils/Registrar.hpp"
 
-void Aidge::Slice_OpImpl::forward() {
-    const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp);
-
-    if (!op.getInput(0)) {
-        AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", op.Type);
-    }
-    AIDGE_ASSERT((op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Ends>().size()) &&
-                 (op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
-                 "start, end and axes arguments should be the same size.");
-
-    const std::size_t nbDims = op.getInput(0)->nbDims();
-
-    const std::vector<std::size_t>& inputDims = op.getInput(0)->dims();
-    auto outputDims = op.getInput(0)->dims();
+template<class I>
+void Aidge::Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_){
+    const I* input = static_cast<const I*>(input_);
+    I* output = static_cast<I*>(output_);
 
-    // compute index of the output's first element
-    // compute output dimension at the same time (may change between two forward calls)
-    std::size_t beginning = 0;
-    const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size();
+    const std::size_t nbDims = inputDims.size();
+    std::vector<DimSize_t> dims = inputDims;
+    DimSize_t totalSize = std::accumulate(inputDims.cbegin(), inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+    I* outputAccumulation = new I[totalSize];
+    const I* inputAccumulation = input;
+    const std::size_t nbAxes = std::get<0>(attrs).size();
     for (std::size_t i = 0; i < nbAxes; ++i) {
-        // For each slice operation get the params and cast them to size_t
-        DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ?
-                            static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) :
-                            static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size()));
-        DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ?
-                            static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) :
-                            static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis]));
-        DimSize_t end = op.template getAttr<SliceAttr::Ends>()[i] >= 0 ?
-                        static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i]) :
-                        static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(inputDims[axis]));
-        const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
-        beginning += start * stridePostAxis;
-        const std::size_t sliceLength = end - start;
-        outputDims[axis] = sliceLength;
-    }
-    op.getOutput(0)->resize(outputDims);
+        DimIdx_t axis = std::get<2>(attrs)[i] >= 0 ?
+                            static_cast<DimIdx_t>(std::get<2>(attrs)[i]) :
+                            static_cast<DimIdx_t>(std::get<2>(attrs)[i] + static_cast<DimIdx_t>(inputDims.size()));
+        std::int64_t start = std::get<0>(attrs)[i] >= 0 ?
+                             std::get<0>(attrs)[i] :
+                             std::get<0>(attrs)[i] + static_cast<std::int64_t>(inputDims[axis]);
+        std::int64_t end = std::get<1>(attrs)[i] >= 0 ?
+                           std::get<1>(attrs)[i] :
+                           std::get<1>(attrs)[i] + static_cast<std::int64_t>(inputDims[axis]);
+        std::int64_t step = std::get<3>(attrs)[i];
 
+        std::size_t sliceSize = static_cast<std::size_t>((end - start) / std::abs(step));
 
-    // for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3}
-    std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
-    for (std::size_t i = 0; i < nbDims; ++i) {
-        substractedDims[i] = inputDims[i] - outputDims[i];
-    }
-
-    // for outputDims = {3,2,2,1}: prodOutputDims = {12,4,2,1}
-    std::vector<std::size_t> prodOutputDims = std::vector<std::size_t>(nbDims);
-    std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims + 1);
-    prodOutputDims[nbDims - 1] = outputDims[nbDims - 1];
-    prodInputDims[nbDims - 1] = inputDims[nbDims - 1];
-    prodInputDims[nbDims] = 1;
-    for (std::size_t i = 2; i <= nbDims; ++i) {
-        prodOutputDims[nbDims - i] = prodOutputDims[nbDims - i + 1] * outputDims[nbDims - i];
-        prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1] * inputDims[nbDims - i];
-    }
+        if ( i > 0) {
+            outputAccumulation = new I[totalSize];
+        }
+        const std::size_t stride_pre = std::accumulate(dims.cbegin(), dims.cbegin() + axis, 1, std::multiplies<std::size_t>());
+        const std::size_t stride_post = std::accumulate(dims.crbegin(), dims.crbegin() + nbDims -1 - axis, 1, std::multiplies<std::size_t>());
+        std::int64_t firstElem = step > 0 ? start : end;
+        std::int64_t lastElem = step > 0 ? end : start;
 
-    std::size_t i = beginning;
-    std::size_t size = 0; // number of elements to copy
-    std::size_t offset = 0;
-    for (std::size_t j = 0; j < prodOutputDims[0];) {
-        ++size;
-        ++i;
-        ++j;
-        bool newChunk = false;
-        for (std::size_t idx = nbDims - 1; idx > 0; --idx) {
-            if (j % prodOutputDims[idx] == 0) {
-                i += substractedDims[idx] * prodInputDims[idx + 1];
-                newChunk = true;
+        for (std::size_t outer = 0; outer < stride_pre; outer++)
+        {
+            std::size_t addedSlices = 0;
+            for (std::int64_t inner = firstElem; inner < lastElem; inner+=step)
+            {
+                size_t idx = outer * stride_post * dims[axis] + inner * stride_post;
+                size_t idx_out = outer * stride_post * sliceSize + addedSlices * stride_post;
+                if (idx < totalSize) {
+                    std::copy_n(std::next(inputAccumulation, idx), stride_post, std::next(outputAccumulation, idx_out));
+                }
+                addedSlices++;
             }
         }
-
-        if (newChunk) {
-            op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning), size, offset);
-            beginning = i;
-            offset += size;
-            size = 0;
+        totalSize /= dims[axis];
+        totalSize *= sliceSize;
+        dims[axis] = sliceSize;
+        
+        if (inputAccumulation != input) {
+            delete[] inputAccumulation;
         }
+        inputAccumulation = outputAccumulation;
+        
+    }
+    // Copy elements from inputAccumulation to output while dividing by divisor
+    std::copy_n(inputAccumulation, totalSize, output);
+    // op.getOutput(0)->getImpl()->copy(inputAccumulation, totalSize);
+    if (outputAccumulation) {
+        delete[] outputAccumulation;
     }
+}
 
-    if (size > 0) {
-        op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning), size, offset);
+void Aidge::Slice_OpImpl::forward() {
+    const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp);
+
+    if (!op.getInput(0)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", op.Type);
     }
+    AIDGE_ASSERT((op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Ends>().size()) &&
+                 (op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
+                 "start, end and axes arguments should be the same size.");
+   // Find the correct kernel type
+    auto kernelFunc =
+            Registrar<SliceImplForward>::create({std::static_pointer_cast<Tensor>(op.getRawInput(0))->dataType()});
+
+    // Call kernel
+    kernelFunc(dynamic_cast<const Slice_Op&>(mOp).getStaticAttributes(),
+               std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
+               std::static_pointer_cast<Tensor>(op.getInput(0))->getImpl()->hostPtr(),
+               std::static_pointer_cast<Tensor>(op.getOutput(0))->getImpl()->hostPtr());
 }
 
 const std::string Aidge::Slice_Op::Type = "Slice";
@@ -127,7 +131,7 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
 
             AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType.");
 
-            this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs
+            this->template getAttr<SliceAttr::Starts>().clear();
             this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
             this->template getAttr<SliceAttr::Ends>().clear();
             this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size());
@@ -179,11 +183,46 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
                                 std::back_inserter(this->template getAttr<SliceAttr::Axes>()));                                
                     break;
                 default:
-                    AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
+                    AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Input DataType is not supported.", type());
                     break;
             }
         }
 
+        // Fill Steps attr if empty
+        if(this->template getAttr<SliceAttr::Steps>().empty()) {
+            // In case the input Steps is not provided, default value is 1
+            this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(getInput(1)->size(), 1);
+
+            if (getInput(4) && !getInput(4)->empty()) {
+                this->template getAttr<SliceAttr::Steps>().clear();
+                this->template getAttr<SliceAttr::Steps>().reserve(getInput(1)->size());
+                switch (mInputs[1]->dataType()) {
+                    case DataType::Float64:
+                        std::copy_n(static_cast<double*>(mInputs[4]->getImpl()->rawPtr()),
+                                    getInput(4)->size(),
+                                    std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
+                        break;
+                    case DataType::Float32:
+                        std::copy_n(static_cast<float*>(mInputs[4]->getImpl()->rawPtr()),
+                                    getInput(4)->size(),
+                                    std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
+                        break;
+                    case DataType::Int64:
+                        std::copy_n(static_cast<std::int64_t*>(mInputs[4]->getImpl()->rawPtr()),
+                                    getInput(4)->size(),
+                                    std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
+                        break;
+                    case DataType::Int32:
+                        std::copy_n(static_cast<std::int32_t*>(mInputs[4]->getImpl()->rawPtr()),
+                                    getInput(4)->size(),
+                                    std::back_inserter(this->template getAttr<SliceAttr::Steps>()));                              
+                        break;
+                    default:
+                        AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
+                        break;
+                }
+            }
+        }
         DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
         std::vector<DimSize_t> outDims = getInput(0)->dims();
         for (std::size_t i = 0; i < nbAxes; ++i) {
@@ -197,7 +236,10 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
                             static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
                             static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
 
-            const std::size_t sliceLength = end - start;
+            if(this->template getAttr<SliceAttr::Steps>()[i] == 0) {
+                AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type());
+            }
+            const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i]));
             // Check if slice length is valid
             if (sliceLength > getInput(0)->dims()[axis])
             {
diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp
index dbd954d1b..e0ce58d31 100644
--- a/src/recipes/HorizontalTiling.cpp
+++ b/src/recipes/HorizontalTiling.cpp
@@ -106,7 +106,11 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
         std::vector<std::int8_t> usedDims(inputDimsEnd.size());
         std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0));
 
-        auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
+        // Create Slice's Steps attribute
+        std::vector<std::int64_t> steps(inputDimsEnd.size());
+        std::iota(steps.begin(), steps.end(), static_cast<std::int64_t>(1));
+
+        auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, steps, "Slice_" + std::to_string(currentFirstDims[axis]));
         slice -> addChild(newNode, 0, 0);
         newNode -> addChild(concat, 0, i);
 
diff --git a/unit_tests/operator/Test_SliceImpl.cpp b/unit_tests/operator/Test_SliceImpl.cpp
index b0fc2bc9b..a9a20c3fd 100644
--- a/unit_tests/operator/Test_SliceImpl.cpp
+++ b/unit_tests/operator/Test_SliceImpl.cpp
@@ -69,7 +69,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
         mySlice->getOperator()->setDataType(DataType::Int32);
         mySlice->getOperator()->setBackend("cpu");
         mySlice->forward();
-        // mySlice->getOperator()->output(0).print();
+        op->getOutput(0)->print();
         REQUIRE(*(op->getOutput(0)) == *expectedOutput);
         REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
         REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
@@ -176,7 +176,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
         mySlice->getOperator()->setDataType(DataType::Int32);
         mySlice->getOperator()->setBackend("cpu");
         mySlice->forward();
-        // mySlice->getOperator()->output(0).print();
+        // op->getOutput(0)->print();
         REQUIRE(*(op->getOutput(0)) == *expectedOutput);
         REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
         REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
@@ -217,13 +217,13 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
             }
         });
 
-        std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,5}, {0,1,2,3});
+        std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,5}, {0,1,2,3}, {1,1,1,1});
         auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
         mySlice->getOperator()->associateInput(0,input0);
         mySlice->getOperator()->setDataType(DataType::Int32);
         mySlice->getOperator()->setBackend("cpu");
         mySlice->forward();
-        // mySlice->getOperator()->output(0).print();
+        // op->getOutput(0)->print();
         REQUIRE(*(op->getOutput(0)) == *expectedOutput);
         REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
         REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
-- 
GitLab