Skip to content
Snippets Groups Projects
Commit a8c98204 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'main' into MaxPoolingAttr

parents 874cd289 f9ca57a0
No related branches found
No related tags found
1 merge request!36[MaxPooling] Add support for ceil_mode parameter.
Pipeline #33146 failed
......@@ -6,7 +6,7 @@ file(READ "${CMAKE_SOURCE_DIR}/project_name.txt" project)
message(STATUS "Project name: ${project}")
message(STATUS "Project version: ${version}")
# Note : project name is {project} and python module name is also {project}
# Note : project name is {project} and python module name is also {project}
set(module_name _${project}) # target name
......@@ -57,7 +57,7 @@ if (PYBIND)
# Handles Python + pybind11 headers dependencies
target_link_libraries(${module_name}
PUBLIC
PUBLIC
pybind11::pybind11
PRIVATE
Python::Python
......@@ -101,8 +101,8 @@ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
install(EXPORT ${project}-targets
FILE "${project}-targets.cmake"
DESTINATION ${INSTALL_CONFIGDIR}
# COMPONENT ${module_name}
)
# COMPONENT ${module_name}
)
#Create a ConfigVersion.cmake file
include(CMakePackageConfigHelpers)
......@@ -136,4 +136,4 @@ export(EXPORT ${project}-targets
if(TEST)
enable_testing()
add_subdirectory(unit_tests)
endif()
\ No newline at end of file
endif()
......@@ -102,5 +102,30 @@ class test_operator_binding(unittest.TestCase):
genOp.get_operator().compute_output_dims()
self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims)
def test_set_impl(self):
class PythonCustomImpl(aidge_core.OperatorImpl):
"""Dummy implementation to test that C++ call python code
"""
def __init__(self):
aidge_core.OperatorImpl.__init__(self) # Recquired to avoid type error !
self.idx = 0
def forward(self):
"""Increment idx attribute on forward.
"""
self.idx += 1
generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu")
customImpl = PythonCustomImpl()
generic_op = generic_node.get_operator()
generic_op.forward() # Do nothing, no implementation set
generic_op.set_impl(customImpl)
generic_op.forward() # Increment idx
self.assertEqual(customImpl.idx, 1)
if __name__ == '__main__':
unittest.main()
......@@ -477,13 +477,14 @@ class Tensor : public Data,
if (dims().empty()) { return "{}"; }
std::string res;
std::size_t dim = 0;
std::size_t *dimVals = new std::size_t[nbDims()];
for (std::size_t i = 0; i < nbDims(); ++i) {
dimVals[i] = 0;
}
std::size_t counter = 0;
res += "{\n";
if (nbDims()>=2){
if (nbDims()>=2) {
std::size_t *dimVals = new std::size_t[nbDims()];
for (std::size_t i = 0; i < nbDims(); ++i) {
dimVals[i] = 0;
}
// std::vector<std::size_t> dimVals = std::vector<std::size_t>(nbDims(), 0);
res += "{\n";
while (counter < mSize) {
std::string spaceString = std::string((dim+1)<<1,' ');
if (dim < nbDims()-2) {
......@@ -532,31 +533,35 @@ class Tensor : public Data,
}
res += "\n";
}
if (dim == 0) {
break;
}
dimVals[dim--] = 0;
dimVals[dim]++;
}
}
for(int i = static_cast<int>(dim); i>=0; --i) {
delete[] dimVals;
for(int i = static_cast<int>(dim); i > 0; --i) {
res += std::string((dim+1)<<1,' ') + "}\n";
}
}else{
} else {
res += "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) {
switch (mDataType)
{
case DataType::Int32:
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n");
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
case DataType::Float64:
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n");
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
default:
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n");
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
}
}
}
res += "}";
return res;
}
......
......@@ -168,9 +168,20 @@ class GenericOperator_Op
void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); }
void setDatatype(const DataType & /*datatype*/) override { printf("setDatatype: not available yet.\n"); }
void forward() override final { printf("forward: not available yet.\n"); }
void backward() override final { printf("backward: not available yet.\n"); }
void forward() override final {
if(mImpl){
mImpl->forward();
}else{
printf("forward: No implementation is linked.\n");
}
}
void backward() override final {
if(mImpl){
mImpl->backward();
}else{
printf("backward: No implementation is linked.\n");
}
}
inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; };
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; };
inline IOIndex_t nbOutputs() const noexcept override final { return mNbOut; };
......
......@@ -26,7 +26,7 @@ namespace Aidge {
class Operator : public std::enable_shared_from_this<Operator> {
protected:
std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator
std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator
std::map<std::string, std::shared_ptr<Hook>> mHooks;
private:
......@@ -76,6 +76,14 @@ public:
virtual void setBackend(const std::string& name) = 0;
virtual void setDatatype(const DataType& datatype) = 0;
/**
* @brief Set the a new OperatorImpl to the Operator
*
*/
void setImpl(std::shared_ptr<OperatorImpl> impl){
mImpl = impl;
}
/**
* @brief Minimum amount of data from a specific input for one computation pass.
* @param inputIdx Index of the input analysed.
......
......@@ -35,7 +35,7 @@ public:
{
#ifdef PYBIND
#define _CRT_SECURE_NO_WARNINGS
if (std::getenv("AIDGE_CORE_WITH_PYBIND")){
if (Py_IsInitialized()){
std::string name = std::string("registrar_")+typeid(Registrable<DerivedClass, Key, Func>).name();
static auto shared_data = reinterpret_cast<std::map<Key, std::function<Func>> *>(py::get_shared_data(name));
if (!shared_data)
......@@ -78,4 +78,4 @@ struct Registrar {
};
}
#endif //AIDGE_CORE_UTILS_REGISTRAR_H_
\ No newline at end of file
#endif //AIDGE_CORE_UTILS_REGISTRAR_H_
......@@ -10,11 +10,111 @@
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/backend/OperatorImpl.hpp"
namespace py = pybind11;
namespace Aidge {
/**
* @brief Trampoline class for binding
*
*/
class pyOperatorImpl: public OperatorImpl {
public:
pyOperatorImpl(){}
void forward() override {
PYBIND11_OVERRIDE(
void,
OperatorImpl,
forward,
);
}
void backward() override {
PYBIND11_OVERRIDE(
void,
OperatorImpl,
backward,
);
}
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_nb_required_data",
getNbRequiredData,
inputIdx
);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_nb_required_protected",
getNbRequiredProtected,
inputIdx
);
}
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_required_memory",
getRequiredMemory,
outputIdx,
inputsSize
);
}
NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_nb_consumed_data",
getNbConsumedData,
inputIdx
);
}
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override {
PYBIND11_OVERRIDE_PURE_NAME(
NbElts_t,
OperatorImpl,
"get_nb_produced_data",
getNbProducedData,
outputIdx
);
}
void updateConsummerProducer() override {
PYBIND11_OVERRIDE_PURE_NAME(
void,
OperatorImpl,
"update_consummer_producer",
updateConsummerProducer,
);
}
};
void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>>(m, "OperatorImpl");
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<>())
.def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
.def("get_nb_required_protected", &OperatorImpl::getNbRequiredProtected)
.def("get_required_memory", &OperatorImpl::getRequiredMemory)
.def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData)
.def("get_nb_produced_data", &OperatorImpl::getNbProducedData)
.def("update_consummer_producer", &OperatorImpl::updateConsummerProducer)
;
}
}
......@@ -11,7 +11,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <iostream>
#include <string>
#include <vector>
#include <array>
......
......@@ -24,6 +24,9 @@ void init_Operator(py::module& m){
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDatatype, py::arg("datatype"))
.def("set_backend", &Operator::setBackend, py::arg("name"))
.def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
;
}
}
......@@ -49,14 +49,8 @@ void init_Recipies(py::module&);
void init_Scheduler(py::module&);
void init_TensorUtils(py::module&);
void set_python_flag(){
// Set an env variable to know if we run with ypthon or cpp
py::module os_module = py::module::import("os");
os_module.attr("environ")["AIDGE_CORE_WITH_PYBIND"] = "1";
}
void init_Aidge(py::module& m){
set_python_flag();
init_Data(m);
init_Tensor(m);
......@@ -78,6 +72,7 @@ void init_Aidge(py::module& m){
init_LeakyReLU(m);
init_MatMul(m);
init_MaxPooling(m);
init_MetaOperatorDefs(m);
init_ReLU(m);
init_Softmax(m);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment