Skip to content
Snippets Groups Projects
Commit f9ca57a0 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'PythonOpImpl' into 'main'

[OperatorImpl] Python OperatorImpl

Closes #24

See merge request !19
parents 4feec190 f08fe6c5
No related branches found
No related tags found
1 merge request!19[OperatorImpl] Python OperatorImpl
Pipeline #33099 passed
......@@ -6,7 +6,7 @@ file(READ "${CMAKE_SOURCE_DIR}/project_name.txt" project)
message(STATUS "Project name: ${project}")
message(STATUS "Project version: ${version}")
# Note : project name is {project} and python module name is also {project}
# Note : project name is {project} and python module name is also {project}
set(module_name _${project}) # target name
......@@ -57,7 +57,7 @@ if (PYBIND)
# Handles Python + pybind11 headers dependencies
target_link_libraries(${module_name}
PUBLIC
PUBLIC
pybind11::pybind11
PRIVATE
Python::Python
......@@ -101,8 +101,8 @@ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
install(EXPORT ${project}-targets
FILE "${project}-targets.cmake"
DESTINATION ${INSTALL_CONFIGDIR}
# COMPONENT ${module_name}
)
# COMPONENT ${module_name}
)
#Create a ConfigVersion.cmake file
include(CMakePackageConfigHelpers)
......@@ -136,4 +136,4 @@ export(EXPORT ${project}-targets
if(TEST)
enable_testing()
add_subdirectory(unit_tests)
endif()
\ No newline at end of file
endif()
......@@ -102,5 +102,30 @@ class test_operator_binding(unittest.TestCase):
genOp.get_operator().compute_output_dims()
self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims)
def test_set_impl(self):
class PythonCustomImpl(aidge_core.OperatorImpl):
"""Dummy implementation to test that C++ call python code
"""
def __init__(self):
aidge_core.OperatorImpl.__init__(self) # Recquired to avoid type error !
self.idx = 0
def forward(self):
"""Increment idx attribute on forward.
"""
self.idx += 1
generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu")
customImpl = PythonCustomImpl()
generic_op = generic_node.get_operator()
generic_op.forward() # Do nothing, no implementation set
generic_op.set_impl(customImpl)
generic_op.forward() # Increment idx
self.assertEqual(customImpl.idx, 1)
if __name__ == '__main__':
unittest.main()
......@@ -168,9 +168,20 @@ class GenericOperator_Op
void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); }
void setDatatype(const DataType & /*datatype*/) override { printf("setDatatype: not available yet.\n"); }
void forward() override final { printf("forward: not available yet.\n"); }
void backward() override final { printf("backward: not available yet.\n"); }
void forward() override final {
if(mImpl){
mImpl->forward();
}else{
printf("forward: No implementation is linked.\n");
}
}
void backward() override final {
if(mImpl){
mImpl->backward();
}else{
printf("backward: No implementation is linked.\n");
}
}
inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; };
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; };
inline IOIndex_t nbOutputs() const noexcept override final { return mNbOut; };
......
......@@ -26,7 +26,7 @@ namespace Aidge {
class Operator : public std::enable_shared_from_this<Operator> {
protected:
std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator
std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator
std::map<std::string, std::shared_ptr<Hook>> mHooks;
private:
......@@ -76,6 +76,14 @@ public:
virtual void setBackend(const std::string& name) = 0;
virtual void setDatatype(const DataType& datatype) = 0;
/**
* @brief Set the a new OperatorImpl to the Operator
*
*/
void setImpl(std::shared_ptr<OperatorImpl> impl){
mImpl = impl;
}
/**
* @brief Minimum amount of data from a specific input for one computation pass.
* @param inputIdx Index of the input analysed.
......
......@@ -35,7 +35,7 @@ public:
{
#ifdef PYBIND
#define _CRT_SECURE_NO_WARNINGS
if (std::getenv("AIDGE_CORE_WITH_PYBIND")){
if (Py_IsInitialized()){
std::string name = std::string("registrar_")+typeid(Registrable<DerivedClass, Key, Func>).name();
static auto shared_data = reinterpret_cast<std::map<Key, std::function<Func>> *>(py::get_shared_data(name));
if (!shared_data)
......@@ -78,4 +78,4 @@ struct Registrar {
};
}
#endif //AIDGE_CORE_UTILS_REGISTRAR_H_
\ No newline at end of file
#endif //AIDGE_CORE_UTILS_REGISTRAR_H_
......@@ -10,11 +10,111 @@
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/backend/OperatorImpl.hpp"
namespace py = pybind11;
namespace Aidge {
/**
* @brief Trampoline class for binding
*
*/
class pyOperatorImpl: public OperatorImpl {
public:
pyOperatorImpl(){}
void forward() override {
PYBIND11_OVERRIDE(
void,
OperatorImpl,
forward,
);
}
void backward() override {
PYBIND11_OVERRIDE(
void,
OperatorImpl,
backward,
);
}
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_nb_required_data",
getNbRequiredData,
inputIdx
);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_nb_required_protected",
getNbRequiredProtected,
inputIdx
);
}
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_required_memory",
getRequiredMemory,
outputIdx,
inputsSize
);
}
NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_nb_consumed_data",
getNbConsumedData,
inputIdx
);
}
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_nb_produced_data",
getNbProducedData,
outputIdx
);
}
void updateConsummerProducer() override {
PYBIND11_OVERRIDE_PURE_NAME(
void,
OperatorImpl,
"update_consummer_producer",
updateConsummerProducer,
);
}
};
void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>>(m, "OperatorImpl");
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<>())
.def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
.def("get_nb_required_protected", &OperatorImpl::getNbRequiredProtected)
.def("get_required_memory", &OperatorImpl::getRequiredMemory)
.def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData)
.def("get_nb_produced_data", &OperatorImpl::getNbProducedData)
.def("update_consummer_producer", &OperatorImpl::updateConsummerProducer)
;
}
}
......@@ -11,7 +11,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <iostream>
#include <string>
#include <vector>
#include <array>
......
......@@ -24,6 +24,9 @@ void init_Operator(py::module& m){
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDatatype, py::arg("datatype"))
.def("set_backend", &Operator::setBackend, py::arg("name"))
.def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
;
}
}
......@@ -49,14 +49,8 @@ void init_Recipies(py::module&);
void init_Scheduler(py::module&);
void init_TensorUtils(py::module&);
void set_python_flag(){
// Set an env variable to know if we run with ypthon or cpp
py::module os_module = py::module::import("os");
os_module.attr("environ")["AIDGE_CORE_WITH_PYBIND"] = "1";
}
void init_Aidge(py::module& m){
set_python_flag();
init_Data(m);
init_Tensor(m);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment