diff --git a/include/aidge/operator/Clip.hpp b/include/aidge/operator/Clip.hpp index 3bed06cc01a8d2f2060455aec10b443007e60649..9ecb8396eee8603f85f633bef4d42ff33c9770b7 100644 --- a/include/aidge/operator/Clip.hpp +++ b/include/aidge/operator/Clip.hpp @@ -19,28 +19,38 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" namespace Aidge { +enum class ClipAttr { Min, Max }; + class Clip_Op : public OperatorTensor, public Registrable<Clip_Op, std::string, std::shared_ptr<OperatorImpl>(const Clip_Op&)> { -public: - // FIXME: change accessibility - std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); - const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static const std::string Type; - Clip_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} +private: + using Attributes_ = StaticAttributes<ClipAttr, float, float>; + template <ClipAttr e> using attr = typename Attributes_::template attr<e>; + const std::shared_ptr<Attributes_> mAttributes; + +public: + + Clip_Op(float min, float max) : + OperatorTensor(Type, {InputCategory::Data}, 1), + mAttributes(std::make_shared<Attributes_>(attr<ClipAttr::Min>(min), attr<ClipAttr::Max>(max))) + {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ Clip_Op(const Clip_Op& op) - : OperatorTensor(op) + : OperatorTensor(op), + mAttributes(op.mAttributes) { if (op.mImpl){ SET_IMPL_MACRO(Clip_Op, *this, op.backend()); @@ -59,6 +69,10 @@ public: void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; + inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } + inline float& min() const noexcept { return mAttributes->getAttr<ClipAttr::Min>(); } + inline float& max() const noexcept { return mAttributes->getAttr<ClipAttr::Max>(); } + static const std::vector<std::string> getInputsName(){ return {"data_input"}; } @@ -67,9 +81,19 @@ public: } }; -inline std::shared_ptr<Node> Clip(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Clip_Op>(), name); +inline std::shared_ptr<Node> Clip(float min = 0.0f, + float max = 1.0f, + const std::string& name = "") +{ + return std::make_shared<Node>(std::make_shared<Clip_Op>(min, max), name); +} } + +namespace { +template <> +const char* const EnumStrings<Aidge::ClipAttr>::data[] + = {"Min", "Max"}; } + #endif /* AIDGE_CORE_OPERATOR_CLIP_H_ */ diff --git a/python_binding/operator/pybind_Clip.cpp b/python_binding/operator/pybind_Clip.cpp index b9af498a275b5a2358ae1da633bee0be498a6310..3679297585b1dbee815ebce4f5f814a7f2b0e023 100644 --- a/python_binding/operator/pybind_Clip.cpp +++ b/python_binding/operator/pybind_Clip.cpp @@ -20,10 +20,10 @@ namespace Aidge { void init_Clip(py::module& m) { py::class_<Clip_Op, std::shared_ptr<Clip_Op>, OperatorTensor>(m, "ClipOp", py::multiple_inheritance()) - .def(py::init<>()) + .def(py::init<float, float>(), py::arg("min"), py::arg("max")) .def_static("get_inputs_name", &Clip_Op::getInputsName) .def_static("get_outputs_name", &Clip_Op::getOutputsName); declare_registrable<Clip_Op>(m, "ClipOp"); - m.def("Clip", &Clip, py::arg("name") = ""); + m.def("Clip", &Clip, py::arg("min") = 0.0f, py::arg("max") = 1.0f, py::arg("name") = ""); } } // namespace Aidge