Skip to content
Snippets Groups Projects
Commit 95077929 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'vit_operators' into 'dev'

[Fix] vit_operators

See merge request eclipse/aidge/aidge_core!80
parents f0917214 e941e9fe
No related branches found
No related tags found
No related merge requests found
......@@ -12,8 +12,10 @@
#ifndef AIDGE_CORE_OPERATOR_REDUCEMEAN_H_
#define AIDGE_CORE_OPERATOR_REDUCEMEAN_H_
#include <algorithm> // std::for_each
#include <array>
#include <cmath>
#include <cstdint> // std::int32_t
#include <numeric>
#include <vector>
......@@ -31,18 +33,18 @@ enum class ReduceMeanAttr { Axes, KeepDims };
template <DimIdx_t DIM>
class ReduceMean_Op : public OperatorTensor,
public Registrable<ReduceMean_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ReduceMean_Op<DIM> &)>,
public StaticAttributes<ReduceMeanAttr, std::array<int, DIM>, DimSize_t> {
public StaticAttributes<ReduceMeanAttr, std::array<std::int32_t, DIM>, DimSize_t> {
public:
static const std::string Type;
ReduceMean_Op() = delete;
using Attributes_ = StaticAttributes<ReduceMeanAttr, std::array<int, DIM>, DimSize_t>;
using Attributes_ = StaticAttributes<ReduceMeanAttr, std::array<std::int32_t, DIM>, DimSize_t>;
template <ReduceMeanAttr e>
using attr = typename Attributes_::template attr<e>;
constexpr ReduceMean_Op(const std::array<int, DIM> &axes, DimSize_t keep_dims)
constexpr ReduceMean_Op(const std::array<std::int32_t, DIM> &axes, DimSize_t keep_dims)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ReduceMeanAttr::Axes>(axes),
attr<ReduceMeanAttr::KeepDims>(keep_dims)) {}
......@@ -67,29 +69,28 @@ class ReduceMean_Op : public OperatorTensor,
}
void computeOutputDims() override final {
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims;
for(std::size_t d=0; d<getInput(0)->dims().size(); ++d)
{
bool reducedDim = false;
for(std::size_t i=0; i<DIM; ++i)
{
int axis_ = this->template getAttr<ReduceMeanAttr::Axes>()[i];
std::size_t axis= axis_>=0? axis_: axis_ + getInput(0)->nbDims();
if(axis == d)
{
reducedDim = true;
break;
}
}
if(reducedDim)
{
if(this->template getAttr<ReduceMeanAttr::KeepDims>())
outDims.push_back(1);
}
else
outDims.push_back(getInput(0)->dims()[d]);
// make Axes attribute positive
std::array<std::int32_t, DIM>& axes = this->template getAttr<ReduceMeanAttr::Axes>();
std::for_each(axes.begin(), axes.end(), [&] (std::int32_t& val) {
if (val < 0)
val+=static_cast<std::int32_t>(getInput(0)->nbDims());
});
std::sort(axes.begin(), axes.end());
// build output dimensions
std::vector<DimSize_t> outDims = getInput(0)->dims();
if (this->template getAttr<ReduceMeanAttr::KeepDims>()) {
std::for_each(axes.begin(), axes.end(), [&outDims] (const std::int32_t& val) { outDims[val] = 1; });
}
else {
for (auto it = axes.crbegin(); it != axes.crend(); ++it)
outDims.erase(outDims.begin() + static_cast<std::size_t>(*it));
}
if(outDims.size()>0)
mOutputs[0]->resize(outDims);
else
......@@ -111,7 +112,7 @@ class ReduceMean_Op : public OperatorTensor,
};
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> ReduceMean(const std::array<int, DIM> &axes,
inline std::shared_ptr<Node> ReduceMean(const std::array<std::int32_t, DIM> &axes,
DimSize_t keep_dims=1,
const std::string& name = "") {
// FIXME: properly handle default w&b initialization in every cases
......@@ -123,7 +124,7 @@ inline std::shared_ptr<Node> ReduceMean(const std::array<int, DIM> &axes,
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <DimSize_t DIM>
inline std::shared_ptr<Node> ReduceMean(
int const (&axes)[DIM],
std::int32_t const (&axes)[DIM],
DimSize_t keep_dims = 1,
const std::string& name = "") {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ReduceMean, not supported");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment