From fbd368946f2a2cf86c832d04991ca1f2b830a36e Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Thu, 30 Nov 2023 16:34:07 +0100
Subject: [PATCH] move computeDims to cpp

---
 include/aidge/operator/Gather.hpp    | 12 +------
 include/aidge/operator/Reshape.hpp   | 14 +-------
 include/aidge/operator/Slice.hpp     | 18 +---------
 include/aidge/operator/Transpose.hpp |  1 -
 src/operator/Gather.cpp              | 38 +++++++++++++++++++++
 src/operator/Reshape.cpp             | 47 ++++++++++++++++++++++++++
 src/operator/Slice.cpp               | 49 ++++++++++++++++++++++++++++
 7 files changed, 137 insertions(+), 42 deletions(-)
 create mode 100644 src/operator/Gather.cpp
 create mode 100644 src/operator/Reshape.cpp
 create mode 100644 src/operator/Slice.cpp

diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp
index 6579331ca..ba7d745fa 100644
--- a/include/aidge/operator/Gather.hpp
+++ b/include/aidge/operator/Gather.hpp
@@ -68,17 +68,7 @@ public:
         return std::make_shared<Gather_Op>(*this);
     }
 
-    void computeOutputDims() override final {
-        if (!mInputs.empty() && !mInputs[0]->empty() && mInputs[1]->nbDims()==2)
-        {
-            std::vector<DimSize_t> outDims = mInputs[0]->dims();
-            std::vector<DimSize_t> indexesDims = mInputs[1]->dims();
-            int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size();
-            outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
-            outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end());
-            mOutputs[0]->resize(outDims);
-        }
-    }
+    void computeOutputDims() override final;
 
     void setBackend(const std::string& name) override {
         mImpl = Registrar<Gather_Op>::create(name)(*this);
diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp
index 81cc7cd19..2d9372c4e 100644
--- a/include/aidge/operator/Reshape.hpp
+++ b/include/aidge/operator/Reshape.hpp
@@ -52,19 +52,7 @@ public:
         return std::make_shared<Reshape_Op>(*this);
     }
 
-    void computeOutputDims() override final {
-        if (!mInputs[0]->empty() && !mInputs[1]->empty())
-        {
-            std::vector<DimSize_t> outDims;
-            int* shapeElem = static_cast<int*>(mInputs[1]->getImpl()->rawPtr());
-            for(std::size_t i=0; i<mInputs[1]->size(); ++i)
-            {
-                outDims.push_back(shapeElem[i]);
-            }
-            mOutputs[0]->resize(outDims);
-        }
-    }
-
+    void computeOutputDims() override final;
 
     void setBackend(const std::string& name) override {
         mImpl = Registrar<Reshape_Op>::create(name)(*this);
diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp
index d1e000723..e98714b02 100644
--- a/include/aidge/operator/Slice.hpp
+++ b/include/aidge/operator/Slice.hpp
@@ -50,23 +50,7 @@ public:
      */
     std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); }
 
-    void computeOutputDims() override final {
-        if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty())
-        {
-            DimSize_t nbAxes = mInputs[1]->dims()[0];
-            const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr());
-            const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr());
-            const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr());
-            std::vector<DimSize_t> outDims = mInputs[0]->dims();
-            for(std::size_t i=0; i<nbAxes;++i)
-            {
-                std::size_t axis = axes[i]>=0?axes[i]:axes[i]+mInputs[0]->nbDims();
-                outDims[axis] = ends[i] - starts[i] + 1;
-            }
-            mOutputs[0]->resize(outDims);
-        }
-    }
-
+    void computeOutputDims() override final;
 
     void setBackend(const std::string& name) override {
         mImpl = Registrar<Slice_Op>::create(name)(*this);
diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp
index 6248dcfc5..8bf5f17ab 100644
--- a/include/aidge/operator/Transpose.hpp
+++ b/include/aidge/operator/Transpose.hpp
@@ -68,7 +68,6 @@ class Transpose_Op : public OperatorTensor,
     }
 
     void computeOutputDims() override final {
-        printf("************** nbIn %d \n", this->nbInputs());
         if (!getInput(0)->empty()) {
             auto attr = (this)->getStaticAttributes();
             const std::array<DimSize_t, DIM>& outDimsOrder = static_cast<const std::array<DimSize_t, DIM>&>(std::get<0>(attr));
diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp
new file mode 100644
index 000000000..26a334bb0
--- /dev/null
+++ b/src/operator/Gather.cpp
@@ -0,0 +1,38 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <cstddef>
+#include <vector>
+#include <utility>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Gather.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::Gather_Op::computeOutputDims() {
+    // check inputs have been associated
+    if (!getInput(0) || !getInput(1)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
+    }
+
+    if (getInput(1)->nbDims()!=2){
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor");
+    }
+
+    std::vector<DimSize_t> outDims = getInput(0)->dims();
+    std::vector<DimSize_t> indexesDims = getInput(1)->dims();
+    int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size();
+    outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
+    outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end());
+    mOutputs[0]->resize(outDims);
+}
\ No newline at end of file
diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp
new file mode 100644
index 000000000..f32e8b5af
--- /dev/null
+++ b/src/operator/Reshape.cpp
@@ -0,0 +1,47 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <cstddef>
+#include <vector>
+#include <utility>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Reshape.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::Reshape_Op::computeOutputDims() {
+    // check inputs have been associated
+    if (!getInput(0) || !getInput(1)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
+    }
+
+    std::vector<DimSize_t> outDims;
+    std::size_t outSize = 1;
+    int* shapeElem = static_cast<int*>(getInput(1)->getImpl()->rawPtr());
+    for(std::size_t i=0; i<mInputs[1]->size(); ++i)
+    {
+        int dimSize = shapeElem[i];
+        if (dimSize < 1)
+        {
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input");
+        }
+        outDims.push_back(dimSize);
+        outSize *= dimSize;
+    }
+
+    if (getInput(0)->size() != outSize){
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input");
+    }
+
+    mOutputs[0]->resize(outDims);
+}
\ No newline at end of file
diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp
new file mode 100644
index 000000000..0495f96c5
--- /dev/null
+++ b/src/operator/Slice.cpp
@@ -0,0 +1,49 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <cstddef>
+#include <vector>
+#include <utility>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Slice.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::Slice_Op::computeOutputDims() {
+    // check inputs have been associated
+    if (!getInput(0) || !getInput(1) || !getInput(2) || !getInput(3)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
+    }
+
+    if (getInput(1)->nbDims()!=1){
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 1D Tensor");
+    }
+    if (getInput(2)->nbDims()!=1){
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Starts input must be a 1D Tensor");
+    }
+    if (getInput(3)->nbDims()!=1){
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Ends input must be a 1D Tensor");
+    }
+
+    DimSize_t nbAxes = getInput(1)->dims()[0];
+    const int* axes = static_cast<const int*>(getInput(1)->getImpl()->rawPtr());
+    const int* starts = static_cast<const int*>(getInput(2)->getImpl()->rawPtr());
+    const int* ends = static_cast<const int*>(getInput(3)->getImpl()->rawPtr());
+    std::vector<DimSize_t> outDims = getInput(0)->dims();
+    for(std::size_t i=0; i<nbAxes;++i)
+    {
+        std::size_t axis = axes[i]>=0?axes[i]:axes[i]+getInput(0)->nbDims();
+        outDims[axis] = ends[i] - starts[i] + 1;
+    }
+    mOutputs[0]->resize(outDims);
+}
\ No newline at end of file
-- 
GitLab