Skip to content
Snippets Groups Projects
Commit 610ef01a authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[OperatorImpl] Add pybind/stl include.

parent 9cd6e1c9
No related branches found
No related tags found
1 merge request!19[OperatorImpl] Python OperatorImpl
...@@ -6,7 +6,7 @@ file(READ "${CMAKE_SOURCE_DIR}/project_name.txt" project) ...@@ -6,7 +6,7 @@ file(READ "${CMAKE_SOURCE_DIR}/project_name.txt" project)
message(STATUS "Project name: ${project}") message(STATUS "Project name: ${project}")
message(STATUS "Project version: ${version}") message(STATUS "Project version: ${version}")
# Note : project name is {project} and python module name is also {project} # Note : project name is {project} and python module name is also {project}
set(module_name _${project}) # target name set(module_name _${project}) # target name
...@@ -57,7 +57,7 @@ if (PYBIND) ...@@ -57,7 +57,7 @@ if (PYBIND)
# Handles Python + pybind11 headers dependencies # Handles Python + pybind11 headers dependencies
target_link_libraries(${module_name} target_link_libraries(${module_name}
PUBLIC PUBLIC
pybind11::pybind11 pybind11::pybind11
PRIVATE PRIVATE
Python::Python Python::Python
...@@ -101,8 +101,8 @@ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) ...@@ -101,8 +101,8 @@ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
install(EXPORT ${project}-targets install(EXPORT ${project}-targets
FILE "${project}-targets.cmake" FILE "${project}-targets.cmake"
DESTINATION ${INSTALL_CONFIGDIR} DESTINATION ${INSTALL_CONFIGDIR}
# COMPONENT ${module_name} # COMPONENT ${module_name}
) )
#Create a ConfigVersion.cmake file #Create a ConfigVersion.cmake file
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
...@@ -136,4 +136,4 @@ export(EXPORT ${project}-targets ...@@ -136,4 +136,4 @@ export(EXPORT ${project}-targets
if(TEST) if(TEST)
enable_testing() enable_testing()
add_subdirectory(unit_tests) add_subdirectory(unit_tests)
endif() endif()
\ No newline at end of file
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
********************************************************************************/ ********************************************************************************/
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
namespace py = pybind11; namespace py = pybind11;
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <array> #include <array>
...@@ -40,16 +40,16 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { ...@@ -40,16 +40,16 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
py::arg("stride_dims"), py::arg("stride_dims"),
py::arg("padding_dims"), py::arg("padding_dims"),
py::arg("dilation_dims")); py::arg("dilation_dims"));
m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels,
DimSize_t out_channels, DimSize_t out_channels,
const std::vector<DimSize_t>& kernel_dims, const std::vector<DimSize_t>& kernel_dims,
const std::string& name, const std::string& name,
const std::vector<DimSize_t> &stride_dims, const std::vector<DimSize_t> &stride_dims,
const std::vector<DimSize_t> &padding_dims, const std::vector<DimSize_t> &padding_dims,
const std::vector<DimSize_t> &dilation_dims) { const std::vector<DimSize_t> &dilation_dims) {
// Lambda function wrapper because PyBind fails to convert const array. // Lambda function wrapper because PyBind fails to convert const array.
// So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array.
if (kernel_dims.size() != DIM) { if (kernel_dims.size() != DIM) {
throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]");
} }
...@@ -90,7 +90,7 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { ...@@ -90,7 +90,7 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1),
py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0),
py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1));
} }
...@@ -98,7 +98,7 @@ void init_Conv(py::module &m) { ...@@ -98,7 +98,7 @@ void init_Conv(py::module &m) {
declare_ConvOp<1>(m); declare_ConvOp<1>(m);
declare_ConvOp<2>(m); declare_ConvOp<2>(m);
declare_ConvOp<3>(m); declare_ConvOp<3>(m);
// FIXME: // FIXME:
// m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const
// (&)[1])>(&Conv)); // (&)[1])>(&Conv));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment