From 0e89cde059a5f3e3d0da1582b2af9b3f9e2c7120 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Gr=C3=A9goire=20KUBLER?= <gregoire.kubler@proton.me>
Date: Wed, 5 Mar 2025 11:56:13 +0000
Subject: [PATCH] feat : added getDataType() function to OperatorTensor

---
 include/aidge/operator/Operator.hpp       |  7 +++++++
 include/aidge/operator/OperatorTensor.hpp | 10 ++++++++++
 src/operator/OperatorTensor.cpp           | 19 +++++++++++++++++++
 3 files changed, 36 insertions(+)

diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 81a54620a..5a12cfea2 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -237,6 +237,13 @@ public:
      */
     void setBackend(const std::vector<std::pair<std::string, DeviceIdx_t>>& backends);
 
+    /**
+     * @brief gets the data type of the operator's tensors.
+     * @return a pair whose first object contains inputs' data types
+     * and second object outputs' data types
+     */
+    virtual std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>>
+    getDataType() const = 0;
     virtual void setDataType(const DataType& dataType) const = 0;
     virtual void setDataFormat(const DataFormat& dataFormat) const = 0;
     /**
diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp
index a515ecb5b..0e3d275eb 100644
--- a/include/aidge/operator/OperatorTensor.hpp
+++ b/include/aidge/operator/OperatorTensor.hpp
@@ -179,8 +179,18 @@ public:
     virtual bool dimsForwarded() const;
     ///////////////////////////////////////////////////
 
+    /**
+     * @brief gets the data type of the operator's tensors.
+     * @return a pair whose first object contains inputs' data types
+     * and second object outputs' data types
+     */
+    std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>>
+    getDataType() const;
+
     /**
      * @brief Sets the data type of the operator's tensors.
+     * @warning Sets all outputs but only inputs of category 
+     * @code InputCategory::Param @endcode & @code InputCategory::OptionnalParam @endcode
      * @param dataType Data type to set.
      */
     virtual void setDataType(const DataType& dataType) const override;
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
index cac1ad226..1c5b0a0d3 100644
--- a/src/operator/OperatorTensor.cpp
+++ b/src/operator/OperatorTensor.cpp
@@ -10,13 +10,18 @@
  ********************************************************************************/
 
 #include <memory>
+#include <utility>
+#include <vector>
 
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/data/Data.hpp"
+#include "aidge/data/DataType.hpp"
 #include "aidge/data/Tensor.hpp"
 #include "aidge/utils/Types.h"
 #include "aidge/utils/ErrorHandling.hpp"
 
+namespace Aidge{
+using std::make_pair;
 
 Aidge::OperatorTensor::OperatorTensor(const std::string& type,
                                       const std::vector<InputCategory>& inputsCategory,
@@ -180,6 +185,19 @@ bool Aidge::OperatorTensor::dimsForwarded() const {
     return forwarded;
 }
 
+std::pair<std::vector<DataType>, std::vector<DataType>>
+OperatorTensor::getDataType() const {
+    auto res = make_pair(std::vector<DataType>(nbInputs()),
+                         std::vector<DataType>(nbOutputs()));
+    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
+        res.first[i] = getOutput(i)->dataType();
+    }
+    for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
+        res.second[i] = getOutput(i)->dataType();
+    }
+    return res;
+}
+
 void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
     for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
         getOutput(i)->setDataType(dataType);
@@ -222,3 +240,4 @@ void Aidge::OperatorTensor::forward() {
 
     Operator::forward();
 }
+}  // namespace Aidge
-- 
GitLab