From ade64013ff2f18e155e211ad6b2065c487547875 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Thu, 7 Mar 2024 14:33:40 +0100
Subject: [PATCH] switch slice attrs into inputs

---
 include/aidge/operator/Slice.hpp          | 45 ++-------------
 python_binding/operator/pybind_Gather.cpp |  2 +-
 python_binding/operator/pybind_Slice.cpp  |  3 +-
 src/operator/Slice.cpp                    | 69 +++++++++++++++--------
 src/recipes/HorizontalTiling.cpp          |  5 +-
 5 files changed, 58 insertions(+), 66 deletions(-)

diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp
index 363c3c2b4..e71eaf40f 100644
--- a/include/aidge/operator/Slice.hpp
+++ b/include/aidge/operator/Slice.hpp
@@ -20,31 +20,16 @@
 #include "aidge/graph/Node.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/utils/Registrar.hpp"
-#include "aidge/utils/StaticAttributes.hpp"
 #include "aidge/utils/Types.h"
 
 namespace Aidge {
-enum class SliceAttr { Starts, Ends, Axes };
-
 class Slice_Op
     : public OperatorTensor,
-      public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>,
-      public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> {
+      public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>{
 public:
     static const std::string Type;
 
-    Slice_Op() = delete;
-
-    using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>;
-    template <SliceAttr e>
-    using attr = typename Attributes_::template attr<e>;
-
-    Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>&  ends, const std::vector<std::int64_t>& axes)
-        : OperatorTensor(Type, 1, 0, 1),
-          Attributes_(attr<SliceAttr::Starts>(starts),
-                      attr<SliceAttr::Ends>(ends),
-                      attr<SliceAttr::Axes>(axes))
-    {}
+    Slice_Op() : OperatorTensor(Type, 4, 0, 1) {}
 
     /**
      * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its
@@ -52,8 +37,7 @@ public:
      * @param op Operator to copy.
      */
     Slice_Op(const Slice_Op &op)
-        : OperatorTensor(op),
-          Attributes_(op)
+        : OperatorTensor(op)
     {
         if (op.mImpl){
             SET_IMPL_MACRO(Slice_Op, *this, op.mOutputs[0]->getImpl()->backend());
@@ -77,7 +61,7 @@ public:
     }
 
     static const std::vector<std::string> getInputsName(){
-        return {"data_input"};
+        return {"data_input", "starts", "ends", "axes"};
     }
     static const std::vector<std::string> getOutputsName(){
         return {"data_output"};
@@ -86,29 +70,12 @@ public:
 
 /**
  * @brief Exract a sub-Tensor from a bigger original Tensor.
- * @param starts Indexes for each dimension of the first element.
- * Can be a negative value. Negative values start their reference from the last index.
- * ``-1`` referes to the last index of a dimension.
- * @param ends Indexes for each dimension of the last element.
- * Can be a negative value. Negative values start their reference from the last index.
- * ``-1`` referes to the last index of a dimension.
- * @param axes Dimensions for which start/end indexes apply. Not specifying a dimensions
- * means the whole dimensions is extracted.
  * @param name Name of the Operator.
  * @return std::shared_ptr<Node> A Node containing the Operator.
  */
-inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts,
-                                   const std::vector<std::int64_t> ends,
-                                   const std::vector<std::int64_t> axes,
-                                   const std::string &name = "") {
-    // FIXME: properly handle default w&b initialization in every cases
-    return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
+inline std::shared_ptr<Node> Slice(const std::string &name = "") {
+    return std::make_shared<Node>(std::make_shared<Slice_Op>(), name);
 }
 }  // namespace Aidge
 
-namespace {
-template <>
-const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" };
-}
-
 #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
diff --git a/python_binding/operator/pybind_Gather.cpp b/python_binding/operator/pybind_Gather.cpp
index 493c5c118..e999aa6ab 100644
--- a/python_binding/operator/pybind_Gather.cpp
+++ b/python_binding/operator/pybind_Gather.cpp
@@ -25,6 +25,6 @@ void init_Gather(py::module& m) {
     .def("attributes_name", &Gather_Op::staticGetAttrsName);
     declare_registrable<Gather_Op>(m, "GatherOp");
 
-    m.def("Gather", &Gather, py::arg("axis")=0, py::arg("name") = "");
+    m.def("Gather", &Gather, py::arg("axis") = 0, py::arg("name") = "");
 }
 }  // namespace Aidge
diff --git a/python_binding/operator/pybind_Slice.cpp b/python_binding/operator/pybind_Slice.cpp
index 3bb1b082c..45baa1d9a 100644
--- a/python_binding/operator/pybind_Slice.cpp
+++ b/python_binding/operator/pybind_Slice.cpp
@@ -22,6 +22,7 @@ void init_Slice(py::module& m) {
     .def("get_inputs_name", &Slice_Op::getInputsName)
     .def("get_outputs_name", &Slice_Op::getOutputsName);
     declare_registrable<Slice_Op>(m, "SliceOp");
-    m.def("Slice", &Slice, py::arg("starts"), py::arg("ends"), py::arg("axes"), py::arg("name") = "");
+
+    m.def("Slice", &Slice, py::arg("name") = "");
 }
 }  // namespace Aidge
diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp
index 6d2670695..3062895b7 100644
--- a/src/operator/Slice.cpp
+++ b/src/operator/Slice.cpp
@@ -8,17 +8,16 @@
  * SPDX-License-Identifier: EPL-2.0
  *
  ********************************************************************************/
-#include "aidge/operator/Slice.hpp"
-#include "aidge/utils/Types.h"
-#include "aidge/utils/ErrorHandling.hpp"
 
 #include <cassert>
 #include <cstddef>
+#include <cstdint>
 #include <string>
 #include <utility>
 #include <vector>
 
 #include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Slice.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/Types.h"
 
@@ -26,28 +25,50 @@ const std::string Aidge::Slice_Op::Type = "Slice";
 
 void Aidge::Slice_Op::computeOutputDims() {
     // check input have been associated
-    if (!getInput(0) || (getInput(0)->empty())) {
-        AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
+    if (!getInput(0) || !getInput(1) || !getInput(2) || !getInput(3)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
     }
 
-    const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
-    std::vector<DimSize_t> outDims = getInput(0)->dims();
-    for (std::size_t i = 0; i < nbAxes; ++i) {
-        // For each slice operation get the params and cast them to size_t
-        const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
-        const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
-        const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
-        const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims();
-        const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis];
-        const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis];
-
-        const std::size_t sliceLength = end - start + 1;
-        // Check if slice length is valid
-        if (sliceLength > getInput(0)->dims()[axis])
-        {
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
+    if((!getInput(0)->empty()) && (!getInput(1)->empty()) && (!getInput(2)->empty()) && (!getInput(3)->empty()))
+    {
+        const auto starts = mInputs[1]->getImpl()->rawPtr();
+        const auto ends = mInputs[2]->getImpl()->rawPtr();
+        const auto axes = mInputs[3]->getImpl()->rawPtr();
+        DimSize_t nbAxes = mInputs[1]->size();
+        std::vector<DimSize_t> outDims = getInput(0)->dims();
+        for (std::size_t i = 0; i < nbAxes; ++i) {
+            // For each slice operation get the params and cast them to size_t
+            std::size_t axis, start, end; //TODO find a better way to cast "starts", "ends" and "axes"
+            if (mInputs[1]->dataType() == DataType::Float32 && mInputs[2]->dataType() == DataType::Float32 && mInputs[3]->dataType() == DataType::Float32)
+            {
+                const float* axes_ = static_cast<float*>(axes);
+                axis = axes_[i] >= 0 ? static_cast<std::size_t>(axes_[i]) : static_cast<std::size_t>(axes_[i]) + getInput(0)->nbDims();
+                const float* starts_ = static_cast<float*>(starts);
+                start = starts_[i] >= 0 ? static_cast<std::size_t>(starts_[i]) : static_cast<std::size_t>(starts_[i]) + getInput(0)->dims()[axis];
+                const float* ends_ = static_cast<float*>(ends);
+                end = ends_[i] >= 0 ? static_cast<std::size_t>(ends_[i]) : static_cast<std::size_t>(ends_[i]) + getInput(0)->dims()[ends_[i]];
+            }
+            else if(mInputs[1]->dataType() == DataType::Int32 && mInputs[2]->dataType() == DataType::Int32 && mInputs[3]->dataType() == DataType::Int32)
+            {
+                const std::int32_t* axes_ = static_cast<std::int32_t*>(axes);
+                axis = axes_[i] >= 0 ? static_cast<std::size_t>(axes_[i]) : static_cast<std::size_t>(axes_[i]) + getInput(0)->nbDims();
+                const std::int32_t* starts_ = static_cast<std::int32_t*>(starts);
+                start = starts_[i] >= 0 ? static_cast<std::size_t>(starts_[i]) : static_cast<std::size_t>(starts_[i]) + getInput(0)->dims()[axis];
+                const std::int32_t* ends_ = static_cast<std::int32_t*>(ends);
+                end = ends_[i] >= 0 ? static_cast<std::size_t>(ends_[i]) : static_cast<std::size_t>(ends_[i]) + getInput(0)->dims()[ends_[i]];
+            }
+            else
+            {
+                AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice inputs type is not supported yet");
+            }
+            const std::size_t sliceLength = end - start;
+            // Check if slice length is valid
+            if (sliceLength > getInput(0)->dims()[axis])
+            {
+                AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
+            }
+            outDims[axis] = sliceLength;
         }
-        outDims[axis] = sliceLength;
+        mOutputs[0]->resize(outDims);
     }
-    mOutputs[0]->resize(outDims);
-}
+}
\ No newline at end of file
diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp
index 8e27fea58..7e08457bc 100644
--- a/src/recipes/HorizontalTiling.cpp
+++ b/src/recipes/HorizontalTiling.cpp
@@ -93,7 +93,10 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
         }
         std::vector<std::int64_t> usedDims(inputDimsEnd.size());
         std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
-        auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
+        Tensor(std::vector<std::size_t>({inputDimsStart.size()})); 
+        // TODO create producer nodes for the attributes
+        // auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
+        auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis]));
         slice -> addChild(newNode, 0, 0);
         newNode -> addChild(concat, 0, i);
 
-- 
GitLab