From 0e89cde059a5f3e3d0da1582b2af9b3f9e2c7120 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20KUBLER?= <gregoire.kubler@proton.me> Date: Wed, 5 Mar 2025 11:56:13 +0000 Subject: [PATCH] feat : added getDataType() function to OperatorTensor --- include/aidge/operator/Operator.hpp | 7 +++++++ include/aidge/operator/OperatorTensor.hpp | 10 ++++++++++ src/operator/OperatorTensor.cpp | 19 +++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 81a54620a..5a12cfea2 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -237,6 +237,13 @@ public: */ void setBackend(const std::vector<std::pair<std::string, DeviceIdx_t>>& backends); + /** + * @brief gets the data type of the operator's tensors. + * @return a pair whose first object contains inputs' data types + * and second object outputs' data types + */ + virtual std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>> + getDataType() const = 0; virtual void setDataType(const DataType& dataType) const = 0; virtual void setDataFormat(const DataFormat& dataFormat) const = 0; /** diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index a515ecb5b..0e3d275eb 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -179,8 +179,18 @@ public: virtual bool dimsForwarded() const; /////////////////////////////////////////////////// + /** + * @brief gets the data type of the operator's tensors. + * @return a pair whose first object contains inputs' data types + * and second object outputs' data types + */ + std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>> + getDataType() const; + /** * @brief Sets the data type of the operator's tensors. + * @warning Sets all outputs but only inputs of category + * @code InputCategory::Param @endcode & @code InputCategory::OptionnalParam @endcode * @param dataType Data type to set. */ virtual void setDataType(const DataType& dataType) const override; diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index cac1ad226..1c5b0a0d3 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -10,13 +10,18 @@ ********************************************************************************/ #include <memory> +#include <utility> +#include <vector> #include "aidge/operator/OperatorTensor.hpp" #include "aidge/data/Data.hpp" +#include "aidge/data/DataType.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +namespace Aidge{ +using std::make_pair; Aidge::OperatorTensor::OperatorTensor(const std::string& type, const std::vector<InputCategory>& inputsCategory, @@ -180,6 +185,19 @@ bool Aidge::OperatorTensor::dimsForwarded() const { return forwarded; } +std::pair<std::vector<DataType>, std::vector<DataType>> +OperatorTensor::getDataType() const { + auto res = make_pair(std::vector<DataType>(nbInputs()), + std::vector<DataType>(nbOutputs())); + for (IOIndex_t i = 0; i < nbInputs(); ++i) { + res.first[i] = getOutput(i)->dataType(); + } + for (IOIndex_t i = 0; i < nbOutputs(); ++i) { + res.second[i] = getOutput(i)->dataType(); + } + return res; +} + void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { for (IOIndex_t i = 0; i < nbOutputs(); ++i) { getOutput(i)->setDataType(dataType); @@ -222,3 +240,4 @@ void Aidge::OperatorTensor::forward() { Operator::forward(); } +} // namespace Aidge -- GitLab