diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 8290fb3d0d978e9af3291809c5057406424096d5..97e27f8d387943d70b410d83b82d8fce2d664142 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -31,7 +31,7 @@ enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelD template <DimIdx_t DIM> class Conv_Op : public OperatorTensor, - public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, + public Registrable<Conv_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, DimSize_t, std::array<DimSize_t, DIM>> { @@ -245,4 +245,4 @@ const char *const EnumStrings<Aidge::ConvAttr>::data[] = { }; } -#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */ diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index a0d2292b7860baa60fe537698784d4d250c81f42..0544683487e4069dce9d4c5ff6df0433f27ed827 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -115,15 +115,21 @@ public: virtual void setDataType(const DataType& dataType) const = 0; /** - * @brief Set the a new OperatorImpl to the Operator + * @brief Set a new OperatorImpl to the Operator * */ inline void setImpl(std::shared_ptr<OperatorImpl> impl) { mImpl = impl; } /** - * @brief Minimum amount of data from a specific input required by the - * implementation to be run. + * @brief Get the OperatorImpl of the Operator * + */ + std::shared_ptr<OperatorImpl> getImpl(){ + return mImpl; + } + + /** + * @brief Minimum amount of data from a specific input for one computation pass. * @param inputIdx Index of the input analysed. * @return NbElts_t */ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index 4d604d520d3d8af532e196c7785896ddc1c242d0..a7214cdc61081978c74f6e172c2389ac82e51aa6 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -14,6 +14,9 @@ #ifdef PYBIND #include <pybind11/pybind11.h> +#include <pybind11/stl.h> // declare_registrable key can recquire stl +#include <pybind11/functional.h>// declare_registrable allow binding of lambda fn + #endif #include "aidge/utils/ErrorHandling.hpp" @@ -59,7 +62,9 @@ struct Registrar { Registrar(const registrar_key& key, registrar_type func) { //fmt::print("REGISTRAR: {}\n", key); bool newInsert; - std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func)); + // std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func)); + C::registry().erase(key); + C::registry().insert(std::make_pair(key, func)); //assert(newInsert && "registrar already exists"); } @@ -81,6 +86,30 @@ struct Registrar { return keys; } }; + +#ifdef PYBIND +/** + * @brief Function to define register function for a registrable class + * Defined here to have access to this function in every module who wants + * to create a new registrable class. + * + * @tparam C registrable class + * @param m pybind module + * @param class_name python name of the class + */ +template <class C> +void declare_registrable(py::module& m, const std::string& class_name){ + typedef typename C::registrar_key registrar_key; + typedef typename C::registrar_type registrar_type; + m.def(("register_"+ class_name).c_str(), [](registrar_key& key, registrar_type function){ + Registrar<C>(key, function); + }) + .def(("get_keys_"+ class_name).c_str(), [](){ + return Registrar<C>::getKeys(); + }); +} +#endif + } #endif //AIDGE_CORE_UTILS_REGISTRAR_H_ diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index a2a5e6b8bb2d0f2413ef94c360b383608c5b41b5..91d65484a122d6a651758e16eb0e925b6e0bfdd0 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -116,7 +116,7 @@ public: void init_OperatorImpl(py::module& m){ py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) - .def(py::init<const Operator&>()) + .def(py::init<const Operator&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>()) .def("forward", &OperatorImpl::forward) .def("backward", &OperatorImpl::backward) .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 346acc5d9d05c24e9538c3b8c5edf1f7e37d6ba8..aea402017622655a577ac4f9e207141bff01d70d 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -19,13 +19,15 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/Registrar.hpp" // declare_registrable namespace py = pybind11; namespace Aidge { template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { + const std::string pyClassName("ConvOp" + std::to_string(DIM) + "D"); py::class_<Conv_Op<DIM>, std::shared_ptr<Conv_Op<DIM>>, Attributes, OperatorTensor>( - m, ("ConvOp" + std::to_string(DIM) + "D").c_str(), + m, pyClassName.c_str(), py::multiple_inheritance()) .def(py::init<DimSize_t, DimSize_t, @@ -41,6 +43,8 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { .def("get_outputs_name", &Conv_Op<DIM>::getOutputsName) .def("attributes_name", &Conv_Op<DIM>::staticGetAttrsName) ; + declare_registrable<Conv_Op<DIM>>(m, pyClassName); + m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, DimSize_t out_channels, @@ -66,9 +70,5 @@ void init_Conv(py::module &m) { declare_ConvOp<1>(m); declare_ConvOp<2>(m); declare_ConvOp<3>(m); - - // FIXME: - // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const - // (&)[1])>(&Conv)); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index 79a85cb92cf27c7edb745c36eefe61ae86c66786..05d6cd089754d1155e1506b4a491af7919bc4d31 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -1,3 +1,4 @@ + /******************************************************************************** * Copyright (c) 2023 CEA-List * @@ -32,10 +33,11 @@ void init_Operator(py::module& m){ .def("set_datatype", &Operator::setDataType, py::arg("dataType")) .def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0) .def("forward", &Operator::forward) - // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! + // py::keep_alive forbide Python to garbage collect the implementation lambda as long as the Operator is not deleted ! .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) + .def("get_impl", &Operator::getImpl) .def("get_hook", &Operator::getHook) .def("add_hook", &Operator::addHook) ; } -} \ No newline at end of file +}