Skip to content
Snippets Groups Projects
Commit 504077b9 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add axis attr to Softmax

parent b05a6d63
No related branches found
No related tags found
No related merge requests found
...@@ -16,18 +16,24 @@ ...@@ -16,18 +16,24 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
enum class SoftmaxAttr { AxisIdx };
class Softmax_Op : public Operator, class Softmax_Op : public Operator,
public Registrable<Softmax_Op, std::string, std::unique_ptr<OperatorImpl>(const Softmax_Op&)> { public Registrable<Softmax_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Softmax_Op&)>,
public StaticAttributes<SoftmaxAttr, int> {
public: public:
// FIXME: change accessibility // FIXME: change accessibility
std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
...@@ -36,8 +42,14 @@ public: ...@@ -36,8 +42,14 @@ public:
public: public:
static constexpr const char* Type = "Softmax"; static constexpr const char* Type = "Softmax";
Softmax_Op() Softmax_Op() = delete;
: Operator(Type)
using Attributes_ = StaticAttributes<SoftmaxAttr, int>;
template <SoftmaxAttr e> using attr = typename Attributes_::template attr<e>;
Softmax_Op(int axis)
: Operator(Type),
Attributes_(
attr<SoftmaxAttr::AxisIdx>(axis))
{ {
setDatatype(DataType::Float32); setDatatype(DataType::Float32);
} }
...@@ -48,6 +60,7 @@ public: ...@@ -48,6 +60,7 @@ public:
*/ */
Softmax_Op(const Softmax_Op& op) Softmax_Op(const Softmax_Op& op)
: Operator(Type), : Operator(Type),
Attributes_(op),
mOutput(std::make_shared<Tensor>(*op.mOutput)) mOutput(std::make_shared<Tensor>(*op.mOutput))
{ {
// cpy-ctor // cpy-ctor
...@@ -64,7 +77,7 @@ public: ...@@ -64,7 +77,7 @@ public:
} }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx == 0 && "operator supports only 1 input"); assert(inputIdx == 0 && "Softmax operator supports only 1 input");
(void) inputIdx; // avoid unused warning (void) inputIdx; // avoid unused warning
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInput = std::dynamic_pointer_cast<Tensor>(data); mInput = std::dynamic_pointer_cast<Tensor>(data);
...@@ -85,24 +98,23 @@ public: ...@@ -85,24 +98,23 @@ public:
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx == 0) && "Softmax Operator has only 1 input"); assert((inputIdx == 0) && "Softmax operator has only 1 input");
(void) inputIdx; // avoid unused warning (void) inputIdx; // avoid unused warning
return mInput; return mInput;
} }
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Softmax Operator has only 1 output"); assert((outputIdx == 0) && "Softmax operator has only 1 output");
(void) outputIdx; // avoid unused warning (void) outputIdx; // avoid unused warning
return mOutput; return mOutput;
} }
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx == 0 && "operator supports only 1 input"); assert(inputIdx == 0 && "Softmax operator supports only 1 input");
(void) inputIdx; // avoid unused warning (void) inputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mInput); return std::static_pointer_cast<Data>(mInput);
} }
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output"); assert(outputIdx == 0 && "Softmax operator supports only 1 output");
(void) outputIdx; // avoid unused warning (void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput); return std::static_pointer_cast<Data>(mOutput);
} }
...@@ -133,9 +145,14 @@ public: ...@@ -133,9 +145,14 @@ public:
} }
}; };
inline std::shared_ptr<Node> Softmax(const std::string& name = "") { inline std::shared_ptr<Node> Softmax(int axis, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Softmax_Op>(), name); return std::make_shared<Node>(std::make_shared<Softmax_Op>(axis), name);
} }
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::SoftmaxAttr>::data[] = {"Axis"};
} }
#endif /* AIDGE_CORE_OPERATOR_SOFTMAX_H_ */ #endif /* AIDGE_CORE_OPERATOR_SOFTMAX_H_ */
...@@ -19,10 +19,10 @@ namespace py = pybind11; ...@@ -19,10 +19,10 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Softmax(py::module& m) { void init_Softmax(py::module& m) {
py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "SoftmaxOp", py::multiple_inheritance()) py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator, Attributes>(m, "SoftmaxOp", py::multiple_inheritance())
.def("get_inputs_name", &Softmax_Op::getInputsName) .def("get_inputs_name", &Softmax_Op::getInputsName)
.def("get_outputs_name", &Softmax_Op::getOutputsName); .def("get_outputs_name", &Softmax_Op::getOutputsName);
m.def("Softmax", &Softmax, py::arg("name") = ""); m.def("Softmax", &Softmax, py::arg("axis"), py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment