From 1262618bd79d70068bf56762bc68d2059edc0a03 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Wed, 4 Oct 2023 09:32:29 +0000 Subject: [PATCH] [OperatorImpl] Bind OperatorImpl so that it can be extended on Python side + add Operator.setImpl method. --- include/aidge/operator/GenericOperator.hpp | 16 ++- include/aidge/operator/Operator.hpp | 10 +- .../backend/pybind_OperatorImpl.cpp | 100 +++++++++++++++++- python_binding/operator/pybind_Operator.cpp | 3 + 4 files changed, 125 insertions(+), 4 deletions(-) diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 184100174..dc70646fb 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -206,8 +206,20 @@ class GenericOperator_Op void setBackend(const std::string & /*name*/) { printf("setBackend: not available yet.\n"); } void setDatatype(const DataType & /*datatype*/) { printf("setDatatype: not available yet.\n"); } - void forward() override final { printf("forward: not available yet.\n"); } - void backward() override final { printf("backward: not available yet.\n"); } + void forward() override final { + if(mImpl){ + mImpl->forward(); + }else{ + printf("forward: No implementation is linked.\n"); + } + } + void backward() override final { + if(mImpl){ + mImpl->backward(); + }else{ + printf("backward: No implementation is linked.\n"); + } + } inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; }; inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; }; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 3ac651cfd..2582f83c4 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -26,7 +26,7 @@ namespace Aidge { class Operator : public std::enable_shared_from_this<Operator> { protected: - std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator + std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator std::map<std::string, std::shared_ptr<Hook>> mHooks; private: @@ -76,6 +76,14 @@ public: virtual void setBackend(const std::string& name) = 0; virtual void setDatatype(const DataType& datatype) = 0; + /** + * @brief Set the a new OperatorImpl to the Operator + * + */ + void setImpl(std::shared_ptr<OperatorImpl> impl){ + mImpl = impl; + } + /** * @brief Minimum amount of data from a specific input for one computation pass. * @param inputIdx Index of the input analysed. diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 11189f2f3..3a104f162 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -14,7 +14,105 @@ namespace py = pybind11; namespace Aidge { + +/** + * @brief Trampoline class for binding + * + */ +class pyOperatorImpl: public OperatorImpl { + public: + pyOperatorImpl(){} + + void forward() override { + PYBIND11_OVERRIDE( + void, + OperatorImpl, + forward, + + ); + } + void backward() override { + PYBIND11_OVERRIDE( + void, + OperatorImpl, + backward, + + ); + } + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_required_data", + getNbRequiredData, + inputIdx + ); + } + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_required_protected", + getNbRequiredProtected, + inputIdx + + ); + } + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t> &inputsSize) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_required_memory", + getRequiredMemory, + outputIdx, + inputsSize + + ); + } + NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_consumed_data", + getNbConsumedData, + inputIdx + + ); + } + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_produced_data", + getNbProducedData, + outputIdx + + ); + } + void updateConsummerProducer() override { + PYBIND11_OVERRIDE_PURE_NAME( + void, + OperatorImpl, + "update_consummer_producer", + updateConsummerProducer, + + ); + } +}; + void init_OperatorImpl(py::module& m){ - py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>>(m, "OperatorImpl"); + + py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) + .def(py::init<>()) + .def("forward", &OperatorImpl::forward) + .def("backward", &OperatorImpl::backward) + .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) + .def("get_nb_required_protected", &OperatorImpl::getNbRequiredProtected) + .def("get_required_memory", &OperatorImpl::getRequiredMemory) + .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) + .def("get_nb_produced_data", &OperatorImpl::getNbProducedData) + .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) + ; } } diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index ac9a34e0a..b0866970c 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -23,6 +23,9 @@ void init_Operator(py::module& m){ .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_backend", &Operator::setBackend, py::arg("name")) + .def("forward", &Operator::forward) + // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! + .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) ; } } -- GitLab