Skip to content
Snippets Groups Projects
Commit fc3993ae authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix negative axis value for ReduceMean

parent fbd36894
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
......@@ -31,18 +31,18 @@ enum class ReduceMeanAttr { Axes, KeepDims };
template <DimIdx_t DIM>
class ReduceMean_Op : public OperatorTensor,
public Registrable<ReduceMean_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ReduceMean_Op<DIM> &)>,
public StaticAttributes<ReduceMeanAttr, std::array<DimSize_t, DIM>, DimSize_t> {
public StaticAttributes<ReduceMeanAttr, std::array<int, DIM>, DimSize_t> {
public:
static constexpr const char *Type = "ReduceMean";
ReduceMean_Op() = delete;
using Attributes_ = StaticAttributes<ReduceMeanAttr, std::array<DimSize_t, DIM>, DimSize_t>;
using Attributes_ = StaticAttributes<ReduceMeanAttr, std::array<int, DIM>, DimSize_t>;
template <ReduceMeanAttr e>
using attr = typename Attributes_::template attr<e>;
constexpr ReduceMean_Op(const std::array<DimSize_t, DIM> &axes, DimSize_t keep_dims)
constexpr ReduceMean_Op(const std::array<int, DIM> &axes, DimSize_t keep_dims)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ReduceMeanAttr::Axes>(axes),
attr<ReduceMeanAttr::KeepDims>(keep_dims)) {}
......@@ -74,7 +74,9 @@ class ReduceMean_Op : public OperatorTensor,
bool reducedDim = false;
for(std::size_t i=0; i<DIM; ++i)
{
if(this->template getAttr<ReduceMeanAttr::Axes>()[i] == d)
int axis_ = this->template getAttr<ReduceMeanAttr::Axes>()[i];
std::size_t axis= axis_>=0? axis_: axis_ + getInput(0)->nbDims();
if(axis == d)
{
reducedDim = true;
break;
......@@ -87,8 +89,11 @@ class ReduceMean_Op : public OperatorTensor,
}
else
outDims.push_back(getInput(0)->dims()[d]);
}
mOutputs[0]->resize(outDims);
}
if(outDims.size()>0)
mOutputs[0]->resize(outDims);
else
mOutputs[0]->resize({1});
}
}
......@@ -109,8 +114,8 @@ class ReduceMean_Op : public OperatorTensor,
};
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> ReduceMean(const std::array<DimSize_t, DIM> &axes,
DimSize_t keep_dims,
inline std::shared_ptr<Node> ReduceMean(const std::array<int, DIM> &axes,
DimSize_t keep_dims=1,
const std::string& name = "") {
// FIXME: properly handle default w&b initialization in every cases
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ReduceMean, not supported");
......@@ -121,7 +126,7 @@ inline std::shared_ptr<Node> ReduceMean(const std::array<DimSize_t, DIM> &axes,
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <DimSize_t DIM>
inline std::shared_ptr<Node> ReduceMean(
DimSize_t const (&axes)[DIM],
int const (&axes)[DIM],
DimSize_t keep_dims = 1,
const std::string& name = "") {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ReduceMean, not supported");
......
......@@ -32,7 +32,7 @@ class Reshape_Op : public OperatorTensor,
public:
static constexpr const char* Type = "Reshape";
Reshape_Op() : OperatorTensor(Type, 2, 0, 1) {} //1,1,1
Reshape_Op() : OperatorTensor(Type, 2, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......
......@@ -30,7 +30,7 @@ template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) {
.def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName)
;
m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& axes,
m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<int>& axes,
DimSize_t keepDims,
const std::string& name) {
AIDGE_ASSERT(axes.size() == DIM, "axes size [%ld] does not match DIM [%d]", axes.size(), DIM);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment