Skip to content
Snippets Groups Projects
Commit 8be9cf2f authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Operator binding hotfix

parent f881c84c
No related branches found
No related tags found
No related merge requests found
Pipeline #33213 canceled
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
namespace py = pybind11; namespace py = pybind11;
...@@ -22,8 +23,8 @@ namespace Aidge { ...@@ -22,8 +23,8 @@ namespace Aidge {
* *
*/ */
class pyOperatorImpl: public OperatorImpl { class pyOperatorImpl: public OperatorImpl {
public: public:
pyOperatorImpl(){} using OperatorImpl::OperatorImpl; // Inherit constructors
void forward() override { void forward() override {
PYBIND11_OVERRIDE( PYBIND11_OVERRIDE(
...@@ -42,7 +43,7 @@ class pyOperatorImpl: public OperatorImpl { ...@@ -42,7 +43,7 @@ class pyOperatorImpl: public OperatorImpl {
); );
} }
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override { NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME( PYBIND11_OVERRIDE_NAME(
NbElts_t, NbElts_t,
OperatorImpl, OperatorImpl,
"get_nb_required_data", "get_nb_required_data",
...@@ -51,7 +52,7 @@ class pyOperatorImpl: public OperatorImpl { ...@@ -51,7 +52,7 @@ class pyOperatorImpl: public OperatorImpl {
); );
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME( PYBIND11_OVERRIDE_NAME(
NbElts_t, NbElts_t,
OperatorImpl, OperatorImpl,
"get_nb_required_protected", "get_nb_required_protected",
...@@ -62,7 +63,7 @@ class pyOperatorImpl: public OperatorImpl { ...@@ -62,7 +63,7 @@ class pyOperatorImpl: public OperatorImpl {
} }
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override { const std::vector<DimSize_t> &inputsSize) const override {
PYBIND11_OVERRIDE_PURE_NAME( PYBIND11_OVERRIDE_NAME(
NbElts_t, NbElts_t,
OperatorImpl, OperatorImpl,
"get_required_memory", "get_required_memory",
...@@ -73,7 +74,7 @@ class pyOperatorImpl: public OperatorImpl { ...@@ -73,7 +74,7 @@ class pyOperatorImpl: public OperatorImpl {
); );
} }
NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override { NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME( PYBIND11_OVERRIDE_NAME(
NbElts_t, NbElts_t,
OperatorImpl, OperatorImpl,
"get_nb_consumed_data", "get_nb_consumed_data",
...@@ -83,7 +84,7 @@ class pyOperatorImpl: public OperatorImpl { ...@@ -83,7 +84,7 @@ class pyOperatorImpl: public OperatorImpl {
); );
} }
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override { NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME( PYBIND11_OVERRIDE_NAME(
NbElts_t, NbElts_t,
OperatorImpl, OperatorImpl,
"get_nb_produced_data", "get_nb_produced_data",
...@@ -93,7 +94,7 @@ class pyOperatorImpl: public OperatorImpl { ...@@ -93,7 +94,7 @@ class pyOperatorImpl: public OperatorImpl {
); );
} }
void updateConsummerProducer() override { void updateConsummerProducer() override {
PYBIND11_OVERRIDE_PURE_NAME( PYBIND11_OVERRIDE_NAME(
void, void,
OperatorImpl, OperatorImpl,
"update_consummer_producer", "update_consummer_producer",
...@@ -106,7 +107,7 @@ class pyOperatorImpl: public OperatorImpl { ...@@ -106,7 +107,7 @@ class pyOperatorImpl: public OperatorImpl {
void init_OperatorImpl(py::module& m){ void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<>()) .def(py::init<const Operator&>())
.def("forward", &OperatorImpl::forward) .def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward) .def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData) .def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment