From 19f7875c4c486badf9aa9194eff48a27d44b6199 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 21 Nov 2023 15:25:28 +0100
Subject: [PATCH] fix Slice outputDims

---
 include/aidge/operator/Slice.hpp | 32 ++++++++++----------------------
 1 file changed, 10 insertions(+), 22 deletions(-)

diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp
index ccdb66a1e..4a99045e2 100644
--- a/include/aidge/operator/Slice.hpp
+++ b/include/aidge/operator/Slice.hpp
@@ -67,7 +67,7 @@ public:
     }
 
     void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
-        assert(inputIdx < 4 && "operator Slice supports only 4 inputs");
+        assert(inputIdx < 4 && "Slice operator supports only 4 inputs");
         assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
         mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
     }
@@ -75,27 +75,15 @@ public:
     void computeOutputDims() override final {
         if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty())
         {
+            DimSize_t nbAxes = mInputs[1]->dims()[0];
             const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr());
             const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr());
             const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr());
-            DimSize_t nbAxes = mInputs[1]->dims()[0];
-            std::vector<DimSize_t> outDims;
-            for(std::size_t i=0; i<mInputs[0]->dims().size();++i)
+            std::vector<DimSize_t> outDims = mInputs[0]->dims();
+            for(std::size_t i=0; i<nbAxes;++i)
             {
-
-                const int* idxPos = std::find(axes, axes + nbAxes, static_cast<int>(i));
-                if(idxPos != (axes + nbAxes))
-                {
-                    // TODO make sure all indxes are positive before this
-                    size_t idx = static_cast<size_t>(*idxPos);
-                    int startVal = starts[idx];
-                    int endVal = ends[idx];
-                    outDims.push_back(endVal - startVal);
-                }
-                else
-                {
-                    outDims.push_back(mInputs[0]->dims()[i]);
-                }
+                std::size_t axis = axes[i]>=0?axes[i]:axes[i]+mInputs[0]->nbDims();
+                outDims[axis] = ends[i] - starts[i] + 1;
             }
             mOutput->resize(outDims);
         }
@@ -114,22 +102,22 @@ public:
 
 
     inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
-        assert((inputIdx < 4) && "Slice Operator has 4 inputs");
+        assert((inputIdx < 4) && "Slice operator has 4 inputs");
         return mInputs[inputIdx];
     }
     inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
-        assert((outputIdx == 0) && "Slice Operator has only 1 output");
+        assert((outputIdx == 0) && "Slice operator has only 1 output");
         (void) outputIdx; // avoid unused warning
         return mOutput;
     }
 
 
     std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
-        assert(inputIdx < 4 && "operator supports only 4 inputs");
+        assert(inputIdx < 4 && "Slice operator supports only 4 inputs");
         return std::static_pointer_cast<Data>(mInputs[inputIdx]);
     }
     std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
-        assert(outputIdx == 0 && "operator supports only 1 output");
+        assert(outputIdx == 0 && "Slice operator supports only 1 output");
         (void) outputIdx; // avoid unused warning
         return std::static_pointer_cast<Data>(mOutput);
     }
-- 
GitLab