diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index ba6132a5ee00325d0f7de57db117a169d42352e9..db078a6f1677c5dfc09035d384eeb304324cebcb 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -16,18 +16,24 @@ #include <memory> #include <vector> -#include "aidge/utils/Registrar.hpp" -#include "aidge/operator/Operator.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/data/Data.hpp" #include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" namespace Aidge { +enum class SoftmaxAttr { AxisIdx }; class Softmax_Op : public Operator, - public Registrable<Softmax_Op, std::string, std::unique_ptr<OperatorImpl>(const Softmax_Op&)> { + public Registrable<Softmax_Op, + std::string, + std::unique_ptr<OperatorImpl>(const Softmax_Op&)>, + public StaticAttributes<SoftmaxAttr, int> { public: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -36,8 +42,14 @@ public: public: static constexpr const char* Type = "Softmax"; - Softmax_Op() - : Operator(Type) + Softmax_Op() = delete; + + using Attributes_ = StaticAttributes<SoftmaxAttr, int>; + template <SoftmaxAttr e> using attr = typename Attributes_::template attr<e>; + Softmax_Op(int axis) + : Operator(Type), + Attributes_( + attr<SoftmaxAttr::AxisIdx>(axis)) { setDatatype(DataType::Float32); } @@ -48,6 +60,7 @@ public: */ Softmax_Op(const Softmax_Op& op) : Operator(Type), + Attributes_(op), mOutput(std::make_shared<Tensor>(*op.mOutput)) { // cpy-ctor @@ -64,7 +77,7 @@ public: } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - assert(inputIdx == 0 && "operator supports only 1 input"); + assert(inputIdx == 0 && "Softmax operator supports only 1 input"); (void) inputIdx; // avoid unused warning assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInput = std::dynamic_pointer_cast<Tensor>(data); @@ -85,24 +98,23 @@ public: inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { - assert((inputIdx == 0) && "Softmax Operator has only 1 input"); + assert((inputIdx == 0) && "Softmax operator has only 1 input"); (void) inputIdx; // avoid unused warning return mInput; } inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { - assert((outputIdx == 0) && "Softmax Operator has only 1 output"); + assert((outputIdx == 0) && "Softmax operator has only 1 output"); (void) outputIdx; // avoid unused warning return mOutput; } - std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { - assert(inputIdx == 0 && "operator supports only 1 input"); + assert(inputIdx == 0 && "Softmax operator supports only 1 input"); (void) inputIdx; // avoid unused warning return std::static_pointer_cast<Data>(mInput); } std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { - assert(outputIdx == 0 && "operator supports only 1 output"); + assert(outputIdx == 0 && "Softmax operator supports only 1 output"); (void) outputIdx; // avoid unused warning return std::static_pointer_cast<Data>(mOutput); } @@ -133,9 +145,14 @@ public: } }; -inline std::shared_ptr<Node> Softmax(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Softmax_Op>(), name); +inline std::shared_ptr<Node> Softmax(int axis, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Softmax_Op>(axis), name); } +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::SoftmaxAttr>::data[] = {"Axis"}; } #endif /* AIDGE_CORE_OPERATOR_SOFTMAX_H_ */ diff --git a/python_binding/operator/pybind_Softmax.cpp b/python_binding/operator/pybind_Softmax.cpp index 8e50ab7c83bf43285b357cb803c0ce3eb42f4cc7..38aaa4dba443e0691bcddeae6f619bb505963163 100644 --- a/python_binding/operator/pybind_Softmax.cpp +++ b/python_binding/operator/pybind_Softmax.cpp @@ -19,10 +19,10 @@ namespace py = pybind11; namespace Aidge { void init_Softmax(py::module& m) { - py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "SoftmaxOp", py::multiple_inheritance()) + py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator, Attributes>(m, "SoftmaxOp", py::multiple_inheritance()) .def("get_inputs_name", &Softmax_Op::getInputsName) .def("get_outputs_name", &Softmax_Op::getOutputsName); - m.def("Softmax", &Softmax, py::arg("name") = ""); + m.def("Softmax", &Softmax, py::arg("axis"), py::arg("name") = ""); } } // namespace Aidge