diff --git a/CMakeLists.txt b/CMakeLists.txt index 40d8837f41bdc0d8dfd7eac1c5960064967f1efb..f8dbe375e217020a4c4570bd67c1b466e6593130 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ file(READ "${CMAKE_SOURCE_DIR}/project_name.txt" project) message(STATUS "Project name: ${project}") message(STATUS "Project version: ${version}") -# Note : project name is {project} and python module name is also {project} +# Note : project name is {project} and python module name is also {project} set(module_name _${project}) # target name @@ -57,7 +57,7 @@ if (PYBIND) # Handles Python + pybind11 headers dependencies target_link_libraries(${module_name} - PUBLIC + PUBLIC pybind11::pybind11 PRIVATE Python::Python @@ -101,8 +101,8 @@ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) install(EXPORT ${project}-targets FILE "${project}-targets.cmake" DESTINATION ${INSTALL_CONFIGDIR} -# COMPONENT ${module_name} -) +# COMPONENT ${module_name} +) #Create a ConfigVersion.cmake file include(CMakePackageConfigHelpers) @@ -136,4 +136,4 @@ export(EXPORT ${project}-targets if(TEST) enable_testing() add_subdirectory(unit_tests) -endif() \ No newline at end of file +endif() diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index fc60f52274162155f8f891bf86c22c9a13b241f4..c7279afed2aed00981d0b15002b1676abcaef72e 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -102,5 +102,30 @@ class test_operator_binding(unittest.TestCase): genOp.get_operator().compute_output_dims() self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) + def test_set_impl(self): + + class PythonCustomImpl(aidge_core.OperatorImpl): + """Dummy implementation to test that C++ call python code + """ + def __init__(self): + aidge_core.OperatorImpl.__init__(self) # Recquired to avoid type error ! + self.idx = 0 + + def forward(self): + """Increment idx attribute on forward. + """ + self.idx += 1 + + generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu") + customImpl = PythonCustomImpl() + generic_op = generic_node.get_operator() + + generic_op.forward() # Do nothing, no implementation set + generic_op.set_impl(customImpl) + generic_op.forward() # Increment idx + self.assertEqual(customImpl.idx, 1) + + + if __name__ == '__main__': unittest.main() diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 83b9a932633deb822ad86c24b96e6e928b5e2be2..55ccbf1516fa79663d57e1e44bc4017bc5c8b843 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -168,9 +168,20 @@ class GenericOperator_Op void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); } void setDatatype(const DataType & /*datatype*/) override { printf("setDatatype: not available yet.\n"); } - void forward() override final { printf("forward: not available yet.\n"); } - void backward() override final { printf("backward: not available yet.\n"); } - + void forward() override final { + if(mImpl){ + mImpl->forward(); + }else{ + printf("forward: No implementation is linked.\n"); + } + } + void backward() override final { + if(mImpl){ + mImpl->backward(); + }else{ + printf("backward: No implementation is linked.\n"); + } + } inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; }; inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; }; inline IOIndex_t nbOutputs() const noexcept override final { return mNbOut; }; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index a99e4e8ed37aeaa647da1dcaaa994b070901129b..903b6362adf3db0c867dc419086e0cb6ddaa65c7 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -26,7 +26,7 @@ namespace Aidge { class Operator : public std::enable_shared_from_this<Operator> { protected: - std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator + std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator std::map<std::string, std::shared_ptr<Hook>> mHooks; private: @@ -76,6 +76,14 @@ public: virtual void setBackend(const std::string& name) = 0; virtual void setDatatype(const DataType& datatype) = 0; + /** + * @brief Set the a new OperatorImpl to the Operator + * + */ + void setImpl(std::shared_ptr<OperatorImpl> impl){ + mImpl = impl; + } + /** * @brief Minimum amount of data from a specific input for one computation pass. * @param inputIdx Index of the input analysed. diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index 3b29c472b3a540c9ef3b8ed46520e3e718e8cbfb..ece74509d466800c870d73d1e0bbe1d639f8bf54 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -35,7 +35,7 @@ public: { #ifdef PYBIND #define _CRT_SECURE_NO_WARNINGS - if (std::getenv("AIDGE_CORE_WITH_PYBIND")){ + if (Py_IsInitialized()){ std::string name = std::string("registrar_")+typeid(Registrable<DerivedClass, Key, Func>).name(); static auto shared_data = reinterpret_cast<std::map<Key, std::function<Func>> *>(py::get_shared_data(name)); if (!shared_data) @@ -78,4 +78,4 @@ struct Registrar { }; } -#endif //AIDGE_CORE_UTILS_REGISTRAR_H_ \ No newline at end of file +#endif //AIDGE_CORE_UTILS_REGISTRAR_H_ diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 11189f2f3c4a46b31d8e08d73bea17f27df07765..8eb9e2649b19374e4346be18f9a3ab8070e4dc3c 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -10,11 +10,111 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> + #include "aidge/backend/OperatorImpl.hpp" namespace py = pybind11; namespace Aidge { + +/** + * @brief Trampoline class for binding + * + */ +class pyOperatorImpl: public OperatorImpl { + public: + pyOperatorImpl(){} + + void forward() override { + PYBIND11_OVERRIDE( + void, + OperatorImpl, + forward, + + ); + } + void backward() override { + PYBIND11_OVERRIDE( + void, + OperatorImpl, + backward, + + ); + } + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_required_data", + getNbRequiredData, + inputIdx + ); + } + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_required_protected", + getNbRequiredProtected, + inputIdx + + ); + } + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t> &inputsSize) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_required_memory", + getRequiredMemory, + outputIdx, + inputsSize + + ); + } + NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_consumed_data", + getNbConsumedData, + inputIdx + + ); + } + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override { + PYBIND11_OVERRIDE_PURE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_produced_data", + getNbProducedData, + outputIdx + + ); + } + void updateConsummerProducer() override { + PYBIND11_OVERRIDE_PURE_NAME( + void, + OperatorImpl, + "update_consummer_producer", + updateConsummerProducer, + + ); + } +}; + void init_OperatorImpl(py::module& m){ - py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>>(m, "OperatorImpl"); + + py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) + .def(py::init<>()) + .def("forward", &OperatorImpl::forward) + .def("backward", &OperatorImpl::backward) + .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) + .def("get_nb_required_protected", &OperatorImpl::getNbRequiredProtected) + .def("get_required_memory", &OperatorImpl::getRequiredMemory) + .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) + .def("get_nb_produced_data", &OperatorImpl::getNbProducedData) + .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) + ; } } diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 3801fac8a8ca8461fe6ec74cf75313fc362d15d4..f4f7946c6ecc180f83e4bf58eee16102752f0c6e 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -11,7 +11,7 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> - +#include <iostream> #include <string> #include <vector> #include <array> diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index d945b212ff6fb643302ca7512e91c7a778a39419..6b535e8cf3293b26aaa64f95ca2f9a394768935f 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -24,6 +24,9 @@ void init_Operator(py::module& m){ .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_backend", &Operator::setBackend, py::arg("name")) + .def("forward", &Operator::forward) + // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! + .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) ; } } diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 7c58c61d4f6f44a69ae460f0c71221a82e30e295..04e39b11e58718dfcc5f9faef24b140132367700 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -49,14 +49,8 @@ void init_Recipies(py::module&); void init_Scheduler(py::module&); void init_TensorUtils(py::module&); -void set_python_flag(){ - // Set an env variable to know if we run with ypthon or cpp - py::module os_module = py::module::import("os"); - os_module.attr("environ")["AIDGE_CORE_WITH_PYBIND"] = "1"; -} void init_Aidge(py::module& m){ - set_python_flag(); init_Data(m); init_Tensor(m);