Skip to content
Snippets Groups Projects
Commit aef244d0 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Initial working python registrar.

parent f0917214
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!85Initial working python registrar.
Pipeline #39650 failed
......@@ -31,7 +31,7 @@ enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelD
template <DimIdx_t DIM>
class Conv_Op : public OperatorTensor,
public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
public Registrable<Conv_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t,
DimSize_t, std::array<DimSize_t, DIM>> {
......@@ -245,4 +245,4 @@ const char *const EnumStrings<Aidge::ConvAttr>::data[] = {
};
}
#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
......@@ -115,13 +115,21 @@ public:
virtual void setDataType(const DataType& dataType) const = 0;
/**
* @brief Set the a new OperatorImpl to the Operator
* @brief Set a new OperatorImpl to the Operator
*
*/
void setImpl(std::shared_ptr<OperatorImpl> impl){
mImpl = impl;
}
/**
* @brief Get the OperatorImpl of the Operator
*
*/
std::shared_ptr<OperatorImpl> getImpl(){
return mImpl;
}
/**
* @brief Minimum amount of data from a specific input for one computation pass.
* @param inputIdx Index of the input analysed.
......
......@@ -14,6 +14,9 @@
#ifdef PYBIND
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // declare_registrable key can recquire stl
#include <pybind11/functional.h>// declare_registrable allow binding of lambda fn
#endif
#include <functional>
......@@ -57,7 +60,9 @@ struct Registrar {
Registrar(const registrar_key& key, registrar_type func) {
//printf("REGISTRAR: %s\n", key.c_str());
bool newInsert;
std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
// std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
C::registry().erase(key);
C::registry().insert(std::make_pair(key, func));
//assert(newInsert && "registrar already exists");
}
......@@ -79,6 +84,30 @@ struct Registrar {
return keys;
}
};
#ifdef PYBIND
/**
* @brief Function to define register function for a registrable class
* Defined here to have access to this function in every module who wants
* to create a new registrable class.
*
* @tparam C registrable class
* @param m pybind module
* @param class_name python name of the class
*/
template <class C>
void declare_registrable(py::module& m, const std::string& class_name){
typedef typename C::registrar_key registrar_key;
typedef typename C::registrar_type registrar_type;
m.def(("register_"+ class_name).c_str(), [](registrar_key& key, registrar_type function){
Registrar<C>(key, function);
})
.def(("get_keys_"+ class_name).c_str(), [](){
return Registrar<C>::getKeys();
});
}
#endif
}
#endif //AIDGE_CORE_UTILS_REGISTRAR_H_
......@@ -107,7 +107,7 @@ public:
void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<const Operator&>())
.def(py::init<const Operator&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>())
.def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
......
......@@ -19,13 +19,15 @@
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/Registrar.hpp" // declare_registrable
namespace py = pybind11;
namespace Aidge {
template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
const std::string pyClassName("ConvOp" + std::to_string(DIM) + "D");
py::class_<Conv_Op<DIM>, std::shared_ptr<Conv_Op<DIM>>, Attributes, OperatorTensor>(
m, ("ConvOp" + std::to_string(DIM) + "D").c_str(),
m, pyClassName.c_str(),
py::multiple_inheritance())
.def(py::init<DimSize_t,
DimSize_t,
......@@ -41,6 +43,8 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
.def("get_outputs_name", &Conv_Op<DIM>::getOutputsName)
.def("attributes_name", &Conv_Op<DIM>::staticGetAttrsName)
;
declare_registrable<Conv_Op<DIM>>(m, pyClassName);
m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels,
DimSize_t out_channels,
......@@ -66,9 +70,5 @@ void init_Conv(py::module &m) {
declare_ConvOp<1>(m);
declare_ConvOp<2>(m);
declare_ConvOp<3>(m);
// FIXME:
// m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const
// (&)[1])>(&Conv));
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
......@@ -32,10 +33,11 @@ void init_Operator(py::module& m){
.def("set_datatype", &Operator::setDataType, py::arg("dataType"))
.def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0)
.def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
// py::keep_alive forbide Python to garbage collect the implementation lambda as long as the Operator is not deleted !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
.def("get_impl", &Operator::getImpl)
.def("get_hook", &Operator::getHook)
.def("add_hook", &Operator::addHook)
;
}
}
\ No newline at end of file
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment