diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index be5fb3e393ced7ee7a53e27426b4247e48b478e8..6209d000665f8a5cf27958277b0a60e9f911b438 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -31,7 +31,7 @@ enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelD template <DimIdx_t DIM> class Conv_Op : public OperatorTensor, - public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, + public Registrable<Conv_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, DimSize_t, std::array<DimSize_t, DIM>> { @@ -245,4 +245,4 @@ const char *const EnumStrings<Aidge::ConvAttr>::data[] = { }; } -#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */ diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index cebc2d54041bb38c6e7f3434f12b559cec3d80af..d6ac2561654afa8d14ed2ac66ef1430ecd00ffee 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -115,13 +115,21 @@ public: virtual void setDataType(const DataType& dataType) const = 0; /** - * @brief Set the a new OperatorImpl to the Operator + * @brief Set a new OperatorImpl to the Operator * */ void setImpl(std::shared_ptr<OperatorImpl> impl){ mImpl = impl; } + /** + * @brief Get the OperatorImpl of the Operator + * + */ + std::shared_ptr<OperatorImpl> getImpl(){ + return mImpl; + } + /** * @brief Minimum amount of data from a specific input for one computation pass. * @param inputIdx Index of the input analysed. diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index 66a07eb0ce21354b20f1ca416cc68d26d9bd6280..4479c3202a4af8fc37fb37c0aa2f41f186d99bd2 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -14,6 +14,9 @@ #ifdef PYBIND #include <pybind11/pybind11.h> +#include <pybind11/stl.h> // declare_registrable key can recquire stl +#include <pybind11/functional.h>// declare_registrable allow binding of lambda fn + #endif #include <functional> @@ -57,7 +60,9 @@ struct Registrar { Registrar(const registrar_key& key, registrar_type func) { //printf("REGISTRAR: %s\n", key.c_str()); bool newInsert; - std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func)); + // std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func)); + C::registry().erase(key); + C::registry().insert(std::make_pair(key, func)); //assert(newInsert && "registrar already exists"); } @@ -79,6 +84,30 @@ struct Registrar { return keys; } }; + +#ifdef PYBIND +/** + * @brief Function to define register function for a registrable class + * Defined here to have access to this function in every module who wants + * to create a new registrable class. + * + * @tparam C registrable class + * @param m pybind module + * @param class_name python name of the class + */ +template <class C> +void declare_registrable(py::module& m, const std::string& class_name){ + typedef typename C::registrar_key registrar_key; + typedef typename C::registrar_type registrar_type; + m.def(("register_"+ class_name).c_str(), [](registrar_key& key, registrar_type function){ + Registrar<C>(key, function); + }) + .def(("get_keys_"+ class_name).c_str(), [](){ + return Registrar<C>::getKeys(); + }); +} +#endif + } #endif //AIDGE_CORE_UTILS_REGISTRAR_H_ diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 34610069079ee792ebbe4b261b57177b3bbe2997..5e63d77b7a06948127975d311ecb0c0f5951b4de 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -107,7 +107,7 @@ public: void init_OperatorImpl(py::module& m){ py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) - .def(py::init<const Operator&>()) + .def(py::init<const Operator&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>()) .def("forward", &OperatorImpl::forward) .def("backward", &OperatorImpl::backward) .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 455ea4024438b97b7ac6f07e5fc6722658b42ea4..56c537d602ba19825f940682de85ec885a49766f 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -19,13 +19,15 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/Registrar.hpp" // declare_registrable namespace py = pybind11; namespace Aidge { template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { + const std::string pyClassName("ConvOp" + std::to_string(DIM) + "D"); py::class_<Conv_Op<DIM>, std::shared_ptr<Conv_Op<DIM>>, Attributes, OperatorTensor>( - m, ("ConvOp" + std::to_string(DIM) + "D").c_str(), + m, pyClassName.c_str(), py::multiple_inheritance()) .def(py::init<DimSize_t, DimSize_t, @@ -41,6 +43,8 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { .def("get_outputs_name", &Conv_Op<DIM>::getOutputsName) .def("attributes_name", &Conv_Op<DIM>::staticGetAttrsName) ; + declare_registrable<Conv_Op<DIM>>(m, pyClassName); + m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, DimSize_t out_channels, @@ -66,9 +70,5 @@ void init_Conv(py::module &m) { declare_ConvOp<1>(m); declare_ConvOp<2>(m); declare_ConvOp<3>(m); - - // FIXME: - // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const - // (&)[1])>(&Conv)); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index 79a85cb92cf27c7edb745c36eefe61ae86c66786..05d6cd089754d1155e1506b4a491af7919bc4d31 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -1,3 +1,4 @@ + /******************************************************************************** * Copyright (c) 2023 CEA-List * @@ -32,10 +33,11 @@ void init_Operator(py::module& m){ .def("set_datatype", &Operator::setDataType, py::arg("dataType")) .def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0) .def("forward", &Operator::forward) - // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! + // py::keep_alive forbide Python to garbage collect the implementation lambda as long as the Operator is not deleted ! .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) + .def("get_impl", &Operator::getImpl) .def("get_hook", &Operator::getHook) .def("add_hook", &Operator::addHook) ; } -} \ No newline at end of file +}