Skip to content
Snippets Groups Projects
Commit 30df0333 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Missing removal in MaxPooling binding

parent 72fa1b79
No related branches found
No related tags found
1 merge request!11Removed padding from conv and pool and added Pad operator
Pipeline #31920 passed
......@@ -31,16 +31,13 @@ template <DimIdx_t DIM> void declare_MaxPoolingOp(py::module &m) {
m, ("MaxPoolingOp" + std::to_string(DIM) + "D").c_str(),
py::multiple_inheritance())
.def(py::init<const std::array<DimSize_t, DIM> &,
const std::array<DimSize_t, DIM> &,
const std::array<DimSize_t, (DIM<<1)> &>(),
const std::array<DimSize_t, DIM> &>(),
py::arg("kernel_dims"),
py::arg("stride_dims"),
py::arg("padding_dims"));
py::arg("stride_dims"));
m.def(("MaxPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims,
const std::string& name,
const std::vector<DimSize_t> &stride_dims,
const std::vector<DimSize_t> &padding_dims) {
const std::vector<DimSize_t> &stride_dims) {
// Lambda function wrapper because PyBind fails to convert const array.
// So we use a vector that we convert in this function to a const DimeSize_t [DIM] array.
if (kernel_dims.size() != DIM) {
......@@ -49,9 +46,6 @@ template <DimIdx_t DIM> void declare_MaxPoolingOp(py::module &m) {
if (stride_dims.size() != DIM) {
throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]");
}
if (padding_dims.size() != (DIM<<1)) {
throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]");
}
DimSize_t tmp_kernel_dims_array[DIM];
for (size_t i = 0; i < DIM; ++i) {
tmp_kernel_dims_array[i] = kernel_dims[i];
......@@ -60,18 +54,12 @@ template <DimIdx_t DIM> void declare_MaxPoolingOp(py::module &m) {
for (size_t i = 0; i < DIM; ++i) {
tmp_stride_dims_array[i] = stride_dims[i];
}
DimSize_t tmp_padding_dims_array[DIM<<1];
for (size_t i = 0; i < (DIM<<1); ++i) {
tmp_padding_dims_array[i] = padding_dims[i];
}
const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array;
const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array;
const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array;
return MaxPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array));
return MaxPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array));
}, py::arg("kernel_dims"),
py::arg("name") = "",
py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1),
py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0));
py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1));
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment