diff --git a/include/aidge/operator/ReduceMean.hpp b/include/aidge/operator/ReduceMean.hpp index 31456f5d9890ecf99e9ea4191630fa65bda81942..917f5bd37aaf09731bef0a01e49ca822ea333928 100644 --- a/include/aidge/operator/ReduceMean.hpp +++ b/include/aidge/operator/ReduceMean.hpp @@ -95,7 +95,7 @@ class ReduceMean_Op : public Operator, break; } } - if(!reducedDim) + if(reducedDim) { if(this->template getAttr<ReduceMeanAttr::KeepDims>()) outDims.push_back(1); diff --git a/python_binding/operator/pybind_ReduceMean.cpp b/python_binding/operator/pybind_ReduceMean.cpp index 3322de897932e46915c819e75981b9c147d25da4..90c00ff8e5fc29cac358234fc5092aa6a9668f8c 100644 --- a/python_binding/operator/pybind_ReduceMean.cpp +++ b/python_binding/operator/pybind_ReduceMean.cpp @@ -1,58 +1,58 @@ -// /******************************************************************************** -// * Copyright (c) 2023 CEA-List -// * -// * This program and the accompanying materials are made available under the -// * terms of the Eclipse Public License 2.0 which is available at -// * http://www.eclipse.org/legal/epl-2.0. -// * -// * SPDX-License-Identifier: EPL-2.0 -// * -// ********************************************************************************/ - -// #include <pybind11/pybind11.h> -// #include <pybind11/stl.h> -// #include <iostream> -// #include <string> -// #include <vector> -// #include <array> - -// #include "aidge/backend/OperatorImpl.hpp" -// #include "aidge/operator/ReduceMean.hpp" -// #include "aidge/operator/Operator.hpp" -// #include "aidge/utils/Types.h" - -// namespace py = pybind11; -// namespace Aidge { - -// template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) { -// py::class_<ReduceMean_Op<DIM>, std::shared_ptr<ReduceMean_Op<DIM>>, Operator, Attributes>( -// m, ("ReduceMeanOp" + std::to_string(DIM) + "D").c_str(), -// py::multiple_inheritance()) +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <iostream> +#include <string> +#include <vector> +#include <array> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/ReduceMean.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) { + py::class_<ReduceMean_Op<DIM>, std::shared_ptr<ReduceMean_Op<DIM>>, Operator, Attributes>( + m, ("ReduceMeanOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) // .def(py::init<const std::array<DimSize_t, DIM> &, DimSize_t>(), // py::arg("axes"), // py::arg("keep_dims")) -// .def("get_inputs_name", &ReduceMean_Op<DIM>::getInputsName) -// .def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName) -// ; - -// m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& axes, -// DimSize_t keepDims, -// const std::string& name) { -// AIDGE_ASSERT(axes.size() == DIM, "axes size [%ld] does not match DIM [%d]", axes.size(), DIM); - -// return ReduceMean<DIM>(to_array<DIM>(axes.begin()), keepDims, name); -// }, py::arg("axes"), -// py::arg("keep_dims") = 1); -// } - - -// void init_ReduceMean(py::module &m) { -// declare_ReduceMeanOp<1>(m); -// declare_ReduceMeanOp<2>(m); -// declare_ReduceMeanOp<3>(m); - -// // FIXME: -// // m.def("ReduceMean1D", static_cast<NodeAPI(*)(const char*, int, int, int const -// // (&)[1])>(&ReduceMean)); -// } -// } // namespace Aidge + .def("get_inputs_name", &ReduceMean_Op<DIM>::getInputsName) + .def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName) + ; + + m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& axes, + DimSize_t keepDims, + const std::string& name) { + AIDGE_ASSERT(axes.size() == DIM, "axes size [%ld] does not match DIM [%d]", axes.size(), DIM); + + return ReduceMean<DIM>(to_array<DIM>(axes.begin()), keepDims, name); + }, py::arg("axes"), + py::arg("keep_dims") = 1, + py::arg("name") = ""); +} + + +void init_ReduceMean(py::module &m) { + declare_ReduceMeanOp<1>(m); + declare_ReduceMeanOp<2>(m); + declare_ReduceMeanOp<3>(m); + + // FIXME: + // m.def("ReduceMean1D", static_cast<NodeAPI(*)(const char*, int, int, int const + // (&)[1])>(&ReduceMean)); +} +} // namespace Aidge