Skip to content
Snippets Groups Projects
Commit 504077b9 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add axis attr to Softmax

parent b05a6d63
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
......@@ -16,18 +16,24 @@
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class SoftmaxAttr { AxisIdx };
class Softmax_Op : public Operator,
public Registrable<Softmax_Op, std::string, std::unique_ptr<OperatorImpl>(const Softmax_Op&)> {
public Registrable<Softmax_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Softmax_Op&)>,
public StaticAttributes<SoftmaxAttr, int> {
public:
// FIXME: change accessibility
std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
......@@ -36,8 +42,14 @@ public:
public:
static constexpr const char* Type = "Softmax";
Softmax_Op()
: Operator(Type)
Softmax_Op() = delete;
using Attributes_ = StaticAttributes<SoftmaxAttr, int>;
template <SoftmaxAttr e> using attr = typename Attributes_::template attr<e>;
Softmax_Op(int axis)
: Operator(Type),
Attributes_(
attr<SoftmaxAttr::AxisIdx>(axis))
{
setDatatype(DataType::Float32);
}
......@@ -48,6 +60,7 @@ public:
*/
Softmax_Op(const Softmax_Op& op)
: Operator(Type),
Attributes_(op),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
......@@ -64,7 +77,7 @@ public:
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx == 0 && "operator supports only 1 input");
assert(inputIdx == 0 && "Softmax operator supports only 1 input");
(void) inputIdx; // avoid unused warning
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInput = std::dynamic_pointer_cast<Tensor>(data);
......@@ -85,24 +98,23 @@ public:
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx == 0) && "Softmax Operator has only 1 input");
assert((inputIdx == 0) && "Softmax operator has only 1 input");
(void) inputIdx; // avoid unused warning
return mInput;
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Softmax Operator has only 1 output");
assert((outputIdx == 0) && "Softmax operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx == 0 && "operator supports only 1 input");
assert(inputIdx == 0 && "Softmax operator supports only 1 input");
(void) inputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mInput);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
assert(outputIdx == 0 && "Softmax operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
......@@ -133,9 +145,14 @@ public:
}
};
inline std::shared_ptr<Node> Softmax(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Softmax_Op>(), name);
inline std::shared_ptr<Node> Softmax(int axis, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Softmax_Op>(axis), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::SoftmaxAttr>::data[] = {"Axis"};
}
#endif /* AIDGE_CORE_OPERATOR_SOFTMAX_H_ */
......@@ -19,10 +19,10 @@ namespace py = pybind11;
namespace Aidge {
void init_Softmax(py::module& m) {
py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "SoftmaxOp", py::multiple_inheritance())
py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator, Attributes>(m, "SoftmaxOp", py::multiple_inheritance())
.def("get_inputs_name", &Softmax_Op::getInputsName)
.def("get_outputs_name", &Softmax_Op::getOutputsName);
m.def("Softmax", &Softmax, py::arg("name") = "");
m.def("Softmax", &Softmax, py::arg("axis"), py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment