Skip to content
Snippets Groups Projects
Commit 593de6f3 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Enhance] add TargetType paramter to Cast Operator

Move StaticAttribute from Base class to member attribute
parent af845fe7
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!145Improve UI for Operator/Node/GraphView/Tensor
......@@ -21,6 +21,7 @@
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
......@@ -30,13 +31,29 @@ public:
void forward() override;
};
enum class CastAttr { TargetType };
class Cast_Op : public OperatorTensor,
public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> {
public:
static const std::string Type;
Cast_Op() : OperatorTensor(Type, 1, 0, 1) {
private:
using Attributes_ = StaticAttributes<CastAttr, DataType>;
template <CastAttr e>
using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
Cast_Op() = delete;
Cast_Op(const DataType targetType)
: OperatorTensor(Type, 1, 0, 1),
mAttributes(std::make_shared<Attributes_>(
attr<CastAttr::TargetType>(targetType)))
{
mImpl = std::make_shared<Cast_OpImpl>(*this);
mOutputs[0]->setDataType(targetType);
}
/**
......@@ -44,7 +61,8 @@ public:
* @param op Operator to copy.
*/
Cast_Op(const Cast_Op& op)
: OperatorTensor(op)
: OperatorTensor(op),
mAttributes(op.mAttributes)
{
if (!op.backend().empty()) {
SET_IMPL_MACRO(Cast_Op, *this, op.backend());
......@@ -64,6 +82,9 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
inline DataType& targetType() const { return mAttributes->template getAttr<CastAttr::TargetType>(); }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
......@@ -72,9 +93,16 @@ public:
}
};
inline std::shared_ptr<Node> Cast(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Cast_Op>(), name);
inline std::shared_ptr<Node> Cast(const DataType targetType, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Cast_Op>(targetType), name);
}
namespace {
template <>
const char* const EnumStrings<Aidge::CastAttr>::data[] = { "TargetType" };
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_CAST_H_ */
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment