Skip to content
Snippets Groups Projects
Commit 9ffeb961 authored by Maxence Naud's avatar Maxence Naud Committed by Maxence Naud
Browse files

Update ReduceMean python binding

parent e951bff9
No related branches found
No related tags found
No related merge requests found
......@@ -24,22 +24,22 @@
namespace py = pybind11;
namespace Aidge {
template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) {
const std::string pyClassName("ReduceMeanOp" + std::to_string(DIM) + "D");
py::class_<ReduceMean_Op<DIM>, std::shared_ptr<ReduceMean_Op<DIM>>, Attributes, OperatorTensor>(
void declare_ReduceMeanOp(py::module &m) {
const std::string pyClassName("ReduceMeanOp");
py::class_<ReduceMean_Op, std::shared_ptr<ReduceMean_Op>, Attributes, OperatorTensor>(
m, pyClassName.c_str(), py::multiple_inheritance())
.def("get_inputs_name", &ReduceMean_Op<DIM>::getInputsName)
.def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName)
.def("attributes_name", &ReduceMean_Op<DIM>::staticGetAttrsName)
.def("get_inputs_name", &ReduceMean_Op::getInputsName)
.def("get_outputs_name", &ReduceMean_Op::getOutputsName)
.def("attributes_name", &ReduceMean_Op::staticGetAttrsName)
;
declare_registrable<ReduceMean_Op<DIM>>(m, pyClassName);
declare_registrable<ReduceMean_Op>(m, pyClassName);
m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<int>& axes,
m.def("ReduceMean", [](const std::vector<int>& axes,
DimSize_t keepDims,
const std::string& name) {
AIDGE_ASSERT(axes.size() == DIM, "axes size [{}] does not match DIM [{}]", axes.size(), DIM);
// AIDGE_ASSERT(axes.size() == DIM, "axes size [{}] does not match DIM [{}]", axes.size(), DIM);
return ReduceMean<DIM>(to_array<DIM>(axes.begin()), keepDims, name);
return ReduceMean(axes, keepDims, name);
}, py::arg("axes"),
py::arg("keep_dims") = 1,
py::arg("name") = "");
......@@ -47,9 +47,9 @@ template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) {
void init_ReduceMean(py::module &m) {
declare_ReduceMeanOp<1>(m);
declare_ReduceMeanOp<2>(m);
declare_ReduceMeanOp<3>(m);
declare_ReduceMeanOp(m);
// declare_ReduceMeanOp<2>(m);
// declare_ReduceMeanOp<3>(m);
// FIXME:
// m.def("ReduceMean1D", static_cast<NodeAPI(*)(const char*, int, int, int const
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment