From 9ffeb9610aa5b2536c025a5f096495774f4c2708 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 26 Mar 2024 00:12:06 +0000 Subject: [PATCH] Update ReduceMean python binding --- python_binding/operator/pybind_ReduceMean.cpp | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python_binding/operator/pybind_ReduceMean.cpp b/python_binding/operator/pybind_ReduceMean.cpp index b89c09d56..599a648a3 100644 --- a/python_binding/operator/pybind_ReduceMean.cpp +++ b/python_binding/operator/pybind_ReduceMean.cpp @@ -24,22 +24,22 @@ namespace py = pybind11; namespace Aidge { -template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) { - const std::string pyClassName("ReduceMeanOp" + std::to_string(DIM) + "D"); - py::class_<ReduceMean_Op<DIM>, std::shared_ptr<ReduceMean_Op<DIM>>, Attributes, OperatorTensor>( +void declare_ReduceMeanOp(py::module &m) { + const std::string pyClassName("ReduceMeanOp"); + py::class_<ReduceMean_Op, std::shared_ptr<ReduceMean_Op>, Attributes, OperatorTensor>( m, pyClassName.c_str(), py::multiple_inheritance()) - .def("get_inputs_name", &ReduceMean_Op<DIM>::getInputsName) - .def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName) - .def("attributes_name", &ReduceMean_Op<DIM>::staticGetAttrsName) + .def("get_inputs_name", &ReduceMean_Op::getInputsName) + .def("get_outputs_name", &ReduceMean_Op::getOutputsName) + .def("attributes_name", &ReduceMean_Op::staticGetAttrsName) ; - declare_registrable<ReduceMean_Op<DIM>>(m, pyClassName); + declare_registrable<ReduceMean_Op>(m, pyClassName); - m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<int>& axes, + m.def("ReduceMean", [](const std::vector<int>& axes, DimSize_t keepDims, const std::string& name) { - AIDGE_ASSERT(axes.size() == DIM, "axes size [{}] does not match DIM [{}]", axes.size(), DIM); + // AIDGE_ASSERT(axes.size() == DIM, "axes size [{}] does not match DIM [{}]", axes.size(), DIM); - return ReduceMean<DIM>(to_array<DIM>(axes.begin()), keepDims, name); + return ReduceMean(axes, keepDims, name); }, py::arg("axes"), py::arg("keep_dims") = 1, py::arg("name") = ""); @@ -47,9 +47,9 @@ template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) { void init_ReduceMean(py::module &m) { - declare_ReduceMeanOp<1>(m); - declare_ReduceMeanOp<2>(m); - declare_ReduceMeanOp<3>(m); + declare_ReduceMeanOp(m); +// declare_ReduceMeanOp<2>(m); +// declare_ReduceMeanOp<3>(m); // FIXME: // m.def("ReduceMean1D", static_cast<NodeAPI(*)(const char*, int, int, int const -- GitLab