Skip to content
Snippets Groups Projects
Commit e341bc78 authored by Maxence Naud's avatar Maxence Naud Committed by Maxence Naud
Browse files

[WIP][NF] update ReduceMean

parent f0917214
No related branches found
No related tags found
No related merge requests found
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
#ifndef AIDGE_CORE_OPERATOR_REDUCEMEAN_H_ #ifndef AIDGE_CORE_OPERATOR_REDUCEMEAN_H_
#define AIDGE_CORE_OPERATOR_REDUCEMEAN_H_ #define AIDGE_CORE_OPERATOR_REDUCEMEAN_H_
#include <algorithm> // std::for_each
#include <array> #include <array>
#include <cmath> #include <cmath>
#include <cstdint> // std::int32_t
#include <numeric> #include <numeric>
#include <vector> #include <vector>
...@@ -31,18 +33,18 @@ enum class ReduceMeanAttr { Axes, KeepDims }; ...@@ -31,18 +33,18 @@ enum class ReduceMeanAttr { Axes, KeepDims };
template <DimIdx_t DIM> template <DimIdx_t DIM>
class ReduceMean_Op : public OperatorTensor, class ReduceMean_Op : public OperatorTensor,
public Registrable<ReduceMean_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ReduceMean_Op<DIM> &)>, public Registrable<ReduceMean_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ReduceMean_Op<DIM> &)>,
public StaticAttributes<ReduceMeanAttr, std::array<int, DIM>, DimSize_t> { public StaticAttributes<ReduceMeanAttr, std::array<std::int32_t, DIM>, DimSize_t> {
public: public:
static const std::string Type; static const std::string Type;
ReduceMean_Op() = delete; ReduceMean_Op() = delete;
using Attributes_ = StaticAttributes<ReduceMeanAttr, std::array<int, DIM>, DimSize_t>; using Attributes_ = StaticAttributes<ReduceMeanAttr, std::array<std::int32_t, DIM>, DimSize_t>;
template <ReduceMeanAttr e> template <ReduceMeanAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
constexpr ReduceMean_Op(const std::array<int, DIM> &axes, DimSize_t keep_dims) constexpr ReduceMean_Op(const std::array<std::int32_t, DIM> &axes, DimSize_t keep_dims)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ReduceMeanAttr::Axes>(axes), Attributes_(attr<ReduceMeanAttr::Axes>(axes),
attr<ReduceMeanAttr::KeepDims>(keep_dims)) {} attr<ReduceMeanAttr::KeepDims>(keep_dims)) {}
...@@ -67,29 +69,28 @@ class ReduceMean_Op : public OperatorTensor, ...@@ -67,29 +69,28 @@ class ReduceMean_Op : public OperatorTensor,
} }
void computeOutputDims() override final { void computeOutputDims() override final {
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
if (!getInput(0)->empty()) { if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims; // make Axes attribute positive
for(std::size_t d=0; d<getInput(0)->dims().size(); ++d) std::array<std::int32_t, DIM>& axes = this->template getAttr<ReduceMeanAttr::Axes>();
{ std::for_each(axes.begin(), axes.end(), [&] (std::int32_t& val) {
bool reducedDim = false; if (val < 0)
for(std::size_t i=0; i<DIM; ++i) val+=static_cast<std::int32_t>(getInput(0)->nbDims());
{ });
int axis_ = this->template getAttr<ReduceMeanAttr::Axes>()[i]; std::sort(axes.cbegin(), axes.cend());
std::size_t axis= axis_>=0? axis_: axis_ + getInput(0)->nbDims();
if(axis == d) // build output dimensions
{ std::vector<DimSize_t> outDims = getInput(0)->dims();
reducedDim = true; if (this->template getAttr<ReduceMeanAttr::KeepDims>()) {
break; std::for_each(axes.begin(), axes.end(), [&] (const std::int32_t& val) { outDims[val] = 1; });
} }
} else {
if(reducedDim) for (auto it = axes.crbegin(); it != axes.crend(); ++it)
{ outDims.erase(outDims.begin() + static_cast<std::size_t>(*it));
if(this->template getAttr<ReduceMeanAttr::KeepDims>())
outDims.push_back(1);
}
else
outDims.push_back(getInput(0)->dims()[d]);
} }
if(outDims.size()>0) if(outDims.size()>0)
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
else else
...@@ -111,7 +112,7 @@ class ReduceMean_Op : public OperatorTensor, ...@@ -111,7 +112,7 @@ class ReduceMean_Op : public OperatorTensor,
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> ReduceMean(const std::array<int, DIM> &axes, inline std::shared_ptr<Node> ReduceMean(const std::array<std::int32_t, DIM> &axes,
DimSize_t keep_dims=1, DimSize_t keep_dims=1,
const std::string& name = "") { const std::string& name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
...@@ -123,7 +124,7 @@ inline std::shared_ptr<Node> ReduceMean(const std::array<int, DIM> &axes, ...@@ -123,7 +124,7 @@ inline std::shared_ptr<Node> ReduceMean(const std::array<int, DIM> &axes,
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction // helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <DimSize_t DIM> template <DimSize_t DIM>
inline std::shared_ptr<Node> ReduceMean( inline std::shared_ptr<Node> ReduceMean(
int const (&axes)[DIM], std::int32_t const (&axes)[DIM],
DimSize_t keep_dims = 1, DimSize_t keep_dims = 1,
const std::string& name = "") { const std::string& name = "") {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ReduceMean, not supported"); static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ReduceMean, not supported");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment