From 46e15f5545bd7d658fedaa44f405b69b3f449c23 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Mon, 20 Nov 2023 10:45:24 +0000
Subject: [PATCH] Remove 'Tensor' from 'Operator' class. Only keep 'Data'

- Add a common parent class to each Operator using Tensors: OperatorTensor
- Gather shared operator functions in OperatorTensor
- Add generic mInputs and mOutputs attributes for OperatorTensor
- Add an enum to identify the type of Data used by each Operator
- Change Inputs, DataInputs, Outputs for Inputs, Data, Attr, Outputs for less confusion
---
 include/aidge/operator/Operator.hpp       | 36 +++++++----
 include/aidge/operator/OperatorTensor.hpp | 79 ++++++++++++++---------
 src/operator/OperatorTensor.cpp           | 60 +++++++++--------
 3 files changed, 108 insertions(+), 67 deletions(-)

diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index e6e61ff7e..7fca2a46f 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -20,12 +20,16 @@
 
 #include "aidge/backend/OperatorImpl.hpp"
 #include "aidge/data/Data.hpp"
-#include "aidge/data/Tensor.hpp"
 #include "aidge/utils/Types.h"
 #include "aidge/hook/Hook.hpp"
 
 namespace Aidge {
 
+enum class OperatorType {
+    Data,
+    Tensor
+};
+
 class Operator : public std::enable_shared_from_this<Operator> {
 protected:
     std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator
@@ -33,27 +37,28 @@ protected:
 
 private:
     std::string mType;
+    const OperatorType mOperatorType;
     const IOIndex_t mNbData;
-    const IOIndex_t mNbAttr;
+    const IOIndex_t mNbParam;
     const IOIndex_t mNbOut;
 
 public:
     Operator() = delete;
-    Operator(const char* type, const IOIndex_t nbData, const IOIndex_t nbAttr, const IOIndex_t nbOut)
+    Operator(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data)
     : mType(type),
+      mOperatorType(operatorType),
       mNbData(nbData),
-      mNbAttr(nbAttr),
+      mNbParam(nbParam),
       mNbOut(nbOut)
     {
         // ctor
     }
-    virtual std::shared_ptr<Operator> clone() const = 0;
-    virtual ~Operator();
 
     Operator(const Operator& op):
         std::enable_shared_from_this<Operator>(),
+        mOperatorType(op.mOperatorType),
         mNbData(op.mNbData),
-        mNbAttr(op.mNbAttr),
+        mNbParam(op.mNbParam),
         mNbOut(op.mNbOut)
     {
         mType = op.mType;
@@ -63,9 +68,12 @@ public:
         // Hooks are not copied.
     }
 
+    virtual ~Operator();
+
 public:
+    virtual std::shared_ptr<Operator> clone() const = 0;
 
-    virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>* data) = 0;
+    virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
     /**
      * @brief For a given output feature area, compute the associated receptive
      * field for each data input.
@@ -92,7 +100,7 @@ public:
 ///////////////////////////////////////////////////////
 
     virtual void setBackend(const std::string& name) = 0;
-    virtual void setDatatype(const DataType& datatype) = 0;
+    virtual void setDataType(const DataType& dataType) const = 0;
 
     /**
      * @brief Set the a new OperatorImpl to the Operator
@@ -135,13 +143,17 @@ public:
 //        INNER
 ///////////////////////////////////////////////////////
 
-    std::string type() const {
+    inline std::string type() const noexcept {
         return mType;
     }
 
-    inline IOIndex_t nbInputs() const noexcept { return mNbData+mNbAttr; };
+    inline OperatorType operatorType() const noexcept{
+        return mOperatorType;
+    }
+
+    inline IOIndex_t nbInputs() const noexcept { return mNbData+mNbParam; };
     inline IOIndex_t nbData() const noexcept { return mNbData; };
-    inline IOIndex_t nbAttr() const noexcept { return mNbAttr; };
+    inline IOIndex_t nbParam() const noexcept { return mNbParam; };
     inline IOIndex_t nbOutputs() const noexcept { return mNbOut; };
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp
index ecc2e40ee..b8c180e6d 100644
--- a/include/aidge/operator/OperatorTensor.hpp
+++ b/include/aidge/operator/OperatorTensor.hpp
@@ -18,63 +18,84 @@
 
 #include "aidge/backend/OperatorImpl.hpp"
 #include "aidge/data/Tensor.hpp"
-#include "aidge/utils/Types.h"
 #include "aidge/operator/Operator.hpp"
+#include "aidge/utils/Types.h"
 
 namespace Aidge {
 
 class OperatorTensor : public Operator {
-/* TODO: Add an attribute specifying the type of Data used by the Operator.
- * The same way ``Type`` attribute specifies the type of Operator. Hence this
- * attribute could be checked in the forwardDims function to assert Operators
- * being used work with Tensors and cast them to OpertorTensor instead of
- * Operator.
- */
-/* TODO: Maybe change type attribute of Data object by an enum instead of an
- * array of char. Faster comparisons.
- */
+    /* TODO: Add an attribute specifying the type of Data used by the Operator.
+     * The same way ``Type`` attribute specifies the type of Operator. Hence this
+     * attribute could be checked in the forwardDims function to assert Operators
+     * being used work with Tensors and cast them to OpertorTensor instead of
+     * Operator.
+     */
+    /* TODO: Maybe change type attribute of Data object by an enum instead of an
+     * array of char. Faster comparisons.
+     */
 protected:
-    std::vector<std::shared_ptr<Tensor>*> mInputs;
+    std::vector<std::shared_ptr<Tensor>> mInputs;
     std::vector<std::shared_ptr<Tensor>> mOutputs;
 
 public:
-    OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbAttr, const IOIndex_t nbOut)
-    : Operator(type, nbData, nbAttr, nbOut),
-      mInputs(std::vector<std::shared_ptr<Tensor>*>(nbData + nbAttr, nullptr)),
-      mOutputs(std::vector<std::shared_ptr<Tensor>>(nbOut))
-    {
+    OperatorTensor() = delete;
+
+    OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam,
+                   const IOIndex_t nbOut)
+        : Operator(type, nbData, nbParam, nbOut, OperatorType::Tensor),
+          mInputs(std::vector<std::shared_ptr<Tensor>>(nbData + nbParam, nullptr)),
+          mOutputs(std::vector<std::shared_ptr<Tensor>>(nbOut)) {
         for (std::size_t i = 0; i < static_cast<std::size_t>(nbOut); ++i) {
             mOutputs[i] = std::make_shared<Tensor>();
+            mOutputs[i]->setDataType(DataType::Float32);
+        }
+    }
+
+    OperatorTensor(const OperatorTensor& other)
+        : Operator(other),
+          mInputs(std::vector<std::shared_ptr<Tensor>>(other.nbInputs(), nullptr)),
+          mOutputs(std::vector<std::shared_ptr<Tensor>>(other.nbOutputs())) {
+        for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) {
+            mOutputs[i] = std::make_shared<Tensor>(other.output(i));
+            // datatype already copied
         }
     }
 
+    virtual ~OperatorTensor() = default;
+
 public:
     ///////////////////////////////////////////////////
-    virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>* data) override;
+    virtual void associateInput(const IOIndex_t inputIdx,
+                                const std::shared_ptr<Data>& data) override;
     ///////////////////////////////////////////////////
 
     ///////////////////////////////////////////////////
     // Tensor access
     // input management
-    std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const;
-    Tensor& input(const IOIndex_t inputIdx) const;
-    std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
+    const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
+    inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); }
+    inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
+        return std::static_pointer_cast<Data>(getInput(inputIdx));
+    }
 
-    //output management
-    std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const;
-    Tensor& output(const IOIndex_t outputIdx) const;
-    std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final;
+    // output management
+    const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
+    inline Tensor& output(const IOIndex_t outputIdx) const {
+        return *getOutput(outputIdx);
+    }
+    inline std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final {
+        return std::static_pointer_cast<Data>(getOutput(outputIdx));
+    }
     ///////////////////////////////////////////////////
 
     ///////////////////////////////////////////////////
     // Tensor dimensions
-    virtual void computeOutputDims() = 0;
+    virtual void computeOutputDims();
     virtual bool outputDimsForwarded() const;
     ///////////////////////////////////////////////////
 
-    virtual void setDataType(const DataType& dataType) const;
-
+    virtual void setDataType(const DataType& dataType) const override;
 };
-} // namespace Aidge
+}  // namespace Aidge
 
-#endif // AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
\ No newline at end of file
+#endif  // AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
\ No newline at end of file
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
index e5fdada1d..1594548ce 100644
--- a/src/operator/OperatorTensor.cpp
+++ b/src/operator/OperatorTensor.cpp
@@ -19,50 +19,58 @@
 #include "aidge/utils/ErrorHandling.hpp"
 
 
-void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>* data) {
+void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
     if (inputIdx >= nbInputs()) {
-        AIDGE_ASSERT("%s Operator has %hu inputs", type().c_str(), nbInputs());
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
     }
-    if (strcmp((*data)->type(), Tensor::Type) != 0) {
-        printf("input data must be of Tensor type");
-        exit(-1);
+    if (strcmp((data)->type(), Tensor::Type) != 0) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Input data must be of Tensor type");
     }
-    mInputs[inputIdx] = &std::dynamic_pointer_cast<Tensor>(*data);
+    mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
 }
 
-std::shared_ptr<Aidge::Tensor> Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const {
+const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const {
     if (inputIdx >= nbInputs()) {
-        AIDGE_ASSERT("%s Operator has %hu inputs", type().c_str(), nbInputs());
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
     }
-    return *mInputs[inputIdx];
+    return mInputs[inputIdx];
 }
 
-Aidge::Tensor& Aidge::OperatorTensor::input(const Aidge::IOIndex_t inputIdx) const {
-    return *getInput(inputIdx);
-}
-
-std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawInput(const Aidge::IOIndex_t inputIdx) const {
-    return std::static_pointer_cast<Data>(getInput(inputIdx));
-}
 
-
-std::shared_ptr<Aidge::Tensor> Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
+const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
     if (outputIdx >= nbOutputs()) {
-        AIDGE_ASSERT("%s Operator has %hu outputs", type().c_str(), nbOutputs());
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
     }
     return mOutputs[outputIdx];
 }
 
-Aidge::Tensor& Aidge::OperatorTensor::output(const Aidge::IOIndex_t outputIdx) const {
-    return *getOutput(outputIdx);
-}
 
-std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawOutput(const Aidge::IOIndex_t outputIdx) const {
-    return std::static_pointer_cast<Data>(getOutput(outputIdx));
+void Aidge::OperatorTensor::computeOutputDims() {
+    // check inputs have been associated
+    bool associated = (nbInputs() > 0); // do not compute anything if no input
+    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
+        if (!getInput(i)) {
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
+        }
+        associated &= !(getInput(i)->empty());
+    }
+    if (associated) {
+        const auto expectedDims =  getInput(0)->dims();
+        for (std::size_t i = 1; i < nbInputs(); ++i) {
+            if (expectedDims != getInput(i)->dims()) {
+                AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator's inputs should have the same dimensions");
+            }
+        }
+        mOutputs[0]->resize(expectedDims);
+    }
 }
 
 bool Aidge::OperatorTensor::outputDimsForwarded() const {
     bool forwarded = true;
+    // check both inputs and outputs have been filled
+    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
+        forwarded &= !(getInput(i)->empty());
+    }
     for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
         forwarded &= !(getOutput(i)->empty());
     }
@@ -71,9 +79,9 @@ bool Aidge::OperatorTensor::outputDimsForwarded() const {
 
 void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
     for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
-        getOutput(i)->setDatatype(dataType);
+        getOutput(i)->setDataType(dataType);
     }
     for (IOIndex_t i = 0; i < nbInputs(); ++i) {
-        getInput(i)->setDatatype(dataType);
+        getInput(i)->setDataType(dataType);
     }
 }
\ No newline at end of file
-- 
GitLab