From 1520a9a083afd2b5bf2d33025d4e16d4cbbce53a Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 10 Nov 2023 09:31:02 +0000
Subject: [PATCH] Propagate the new class hierarchy consequences to Conv and
 AvgPooling

---
 include/aidge/operator/AvgPooling.hpp | 79 ++++-----------------------
 include/aidge/operator/Conv.hpp       | 69 +++--------------------
 src/backend/OperatorImpl.cpp          | 12 ++--
 3 files changed, 25 insertions(+), 135 deletions(-)

diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp
index dfcd0d5b3..490782331 100644
--- a/include/aidge/operator/AvgPooling.hpp
+++ b/include/aidge/operator/AvgPooling.hpp
@@ -19,7 +19,7 @@
 
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/Node.hpp"
-#include "aidge/operator/Operator.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/operator/Producer.hpp"
 #include "aidge/utils/StaticAttributes.hpp"
 #include "aidge/utils/Registrar.hpp"
@@ -29,15 +29,11 @@ namespace Aidge {
 enum class AvgPoolingAttr { StrideDims, KernelDims };
 
 template <DimIdx_t DIM>
-class AvgPooling_Op : public Operator,
+class AvgPooling_Op : public OperatorTensor,
                 public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>,
                 public StaticAttributes<AvgPoolingAttr,
                                        std::array<DimSize_t, DIM>,
                                        std::array<DimSize_t, DIM>> {
-private:
-    // FIXME: change accessibility
-    std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
-    const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
 
 public:
     static constexpr const char *Type = "AvgPooling";
@@ -52,10 +48,10 @@ public:
 
     constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims,
                             const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1))
-        : Operator(Type),
+        : OperatorTensor(Type, 1, 0, 1),
           Attributes_(attr<AvgPoolingAttr::StrideDims>(stride_dims),
                       attr<AvgPoolingAttr::KernelDims>(kernel_dims)) {
-        setDatatype(DataType::Float32);
+        setDataType(DataType::Float32);
     }
 
     /**
@@ -63,12 +59,12 @@ public:
      * @param op Operator to copy.
      */
     AvgPooling_Op(const AvgPooling_Op<DIM>& op)
-        : Operator(Type),
+        : OperatorTensor(Type, 1, 0, 1),
           Attributes_(op),
           mOutput(std::make_shared<Tensor>(*op.mOutput))
     {
         // cpy-ctor
-        setDatatype(op.mOutput->dataType());
+        setDataType(op.mOutput->dataType());
         mImpl = op.mImpl ? Registrar<AvgPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr;
     }
 
@@ -80,64 +76,23 @@ public:
         return std::make_shared<AvgPooling_Op<DIM>>(*this);
     }
 
-    void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
-        assert(inputIdx < 1 && "operators supports only 3 inputs");
-        (void) inputIdx; // avoid unused warning
-        assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
-
-        mInput = std::dynamic_pointer_cast<Tensor>(data);
-    }
 
     void computeOutputDims() override final {
-        if (!mInput->empty()) {
+        if (!*mInputs[0]->empty()) {
             std::array<DimSize_t, DIM + 2> outputDims = {};
 
             for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) {
                 outputDims[dim+2] = 1 + static_cast<DimSize_t>(
-                                            std::floor(static_cast<float>(mInput->dims()[dim+2] -
+                                            std::floor(static_cast<float>(*mInputs[0]->dims()[dim+2] -
                                                                     this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) /
                                             static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim])));
             }
-            outputDims[1] = mInput->dims()[1];
-            outputDims[0] = mInput->dims()[0];
-            mOutput->resize(outputDims);
+            outputDims[1] = *mInputs[0]->dims()[1];
+            outputDims[0] = *mInputs[0]->dims()[0];
+            mOutputs[0]->resize(outputDims);
         }
     }
 
-    bool outputDimsForwarded() const override final { return !(mOutput->empty()); }
-
-
-    inline Tensor& input(const IOIndex_t inputIdx) const override final {
-        assert(inputIdx == 0 && "operators supports only 1 inputs");
-        (void) inputIdx; // avoid unused warning
-        return *(mInput.get());
-    }
-    inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
-
-
-    inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
-        assert(inputIdx == 0 && "AvgPooling Operators supports only 1 inputs");
-        (void) inputIdx; // avoid unused warning
-        return mInput;
-    }
-    inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
-        assert(outputIdx == 0 && "AvgPooling Operators has only 1 outputs");
-        (void) outputIdx; // avoid unused warning
-        return mOutput;
-    }
-
-
-    std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
-        assert(inputIdx == 0 && "operators supports only 1 inputs");
-        (void) inputIdx; // avoid unused warning
-        return std::static_pointer_cast<Data>(mInput);
-    }
-    std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
-        assert(outputIdx == 0 && "operator supports only 1 output");
-        (void) outputIdx; // avoid unused warning
-        return std::static_pointer_cast<Data>(mOutput);
-    }
-
 
     void setBackend(const std::string &name) override {
         mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this);
@@ -147,16 +102,6 @@ public:
         mInput->setBackend(name);
     }
 
-    void setDatatype(const DataType &datatype) override {
-        mOutput->setDatatype(datatype);
-
-        // FIXME: temporary workaround
-        mInput->setDatatype(datatype);
-    }
-
-    inline IOIndex_t nbInputs() const noexcept override final { return 1; }
-    inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
-    inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
     static const std::vector<std::string> getInputsName(){
         return {"data_input"};
     }
@@ -190,4 +135,4 @@ const char *const EnumStrings<Aidge::AvgPoolingAttr>::data[] = {"StrideDims",
                                                           "KernelDims"};
 }
 
-#endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */
+#endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */
\ No newline at end of file
diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp
index b1e3e34b0..62f2446f3 100644
--- a/include/aidge/operator/Conv.hpp
+++ b/include/aidge/operator/Conv.hpp
@@ -19,7 +19,7 @@
 
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/Node.hpp"
-#include "aidge/operator/Operator.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/operator/Producer.hpp"
 #include "aidge/utils/StaticAttributes.hpp"
 #include "aidge/utils/Registrar.hpp"
@@ -29,17 +29,12 @@ namespace Aidge {
 enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims };
 
 template <DimIdx_t DIM>
-class Conv_Op : public Operator,
+class Conv_Op : public OperatorTensor,
                 public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
                 public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t,
                                        DimSize_t, std::array<DimSize_t, DIM>> {
-public:
-    // FIXME: change accessibility
-    std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(),
-                                                      std::make_shared<Tensor>()};
-    const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
 
-   public:
+public:
     static constexpr const char *Type = "Conv";
 
     Conv_Op() = delete;
@@ -54,13 +49,13 @@ public:
                       const std::array<DimSize_t, DIM> &kernel_dims,
                       const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
                       const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
-        : Operator(Type),
+        : OperatorTensor(Type, 1, 2, 1),
           Attributes_(attr<ConvAttr::StrideDims>(stride_dims),
                       attr<ConvAttr::DilationDims>(dilation_dims),
                       attr<ConvAttr::InChannels>(in_channels),
                       attr<ConvAttr::OutChannels>(out_channels),
                       attr<ConvAttr::KernelDims>(kernel_dims)) {
-        setDatatype(DataType::Float32);
+        setDataType(DataType::Float32);
     }
 
     /**
@@ -68,12 +63,12 @@ public:
      * @param op Operator to copy.
      */
     Conv_Op(const Conv_Op<DIM>& op)
-        : Operator(Type),
+        : OperatorTensor(Type, 1, 2, 1),
           Attributes_(op),
           mOutput(std::make_shared<Tensor>(*op.mOutput))
     {
         // cpy-ctor
-        setDatatype(op.mOutput->dataType());
+        setDataType(op.mOutput->dataType());
         mImpl = op.mImpl ? Registrar<Conv_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr;
     }
 
@@ -98,13 +93,6 @@ public:
 
     // }
 
-    void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
-        assert(inputIdx < 3 && "operators supports only 3 inputs");
-        assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
-
-        mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
-    }
-
     void computeOutputDims() override final {
         if (!mInputs[0]->empty()) {
             std::array<DimSize_t, DIM + 2> outputDims = {};
@@ -125,37 +113,6 @@ public:
         }
     }
 
-    bool outputDimsForwarded() const override final { return !(mOutput->empty()); }
-
-
-    inline Tensor& input(const IOIndex_t inputIdx) const override final {
-        assert(inputIdx < 3 && "operators supports only 3 inputs");
-        return *(mInputs[inputIdx].get()); }
-    inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
-
-
-    inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
-        assert(inputIdx < 3 && "Conv Operators supports only 3 inputs");
-        return mInputs[inputIdx];
-    }
-    inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
-        assert((outputIdx == 0) && "Conv Operator has only 1 output");
-        (void) outputIdx; // avoid unused warning
-        return mOutput;
-    }
-
-
-    std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
-        assert(inputIdx < 3 && "operators supports only 3 inputs");
-        return std::static_pointer_cast<Data>(mInputs[inputIdx]);
-    }
-    std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
-        assert(outputIdx == 0 && "operator supports only 1 output");
-        (void) outputIdx; // avoid unused warning
-        return std::static_pointer_cast<Data>(mOutput);
-    }
-
-
     void setBackend(const std::string &name) override {
         mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
         mOutput->setBackend(name);
@@ -165,18 +122,6 @@ public:
         mInputs[2]->setBackend(name);
     }
 
-    void setDatatype(const DataType &datatype) override {
-        mOutput->setDatatype(datatype);
-
-        // FIXME: temporary workaround
-        mInputs[0]->setDatatype(datatype);
-        mInputs[1]->setDatatype(datatype);
-        mInputs[2]->setDatatype(datatype);
-    }
-
-    inline IOIndex_t nbInputs() const noexcept override final { return 3; }
-    inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
-    inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
     static const std::vector<std::string> getInputsName(){
         return {"data_input", "weight", "bias"};
     }
diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp
index 166754cc9..b76bf3336 100644
--- a/src/backend/OperatorImpl.cpp
+++ b/src/backend/OperatorImpl.cpp
@@ -25,25 +25,25 @@ Aidge::OperatorImpl::OperatorImpl(const Operator& op):
 }
 
 Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
-    assert(mOp.getInput(inputIdx) && "requires valid input");
+    assert(mOp.getRawInput(inputIdx) && "requires valid input");
 
     // Requires the whole tensor by default
-    return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size();
+    return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
 }
 
 Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
-    assert(mOp.getInput(inputIdx) && "requires valid input");
+    assert(mOp.getRawInput(inputIdx) && "requires valid input");
 
     // Protect the whole tensor by default
-    return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size();
+    return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
 }
 
 Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
                                                          const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
-    assert(mOp.getOutput(outputIdx) && "requires valid output");
+    assert(mOp.getRawOutput(outputIdx) && "requires valid output");
 
     // Requires the whole tensor by default, regardless of available data on inputs
-    return std::static_pointer_cast<Tensor>(mOp.getOutput(outputIdx))->size();
+    return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size();
 }
 
 Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
-- 
GitLab