From e1dbf501bf7343d17a4c517201ee80212bf0999e Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 22 Nov 2023 16:54:57 +0000
Subject: [PATCH] [Fix] Add, Concat and Slice implementations

---
 .../aidge/backend/cpu/operator/AddImpl.hpp    |  7 +--
 .../aidge/backend/cpu/operator/ConcatImpl.hpp |  7 +--
 .../aidge/backend/cpu/operator/SliceImpl.hpp  | 57 ++++++-------------
 3 files changed, 18 insertions(+), 53 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/AddImpl.hpp b/include/aidge/backend/cpu/operator/AddImpl.hpp
index 5ec33e97..806bbb02 100644
--- a/include/aidge/backend/cpu/operator/AddImpl.hpp
+++ b/include/aidge/backend/cpu/operator/AddImpl.hpp
@@ -31,13 +31,8 @@ class AddImplBackward_cpu
 
 
 class AddImpl_cpu : public OperatorImpl {
-private:
-    const Add_Op& mOp;
-    std::vector<NbElts_t> mNbConsumedData;
-    std::array<NbElts_t, 1> mNbProducedData = {};
-
 public:
-    AddImpl_cpu(const Add_Op& op) : mOp(op), mNbConsumedData(std::vector<NbElts_t>(op.nbInputs())) {}
+    AddImpl_cpu(const Add_Op& op) : OperatorImpl(op) {}
 
     static std::unique_ptr<AddImpl_cpu> create(const Add_Op& op) {
         return std::make_unique<AddImpl_cpu>(op);
diff --git a/include/aidge/backend/cpu/operator/ConcatImpl.hpp b/include/aidge/backend/cpu/operator/ConcatImpl.hpp
index 880a2e66..a5e0c56e 100644
--- a/include/aidge/backend/cpu/operator/ConcatImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ConcatImpl.hpp
@@ -39,13 +39,8 @@ class ConcatImplBackward_cpu
 
 
 class ConcatImpl_cpu : public OperatorImpl {
-private:
-    const Concat_Op& mOp;
-    std::vector<NbElts_t> mNbConsumedData;
-    std::array<NbElts_t, 1> mNbProducedData = {};
-
 public:
-    ConcatImpl_cpu(const Concat_Op& op) : mOp(op), mNbConsumedData(std::vector<NbElts_t>(op.nbInputs())) {}
+    ConcatImpl_cpu(const Concat_Op& op) : OperatorImpl(op) {}
 
     static std::unique_ptr<ConcatImpl_cpu> create(const Concat_Op& op) {
         return std::make_unique<ConcatImpl_cpu>(op);
diff --git a/include/aidge/backend/cpu/operator/SliceImpl.hpp b/include/aidge/backend/cpu/operator/SliceImpl.hpp
index dddab386..69dd88bb 100644
--- a/include/aidge/backend/cpu/operator/SliceImpl.hpp
+++ b/include/aidge/backend/cpu/operator/SliceImpl.hpp
@@ -43,13 +43,8 @@ class SliceImplBackward_cpu
 
 template <DimIdx_t DIM>
 class SliceImpl_cpu : public OperatorImpl {
-   private:
-    const Slice_Op<DIM>& mOp;
-    std::array<NbElts_t, 1> mNbConsumedData;
-    std::array<NbElts_t, 1> mNbProducedData;
-
    public:
-    SliceImpl_cpu(const Slice_Op<DIM>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
+    SliceImpl_cpu(const Slice_Op<DIM>& op) : OperatorImpl(op) {}
 
     static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) {
         return std::make_unique<SliceImpl_cpu<DIM>>(op);
@@ -57,10 +52,10 @@ class SliceImpl_cpu : public OperatorImpl {
 
    public:
     NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final {
-        assert(mOp.getInput(0) && "requires valid input");
+        assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
 
         // Requires the whole tensors
-        const auto& inputDims = mOp.getInput(0)->dims();
+        const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims();
 
         return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
                                std::multiplies<NbElts_t>());
@@ -70,7 +65,7 @@ class SliceImpl_cpu : public OperatorImpl {
                                const std::vector<DimSize_t>& inputsSize) const override final {
         (void)outputIdx;
         (void)inputsSize;
-        const auto& outputDims = mOp.getOutput(0)->dims();
+        const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
         return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
                                std::multiplies<NbElts_t>());
     }
@@ -89,17 +84,17 @@ class SliceImpl_cpu : public OperatorImpl {
 
     void forward() {
         // FIXME: uncomment the following code once memory handling will work
-        assert(mOp.getInput(0) && "missing input #0");
+        assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
 
         // Find the correct kernel type
         auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create(
-                {mOp.getInput(0)->dataType()});
+                {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
 
         // Call kernel
-        kernelFunc(mOp.getInput(0)->template dims<DIM>(),
-                   std::get<1>(mOp.getStaticAttributes()),
-                   mOp.getInput(0)->getImpl()->rawPtr(),
-                   mOp.getOutput(0)->getImpl()->rawPtr()
+        kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<DIM>(),
+                   std::get<1>(std::static_pointer_cast<const Slice_Op<DIM>&>(mOp).getStaticAttributes()),
+                   std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
+                   std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
                 );
 
         // each input is consumed by the minimum amount for a forward pass
@@ -115,19 +110,14 @@ class SliceImpl_cpu : public OperatorImpl {
 
 template <>
 class SliceImpl_cpu<1> : public OperatorImpl {
-   private:
-    const Slice_Op<1>& mOp;
-    std::array<NbElts_t, 1> mNbConsumedData;
-    std::array<NbElts_t, 1> mNbProducedData;
-
-   public:
-    SliceImpl_cpu(const Slice_Op<1>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
+public:
+    SliceImpl_cpu(const Slice_Op<1>& op) : OperatorImpl(op) {}
 
     static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) {
         return std::make_unique<SliceImpl_cpu<1>>(op);
     }
 
-   public:
+public:
     NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
     NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
     NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
@@ -144,13 +134,8 @@ class SliceImpl_cpu<1> : public OperatorImpl {
 
 template <>
 class SliceImpl_cpu<2> : public OperatorImpl {
-   private:
-    const Slice_Op<2>& mOp;
-    std::array<NbElts_t, 1> mNbConsumedData;
-    std::array<NbElts_t, 1> mNbProducedData;
-
    public:
-    SliceImpl_cpu(const Slice_Op<2>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
+    SliceImpl_cpu(const Slice_Op<2>& op) : OperatorImpl(op) {}
 
     static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) {
         return std::make_unique<SliceImpl_cpu<2>>(op);
@@ -173,13 +158,8 @@ class SliceImpl_cpu<2> : public OperatorImpl {
 
 template <>
 class SliceImpl_cpu<3> : public OperatorImpl {
-   private:
-    const Slice_Op<3>& mOp;
-    std::array<NbElts_t, 1> mNbConsumedData;
-    std::array<NbElts_t, 1> mNbProducedData;
-
    public:
-    SliceImpl_cpu(const Slice_Op<3>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
+    SliceImpl_cpu(const Slice_Op<3>& op) : OperatorImpl(op) {}
 
     static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) {
         return std::make_unique<SliceImpl_cpu<3>>(op);
@@ -202,13 +182,8 @@ class SliceImpl_cpu<3> : public OperatorImpl {
 
 template <>
 class SliceImpl_cpu<4> : public OperatorImpl {
-   private:
-    const Slice_Op<4>& mOp;
-    std::array<NbElts_t, 1> mNbConsumedData;
-    std::array<NbElts_t, 1> mNbProducedData;
-
    public:
-    SliceImpl_cpu(const Slice_Op<4>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
+    SliceImpl_cpu(const Slice_Op<4>& op) : OperatorImpl(op) {}
 
     static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) {
         return std::make_unique<SliceImpl_cpu<4>>(op);
-- 
GitLab