diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 184100174714df5fc059e374cb85549f6bfd4135..dc70646fbc8d53059ba04ff757fa7c3e22dcc6d0 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -206,8 +206,20 @@ class GenericOperator_Op void setBackend(const std::string & /*name*/) { printf("setBackend: not available yet.\n"); } void setDatatype(const DataType & /*datatype*/) { printf("setDatatype: not available yet.\n"); } - void forward() override final { printf("forward: not available yet.\n"); } - void backward() override final { printf("backward: not available yet.\n"); } + void forward() override final { + if(mImpl){ + mImpl->forward(); + }else{ + printf("forward: No implementation is linked.\n"); + } + } + void backward() override final { + if(mImpl){ + mImpl->backward(); + }else{ + printf("backward: No implementation is linked.\n"); + } + } inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; }; inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; }; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 3ac651cfd6f700a129e36fb461f948f50137cfd6..2582f83c45b431aad51e95ee4ba43e0db8abfe5a 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -26,7 +26,7 @@ namespace Aidge { class Operator : public std::enable_shared_from_this<Operator> { protected: - std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator + std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator std::map<std::string, std::shared_ptr<Hook>> mHooks; private: @@ -76,6 +76,14 @@ public: virtual void setBackend(const std::string& name) = 0; virtual void setDatatype(const DataType& datatype) = 0; + /** + * @brief Set the a new OperatorImpl to the Operator + * + */ + void setImpl(std::shared_ptr<OperatorImpl> impl){ + mImpl = impl; + } + /** * @brief Minimum amount of data from a specific input for one computation pass. * @param inputIdx Index of the input analysed. diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 11189f2f3c4a46b31d8e08d73bea17f27df07765..3a104f162ab2b1f867d72fbf8fdc3ea035bbf25f 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -14,7 +14,105 @@ namespace py = pybind11; namespace Aidge { + +/** + * @brief Trampoline class for binding + * + */ +class pyOperatorImpl: public OperatorImpl { + public: + pyOperatorImpl(){} + + void forward() override { + PYBIND11_OVERRIDE( + void, + OperatorImpl, + forward, + + ); + } + void backward() override { + PYBIND11_OVERRIDE( + void, + OperatorImpl, + backward, + + ); + } + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_required_data", + getNbRequiredData, + inputIdx + ); + } + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_required_protected", + getNbRequiredProtected, + inputIdx + + ); + } + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t> &inputsSize) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_required_memory", + getRequiredMemory, + outputIdx, + inputsSize + + ); + } + NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_consumed_data", + getNbConsumedData, + inputIdx + + ); + } + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_produced_data", + getNbProducedData, + outputIdx + + ); + } + void updateConsummerProducer() override { + PYBIND11_OVERRIDE_PURE_NAME( + void, + OperatorImpl, + "update_consummer_producer", + updateConsummerProducer, + + ); + } +}; + void init_OperatorImpl(py::module& m){ - py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>>(m, "OperatorImpl"); + + py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) + .def(py::init<>()) + .def("forward", &OperatorImpl::forward) + .def("backward", &OperatorImpl::backward) + .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) + .def("get_nb_required_protected", &OperatorImpl::getNbRequiredProtected) + .def("get_required_memory", &OperatorImpl::getRequiredMemory) + .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) + .def("get_nb_produced_data", &OperatorImpl::getNbProducedData) + .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) + ; } } diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index ac9a34e0a14ace2cf264188302f52a27bf0f7222..b0866970c57798f11fe2efff6777e11af53dd37e 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -23,6 +23,9 @@ void init_Operator(py::module& m){ .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_backend", &Operator::setBackend, py::arg("name")) + .def("forward", &Operator::forward) + // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! + .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) ; } }