diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 81a54620a6f325eba04e9055e12d73d6a5d64163..5a12cfea28455b75ef527d6b4351f157fa087136 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -237,6 +237,13 @@ public: */ void setBackend(const std::vector<std::pair<std::string, DeviceIdx_t>>& backends); + /** + * @brief gets the data type of the operator's tensors. + * @return a pair whose first object contains inputs' data types + * and second object outputs' data types + */ + virtual std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>> + getDataType() const = 0; virtual void setDataType(const DataType& dataType) const = 0; virtual void setDataFormat(const DataFormat& dataFormat) const = 0; /** diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index a515ecb5ba4843c1338f526d05cdef327bfd1ba0..0e3d275eb73c8024af233a7c4406b0dbaf0e43ca 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -179,8 +179,18 @@ public: virtual bool dimsForwarded() const; /////////////////////////////////////////////////// + /** + * @brief gets the data type of the operator's tensors. + * @return a pair whose first object contains inputs' data types + * and second object outputs' data types + */ + std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>> + getDataType() const; + /** * @brief Sets the data type of the operator's tensors. + * @warning Sets all outputs but only inputs of category + * @code InputCategory::Param @endcode & @code InputCategory::OptionnalParam @endcode * @param dataType Data type to set. */ virtual void setDataType(const DataType& dataType) const override; diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index cac1ad2264af0e1b687c44eb220ee2c6080e9e73..1c5b0a0d396dd7dcab98ccae685d747274165463 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -10,13 +10,18 @@ ********************************************************************************/ #include <memory> +#include <utility> +#include <vector> #include "aidge/operator/OperatorTensor.hpp" #include "aidge/data/Data.hpp" +#include "aidge/data/DataType.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +namespace Aidge{ +using std::make_pair; Aidge::OperatorTensor::OperatorTensor(const std::string& type, const std::vector<InputCategory>& inputsCategory, @@ -180,6 +185,19 @@ bool Aidge::OperatorTensor::dimsForwarded() const { return forwarded; } +std::pair<std::vector<DataType>, std::vector<DataType>> +OperatorTensor::getDataType() const { + auto res = make_pair(std::vector<DataType>(nbInputs()), + std::vector<DataType>(nbOutputs())); + for (IOIndex_t i = 0; i < nbInputs(); ++i) { + res.first[i] = getOutput(i)->dataType(); + } + for (IOIndex_t i = 0; i < nbOutputs(); ++i) { + res.second[i] = getOutput(i)->dataType(); + } + return res; +} + void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { for (IOIndex_t i = 0; i < nbOutputs(); ++i) { getOutput(i)->setDataType(dataType); @@ -222,3 +240,4 @@ void Aidge::OperatorTensor::forward() { Operator::forward(); } +} // namespace Aidge