From 9ffeb9610aa5b2536c025a5f096495774f4c2708 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 26 Mar 2024 00:12:06 +0000
Subject: [PATCH] Update ReduceMean python binding

---
 python_binding/operator/pybind_ReduceMean.cpp | 26 +++++++++----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/python_binding/operator/pybind_ReduceMean.cpp b/python_binding/operator/pybind_ReduceMean.cpp
index b89c09d56..599a648a3 100644
--- a/python_binding/operator/pybind_ReduceMean.cpp
+++ b/python_binding/operator/pybind_ReduceMean.cpp
@@ -24,22 +24,22 @@
 namespace py = pybind11;
 namespace Aidge {
 
-template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) {
-  const std::string pyClassName("ReduceMeanOp" + std::to_string(DIM) + "D");
-  py::class_<ReduceMean_Op<DIM>, std::shared_ptr<ReduceMean_Op<DIM>>, Attributes, OperatorTensor>(
+void declare_ReduceMeanOp(py::module &m) {
+  const std::string pyClassName("ReduceMeanOp");
+  py::class_<ReduceMean_Op, std::shared_ptr<ReduceMean_Op>, Attributes, OperatorTensor>(
     m, pyClassName.c_str(), py::multiple_inheritance())
-    .def("get_inputs_name", &ReduceMean_Op<DIM>::getInputsName)
-    .def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName)
-    .def("attributes_name", &ReduceMean_Op<DIM>::staticGetAttrsName)
+    .def("get_inputs_name", &ReduceMean_Op::getInputsName)
+    .def("get_outputs_name", &ReduceMean_Op::getOutputsName)
+    .def("attributes_name", &ReduceMean_Op::staticGetAttrsName)
     ;
-  declare_registrable<ReduceMean_Op<DIM>>(m, pyClassName);
+  declare_registrable<ReduceMean_Op>(m, pyClassName);
 
-  m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<int>& axes,
+  m.def("ReduceMean", [](const std::vector<int>& axes,
                                                                 DimSize_t keepDims,
                                                                 const std::string& name) {
-        AIDGE_ASSERT(axes.size() == DIM, "axes size [{}] does not match DIM [{}]", axes.size(), DIM);
+        // AIDGE_ASSERT(axes.size() == DIM, "axes size [{}] does not match DIM [{}]", axes.size(), DIM);
 
-        return ReduceMean<DIM>(to_array<DIM>(axes.begin()), keepDims, name);
+        return ReduceMean(axes, keepDims, name);
     }, py::arg("axes"),
        py::arg("keep_dims") = 1,
        py::arg("name") = "");
@@ -47,9 +47,9 @@ template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) {
 
 
 void init_ReduceMean(py::module &m) {
-  declare_ReduceMeanOp<1>(m);
-  declare_ReduceMeanOp<2>(m);
-  declare_ReduceMeanOp<3>(m);
+  declare_ReduceMeanOp(m);
+//   declare_ReduceMeanOp<2>(m);
+//   declare_ReduceMeanOp<3>(m);
 
   // FIXME:
   // m.def("ReduceMean1D", static_cast<NodeAPI(*)(const char*, int, int, int const
-- 
GitLab