Skip to content
Snippets Groups Projects
Commit 988ef2f4 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix ReduceMean for dim 1

parent 5052290b
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
Pipeline #34628 passed
...@@ -95,7 +95,7 @@ class ReduceMean_Op : public Operator, ...@@ -95,7 +95,7 @@ class ReduceMean_Op : public Operator,
break; break;
} }
} }
if(!reducedDim) if(reducedDim)
{ {
if(this->template getAttr<ReduceMeanAttr::KeepDims>()) if(this->template getAttr<ReduceMeanAttr::KeepDims>())
outDims.push_back(1); outDims.push_back(1);
......
// /******************************************************************************** /********************************************************************************
// * Copyright (c) 2023 CEA-List * Copyright (c) 2023 CEA-List
// * *
// * This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
// * terms of the Eclipse Public License 2.0 which is available at * terms of the Eclipse Public License 2.0 which is available at
// * http://www.eclipse.org/legal/epl-2.0. * http://www.eclipse.org/legal/epl-2.0.
// * *
// * SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
// * *
// ********************************************************************************/ ********************************************************************************/
// #include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
// #include <pybind11/stl.h> #include <pybind11/stl.h>
// #include <iostream> #include <iostream>
// #include <string> #include <string>
// #include <vector> #include <vector>
// #include <array> #include <array>
// #include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
// #include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/ReduceMean.hpp"
// #include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
// #include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
// namespace py = pybind11; namespace py = pybind11;
// namespace Aidge { namespace Aidge {
// template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) { template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) {
// py::class_<ReduceMean_Op<DIM>, std::shared_ptr<ReduceMean_Op<DIM>>, Operator, Attributes>( py::class_<ReduceMean_Op<DIM>, std::shared_ptr<ReduceMean_Op<DIM>>, Operator, Attributes>(
// m, ("ReduceMeanOp" + std::to_string(DIM) + "D").c_str(), m, ("ReduceMeanOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance())
// py::multiple_inheritance())
// .def(py::init<const std::array<DimSize_t, DIM> &, DimSize_t>(), // .def(py::init<const std::array<DimSize_t, DIM> &, DimSize_t>(),
// py::arg("axes"), // py::arg("axes"),
// py::arg("keep_dims")) // py::arg("keep_dims"))
// .def("get_inputs_name", &ReduceMean_Op<DIM>::getInputsName) .def("get_inputs_name", &ReduceMean_Op<DIM>::getInputsName)
// .def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName) .def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName)
// ; ;
// m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& axes, m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& axes,
// DimSize_t keepDims, DimSize_t keepDims,
// const std::string& name) { const std::string& name) {
// AIDGE_ASSERT(axes.size() == DIM, "axes size [%ld] does not match DIM [%d]", axes.size(), DIM); AIDGE_ASSERT(axes.size() == DIM, "axes size [%ld] does not match DIM [%d]", axes.size(), DIM);
// return ReduceMean<DIM>(to_array<DIM>(axes.begin()), keepDims, name); return ReduceMean<DIM>(to_array<DIM>(axes.begin()), keepDims, name);
// }, py::arg("axes"), }, py::arg("axes"),
// py::arg("keep_dims") = 1); py::arg("keep_dims") = 1,
// } py::arg("name") = "");
}
// void init_ReduceMean(py::module &m) {
// declare_ReduceMeanOp<1>(m); void init_ReduceMean(py::module &m) {
// declare_ReduceMeanOp<2>(m); declare_ReduceMeanOp<1>(m);
// declare_ReduceMeanOp<3>(m); declare_ReduceMeanOp<2>(m);
declare_ReduceMeanOp<3>(m);
// // FIXME:
// // m.def("ReduceMean1D", static_cast<NodeAPI(*)(const char*, int, int, int const // FIXME:
// // (&)[1])>(&ReduceMean)); // m.def("ReduceMean1D", static_cast<NodeAPI(*)(const char*, int, int, int const
// } // (&)[1])>(&ReduceMean));
// } // namespace Aidge }
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment