Skip to content
Snippets Groups Projects
Commit 0e89cde0 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

feat : added getDataType() function to OperatorTensor

parent c1c35a65
No related branches found
No related tags found
1 merge request!319feat_operator_convtranspose
Pipeline #67337 waiting for manual action
......@@ -237,6 +237,13 @@ public:
*/
void setBackend(const std::vector<std::pair<std::string, DeviceIdx_t>>& backends);
/**
* @brief gets the data type of the operator's tensors.
* @return a pair whose first object contains inputs' data types
* and second object outputs' data types
*/
virtual std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>>
getDataType() const = 0;
virtual void setDataType(const DataType& dataType) const = 0;
virtual void setDataFormat(const DataFormat& dataFormat) const = 0;
/**
......
......@@ -179,8 +179,18 @@ public:
virtual bool dimsForwarded() const;
///////////////////////////////////////////////////
/**
* @brief gets the data type of the operator's tensors.
* @return a pair whose first object contains inputs' data types
* and second object outputs' data types
*/
std::pair<std::vector<Aidge::DataType>, std::vector<Aidge::DataType>>
getDataType() const;
/**
* @brief Sets the data type of the operator's tensors.
* @warning Sets all outputs but only inputs of category
* @code InputCategory::Param @endcode & @code InputCategory::OptionnalParam @endcode
* @param dataType Data type to set.
*/
virtual void setDataType(const DataType& dataType) const override;
......
......@@ -10,13 +10,18 @@
********************************************************************************/
#include <memory>
#include <utility>
#include <vector>
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/DataType.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge{
using std::make_pair;
Aidge::OperatorTensor::OperatorTensor(const std::string& type,
const std::vector<InputCategory>& inputsCategory,
......@@ -180,6 +185,19 @@ bool Aidge::OperatorTensor::dimsForwarded() const {
return forwarded;
}
std::pair<std::vector<DataType>, std::vector<DataType>>
OperatorTensor::getDataType() const {
auto res = make_pair(std::vector<DataType>(nbInputs()),
std::vector<DataType>(nbOutputs()));
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
res.first[i] = getOutput(i)->dataType();
}
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
res.second[i] = getOutput(i)->dataType();
}
return res;
}
void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDataType(dataType);
......@@ -222,3 +240,4 @@ void Aidge::OperatorTensor::forward() {
Operator::forward();
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment