From 105f3960be4c79a38fcf77d984bae10af59b9d4b Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 10 Nov 2023 09:28:21 +0000
Subject: [PATCH] [Add] intermediate class to handle Operators using Tensors

---
 include/aidge/operator/OperatorTensor.hpp | 80 +++++++++++++++++++++++
 src/operator/Operator.cpp                 |  2 +-
 src/operator/OperatorTensor.cpp           | 79 ++++++++++++++++++++++
 3 files changed, 160 insertions(+), 1 deletion(-)
 create mode 100644 include/aidge/operator/OperatorTensor.hpp
 create mode 100644 src/operator/OperatorTensor.cpp

diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp
new file mode 100644
index 000000000..ecc2e40ee
--- /dev/null
+++ b/include/aidge/operator/OperatorTensor.hpp
@@ -0,0 +1,80 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
+#define AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/operator/Operator.hpp"
+
+namespace Aidge {
+
+class OperatorTensor : public Operator {
+/* TODO: Add an attribute specifying the type of Data used by the Operator.
+ * The same way ``Type`` attribute specifies the type of Operator. Hence this
+ * attribute could be checked in the forwardDims function to assert Operators
+ * being used work with Tensors and cast them to OpertorTensor instead of
+ * Operator.
+ */
+/* TODO: Maybe change type attribute of Data object by an enum instead of an
+ * array of char. Faster comparisons.
+ */
+protected:
+    std::vector<std::shared_ptr<Tensor>*> mInputs;
+    std::vector<std::shared_ptr<Tensor>> mOutputs;
+
+public:
+    OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbAttr, const IOIndex_t nbOut)
+    : Operator(type, nbData, nbAttr, nbOut),
+      mInputs(std::vector<std::shared_ptr<Tensor>*>(nbData + nbAttr, nullptr)),
+      mOutputs(std::vector<std::shared_ptr<Tensor>>(nbOut))
+    {
+        for (std::size_t i = 0; i < static_cast<std::size_t>(nbOut); ++i) {
+            mOutputs[i] = std::make_shared<Tensor>();
+        }
+    }
+
+public:
+    ///////////////////////////////////////////////////
+    virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>* data) override;
+    ///////////////////////////////////////////////////
+
+    ///////////////////////////////////////////////////
+    // Tensor access
+    // input management
+    std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const;
+    Tensor& input(const IOIndex_t inputIdx) const;
+    std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
+
+    //output management
+    std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const;
+    Tensor& output(const IOIndex_t outputIdx) const;
+    std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final;
+    ///////////////////////////////////////////////////
+
+    ///////////////////////////////////////////////////
+    // Tensor dimensions
+    virtual void computeOutputDims() = 0;
+    virtual bool outputDimsForwarded() const;
+    ///////////////////////////////////////////////////
+
+    virtual void setDataType(const DataType& dataType) const;
+
+};
+} // namespace Aidge
+
+#endif // AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
\ No newline at end of file
diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp
index f6143f125..a8f2fe467 100644
--- a/src/operator/Operator.cpp
+++ b/src/operator/Operator.cpp
@@ -52,7 +52,7 @@ void Aidge::Operator::forward() {
         mImpl->forward();
         runHooks();
     } else {
-        printf("backward: No implementation is linked.\n");
+        printf("forward: No implementation is linked.\n");
     }
 }
 
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
new file mode 100644
index 000000000..e5fdada1d
--- /dev/null
+++ b/src/operator/OperatorTensor.cpp
@@ -0,0 +1,79 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <memory>
+
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/data/Data.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+
+void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>* data) {
+    if (inputIdx >= nbInputs()) {
+        AIDGE_ASSERT("%s Operator has %hu inputs", type().c_str(), nbInputs());
+    }
+    if (strcmp((*data)->type(), Tensor::Type) != 0) {
+        printf("input data must be of Tensor type");
+        exit(-1);
+    }
+    mInputs[inputIdx] = &std::dynamic_pointer_cast<Tensor>(*data);
+}
+
+std::shared_ptr<Aidge::Tensor> Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const {
+    if (inputIdx >= nbInputs()) {
+        AIDGE_ASSERT("%s Operator has %hu inputs", type().c_str(), nbInputs());
+    }
+    return *mInputs[inputIdx];
+}
+
+Aidge::Tensor& Aidge::OperatorTensor::input(const Aidge::IOIndex_t inputIdx) const {
+    return *getInput(inputIdx);
+}
+
+std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawInput(const Aidge::IOIndex_t inputIdx) const {
+    return std::static_pointer_cast<Data>(getInput(inputIdx));
+}
+
+
+std::shared_ptr<Aidge::Tensor> Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
+    if (outputIdx >= nbOutputs()) {
+        AIDGE_ASSERT("%s Operator has %hu outputs", type().c_str(), nbOutputs());
+    }
+    return mOutputs[outputIdx];
+}
+
+Aidge::Tensor& Aidge::OperatorTensor::output(const Aidge::IOIndex_t outputIdx) const {
+    return *getOutput(outputIdx);
+}
+
+std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawOutput(const Aidge::IOIndex_t outputIdx) const {
+    return std::static_pointer_cast<Data>(getOutput(outputIdx));
+}
+
+bool Aidge::OperatorTensor::outputDimsForwarded() const {
+    bool forwarded = true;
+    for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
+        forwarded &= !(getOutput(i)->empty());
+    }
+    return forwarded;
+}
+
+void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
+    for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
+        getOutput(i)->setDatatype(dataType);
+    }
+    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
+        getInput(i)->setDatatype(dataType);
+    }
+}
\ No newline at end of file
-- 
GitLab