diff --git a/include/aidge/operator/Cast.hpp b/include/aidge/operator/Cast.hpp index fd05c96a296fcbde390e328bf0a5fe4b3ee52119..cd37e47d84c9a03d047b38d6e5cd4f8e84423d2d 100644 --- a/include/aidge/operator/Cast.hpp +++ b/include/aidge/operator/Cast.hpp @@ -123,6 +123,13 @@ public: */ inline DataType& targetType() const { return mAttributes->template getAttr<CastAttr::TargetType>(); } + /** + * @brief Sets the data type of the operator's tensors. + * @details This method needs to be overwritten because Cast_Op's output type needs to be set from targetType + * @param dataType Data type to set. + */ + virtual void setDataType(const DataType& dataType) const override; + /** * @brief Get the input tensor names for the Cast operator. * @return A vector containing the input tensor names. diff --git a/src/operator/Cast.cpp b/src/operator/Cast.cpp index 3c22939581353afe3df62084999e1d0ce968ea9a..128868dcd3f39c16316b67a09abe471c47b2df33 100644 --- a/src/operator/Cast.cpp +++ b/src/operator/Cast.cpp @@ -49,6 +49,11 @@ Cast_Op::Cast_Op(const Cast_Op& op) } } +void Aidge::Cast_Op::setDataType(const DataType& dataType) const { + if (targetType() != dataType) { + Log::warn("Cast::setDataType(): Cannot setDataType for cast operator."); + } +} void Cast_Op::setBackend(const std::string& name, DeviceIdx_t device) { if (Registrar<Cast_Op>::exists({name})) {