Skip to content
Snippets Groups Projects
Commit af1ba81f authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

feat : add support for py registrar to GlobalAveragePooling_Op

parent bc065811
No related branches found
No related tags found
No related merge requests found
...@@ -34,7 +34,7 @@ namespace Aidge { ...@@ -34,7 +34,7 @@ namespace Aidge {
class GlobalAveragePooling_Op class GlobalAveragePooling_Op
: public OperatorTensor, : public OperatorTensor,
public Registrable<GlobalAveragePooling_Op, std::string, public Registrable<GlobalAveragePooling_Op, std::string,
std::unique_ptr<OperatorImpl>( std::shared_ptr<OperatorImpl>(
const GlobalAveragePooling_Op &)> { const GlobalAveragePooling_Op &)> {
public: public:
static const std::string Type; static const std::string Type;
...@@ -43,9 +43,11 @@ public: ...@@ -43,9 +43,11 @@ public:
GlobalAveragePooling_Op(const GlobalAveragePooling_Op &op) GlobalAveragePooling_Op(const GlobalAveragePooling_Op &op)
: OperatorTensor(op) { : OperatorTensor(op) {
mImpl = op.mImpl ? Registrar<GlobalAveragePooling_Op>::create( if (op.mImpl){
op.mOutputs[0]->getImpl()->backend())(*this) SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, op.mOutputs[0]->getImpl()->backend());
: nullptr; }else{
mImpl = nullptr;
}
} }
std::shared_ptr<Operator> clone() const override { std::shared_ptr<Operator> clone() const override {
...@@ -55,7 +57,7 @@ public: ...@@ -55,7 +57,7 @@ public:
void computeOutputDims() override final; void computeOutputDims() override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override { void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<GlobalAveragePooling_Op>::create(name)(*this); mImpl = SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, name);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
......
...@@ -18,13 +18,14 @@ ...@@ -18,13 +18,14 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
const std::string pyClassName("GlobalAveragePoolingOp");
void init_GlobalAveragePooling(py::module &m) { void init_GlobalAveragePooling(py::module &m) {
py::class_<GlobalAveragePooling_Op, std::shared_ptr<GlobalAveragePooling_Op>, py::class_<GlobalAveragePooling_Op, std::shared_ptr<GlobalAveragePooling_Op>,
OperatorTensor>(m, "GlobalAveragePooling", OperatorTensor>(m, pyClassName.c_str,
py::multiple_inheritance()) py::multiple_inheritance())
.def("get_inputs_name", &GlobalAveragePooling_Op::getInputsName) .def("get_inputs_name", &GlobalAveragePooling_Op::getInputsName)
.def("get_outputs_name", &GlobalAveragePooling_Op::getOutputsName); .def("get_outputs_name", &GlobalAveragePooling_Op::getOutputsName);
declare_registrable<GlobalAveragePooling_Op>(m, pyClassName);
m.def("globalaveragepooling", &GlobalAveragePooling, py::arg("name") = ""); m.def("globalaveragepooling", &GlobalAveragePooling, py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment