diff --git a/include/aidge/operator/Cast.hpp b/include/aidge/operator/Cast.hpp index 6efbc0a214dde3ca969226f734b5ee903fe5ab50..5a9ba31a00089c162e5fa933bf10e13779c508df 100644 --- a/include/aidge/operator/Cast.hpp +++ b/include/aidge/operator/Cast.hpp @@ -21,6 +21,7 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" namespace Aidge { @@ -30,13 +31,29 @@ public: void forward() override; }; +enum class CastAttr { TargetType }; + class Cast_Op : public OperatorTensor, public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> { public: static const std::string Type; - Cast_Op() : OperatorTensor(Type, 1, 0, 1) { +private: + using Attributes_ = StaticAttributes<CastAttr, DataType>; + template <CastAttr e> + using attr = typename Attributes_::template attr<e>; + const std::shared_ptr<Attributes_> mAttributes; + +public: + Cast_Op() = delete; + + Cast_Op(const DataType targetType) + : OperatorTensor(Type, 1, 0, 1), + mAttributes(std::make_shared<Attributes_>( + attr<CastAttr::TargetType>(targetType))) + { mImpl = std::make_shared<Cast_OpImpl>(*this); + mOutputs[0]->setDataType(targetType); } /** @@ -44,7 +61,8 @@ public: * @param op Operator to copy. */ Cast_Op(const Cast_Op& op) - : OperatorTensor(op) + : OperatorTensor(op), + mAttributes(op.mAttributes) { if (!op.backend().empty()) { SET_IMPL_MACRO(Cast_Op, *this, op.backend()); @@ -64,6 +82,9 @@ public: void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } + inline DataType& targetType() const { return mAttributes->template getAttr<CastAttr::TargetType>(); } + static const std::vector<std::string> getInputsName(){ return {"data_input"}; } @@ -72,9 +93,16 @@ public: } }; -inline std::shared_ptr<Node> Cast(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Cast_Op>(), name); + +inline std::shared_ptr<Node> Cast(const DataType targetType, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Cast_Op>(targetType), name); } + +namespace { +template <> +const char* const EnumStrings<Aidge::CastAttr>::data[] = { "TargetType" }; } +} // namespace Aidge + #endif /* AIDGE_CORE_OPERATOR_CAST_H_ */ \ No newline at end of file