Skip to content
Snippets Groups Projects
Commit 66f3a866 authored by Cyril Moineau's avatar Cyril Moineau Committed by Maxence Naud
Browse files

Initial working python registrar.

parent 4c08fa6c
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!85Initial working python registrar.
......@@ -31,7 +31,7 @@ enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelD
template <DimIdx_t DIM>
class Conv_Op : public OperatorTensor,
public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
public Registrable<Conv_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t,
DimSize_t, std::array<DimSize_t, DIM>> {
......@@ -245,4 +245,4 @@ const char *const EnumStrings<Aidge::ConvAttr>::data[] = {
};
}
#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
......@@ -115,15 +115,21 @@ public:
virtual void setDataType(const DataType& dataType) const = 0;
/**
* @brief Set the a new OperatorImpl to the Operator
* @brief Set a new OperatorImpl to the Operator
*
*/
inline void setImpl(std::shared_ptr<OperatorImpl> impl) { mImpl = impl; }
/**
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
* @brief Get the OperatorImpl of the Operator
*
*/
std::shared_ptr<OperatorImpl> getImpl(){
return mImpl;
}
/**
* @brief Minimum amount of data from a specific input for one computation pass.
* @param inputIdx Index of the input analysed.
* @return NbElts_t
*/
......
......@@ -14,6 +14,9 @@
#ifdef PYBIND
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // declare_registrable key can recquire stl
#include <pybind11/functional.h>// declare_registrable allow binding of lambda fn
#endif
#include "aidge/utils/ErrorHandling.hpp"
......@@ -59,7 +62,9 @@ struct Registrar {
Registrar(const registrar_key& key, registrar_type func) {
//fmt::print("REGISTRAR: {}\n", key);
bool newInsert;
std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
// std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
C::registry().erase(key);
C::registry().insert(std::make_pair(key, func));
//assert(newInsert && "registrar already exists");
}
......@@ -81,6 +86,30 @@ struct Registrar {
return keys;
}
};
#ifdef PYBIND
/**
* @brief Function to define register function for a registrable class
* Defined here to have access to this function in every module who wants
* to create a new registrable class.
*
* @tparam C registrable class
* @param m pybind module
* @param class_name python name of the class
*/
template <class C>
void declare_registrable(py::module& m, const std::string& class_name){
typedef typename C::registrar_key registrar_key;
typedef typename C::registrar_type registrar_type;
m.def(("register_"+ class_name).c_str(), [](registrar_key& key, registrar_type function){
Registrar<C>(key, function);
})
.def(("get_keys_"+ class_name).c_str(), [](){
return Registrar<C>::getKeys();
});
}
#endif
}
#endif //AIDGE_CORE_UTILS_REGISTRAR_H_
......@@ -116,7 +116,7 @@ public:
void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<const Operator&>())
.def(py::init<const Operator&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>())
.def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
......
......@@ -19,13 +19,15 @@
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/Registrar.hpp" // declare_registrable
namespace py = pybind11;
namespace Aidge {
template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
const std::string pyClassName("ConvOp" + std::to_string(DIM) + "D");
py::class_<Conv_Op<DIM>, std::shared_ptr<Conv_Op<DIM>>, Attributes, OperatorTensor>(
m, ("ConvOp" + std::to_string(DIM) + "D").c_str(),
m, pyClassName.c_str(),
py::multiple_inheritance())
.def(py::init<DimSize_t,
DimSize_t,
......@@ -41,6 +43,8 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
.def("get_outputs_name", &Conv_Op<DIM>::getOutputsName)
.def("attributes_name", &Conv_Op<DIM>::staticGetAttrsName)
;
declare_registrable<Conv_Op<DIM>>(m, pyClassName);
m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels,
DimSize_t out_channels,
......@@ -66,9 +70,5 @@ void init_Conv(py::module &m) {
declare_ConvOp<1>(m);
declare_ConvOp<2>(m);
declare_ConvOp<3>(m);
// FIXME:
// m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const
// (&)[1])>(&Conv));
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
......@@ -32,10 +33,11 @@ void init_Operator(py::module& m){
.def("set_datatype", &Operator::setDataType, py::arg("dataType"))
.def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0)
.def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
// py::keep_alive forbide Python to garbage collect the implementation lambda as long as the Operator is not deleted !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
.def("get_impl", &Operator::getImpl)
.def("get_hook", &Operator::getHook)
.def("add_hook", &Operator::addHook)
;
}
}
\ No newline at end of file
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment