From 697c85330d9e4c85c1036bde8eb60c4dd3c88f5a Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 26 May 2024 17:00:06 +0200
Subject: [PATCH] Do not resize output in forward() as it is done in
 forwardDims()

---
 src/operator/Slice.cpp | 27 +++++++++------------------
 1 file changed, 9 insertions(+), 18 deletions(-)

diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp
index 8a9f5cbbf..e3ac4e774 100644
--- a/src/operator/Slice.cpp
+++ b/src/operator/Slice.cpp
@@ -36,33 +36,24 @@ void Aidge::Slice_OpImpl::forward() {
                  (op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
                  "start, end and axes arguments should be the same size.");
 
-    const std::size_t nbDims = op.getInput(0)->nbDims();
-
-    const std::vector<std::size_t>& inputDims = op.getInput(0)->dims();
-    auto outputDims = op.getInput(0)->dims();
+    const auto nbDims = op.getInput(0)->nbDims();
+    const auto& inputDims = op.getInput(0)->dims();
+    const auto& outputDims = op.getOutput(0)->dims();
 
     // compute index of the output's first element
-    // compute output dimension at the same time (may change between two forward calls)
     std::size_t beginning = 0;
     const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size();
     for (std::size_t i = 0; i < nbAxes; ++i) {
         // For each slice operation get the params and cast them to size_t
-        DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ?
+        const DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ?
                             static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) :
                             static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size()));
-        DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ?
+        const DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ?
                             static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) :
                             static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis]));
-        DimSize_t end = op.template getAttr<SliceAttr::Ends>()[i] >= 0 ?
-                        static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i]) :
-                        static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(inputDims[axis]));
         const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
         beginning += start * stridePostAxis;
-        const std::size_t sliceLength = end - start;
-        outputDims[axis] = sliceLength;
     }
-    op.getOutput(0)->resize(outputDims);
-
 
     // for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3}
     std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
@@ -195,16 +186,16 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) {
 
     AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute");
 
-    DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
+    const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
     std::vector<DimSize_t> outDims = getInput(0)->dims();
     for (std::size_t i = 0; i < nbAxes; ++i) {
-        DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
+        const DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
                         static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
                         static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims()));
-        DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
+        const DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
                             static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) :
                             static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
-        DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
+        const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
                         static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
                         static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
 
-- 
GitLab