Skip to content
Snippets Groups Projects
Commit 0e89cde0 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

feat : added getDataType() function to OperatorTensor

parent c1c35a65
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!319feat_operator_convtranspose
Pipeline #67337 waiting for manual action
...@@ -237,6 +237,13 @@ public: ...@@ -237,6 +237,13 @@ public:
*/ */
void setBackend(const std::vector<std::pair<std::string, DeviceIdx_t>>& backends); void setBackend(const std::vector<std::pair<std::string, DeviceIdx_t>>& backends);
/**
* @brief gets the data type of the operator's tensors.
* @return a pair whose first object contains inputs' data types
* and second object outputs' data types
*/
virtual std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>>
getDataType() const = 0;
virtual void setDataType(const DataType& dataType) const = 0; virtual void setDataType(const DataType& dataType) const = 0;
virtual void setDataFormat(const DataFormat& dataFormat) const = 0; virtual void setDataFormat(const DataFormat& dataFormat) const = 0;
/** /**
......
...@@ -179,8 +179,18 @@ public: ...@@ -179,8 +179,18 @@ public:
virtual bool dimsForwarded() const; virtual bool dimsForwarded() const;
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
/**
* @brief gets the data type of the operator's tensors.
* @return a pair whose first object contains inputs' data types
* and second object outputs' data types
*/
std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>>
getDataType() const;
/** /**
* @brief Sets the data type of the operator's tensors. * @brief Sets the data type of the operator's tensors.
* @warning Sets all outputs but only inputs of category
* @code InputCategory::Param @endcode & @code InputCategory::OptionnalParam @endcode
* @param dataType Data type to set. * @param dataType Data type to set.
*/ */
virtual void setDataType(const DataType& dataType) const override; virtual void setDataType(const DataType& dataType) const override;
......
...@@ -10,13 +10,18 @@ ...@@ -10,13 +10,18 @@
********************************************************************************/ ********************************************************************************/
#include <memory> #include <memory>
#include <utility>
#include <vector>
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/data/DataType.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
namespace Aidge{
using std::make_pair;
Aidge::OperatorTensor::OperatorTensor(const std::string& type, Aidge::OperatorTensor::OperatorTensor(const std::string& type,
const std::vector<InputCategory>& inputsCategory, const std::vector<InputCategory>& inputsCategory,
...@@ -180,6 +185,19 @@ bool Aidge::OperatorTensor::dimsForwarded() const { ...@@ -180,6 +185,19 @@ bool Aidge::OperatorTensor::dimsForwarded() const {
return forwarded; return forwarded;
} }
std::pair<std::vector<DataType>, std::vector<DataType>>
OperatorTensor::getDataType() const {
auto res = make_pair(std::vector<DataType>(nbInputs()),
std::vector<DataType>(nbOutputs()));
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
res.first[i] = getOutput(i)->dataType();
}
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
res.second[i] = getOutput(i)->dataType();
}
return res;
}
void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) { for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDataType(dataType); getOutput(i)->setDataType(dataType);
...@@ -222,3 +240,4 @@ void Aidge::OperatorTensor::forward() { ...@@ -222,3 +240,4 @@ void Aidge::OperatorTensor::forward() {
Operator::forward(); Operator::forward();
} }
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment