Skip to content
Snippets Groups Projects
Commit a6ca5440 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

change Gather's axis type to int8

parent 9c1ff3f8
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!93Change Gather and Slice's attributes into intputs
......@@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op,
std::string,
std::shared_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, std::int64_t> {
public StaticAttributes<GatherAttr, std::int8_t> {
public:
static const std::string Type;
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, std::int64_t>;
using Attributes_ = StaticAttributes<GatherAttr, std::int8_t>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(std::int64_t axis)
Gather_Op(std::int8_t axis)
: OperatorTensor(Type, 2, 0, 1),
Attributes_(attr<GatherAttr::Axis>(axis))
{}
......
......@@ -33,9 +33,9 @@ void Aidge::Gather_Op::computeOutputDims() {
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> indicesDims = getInput(1)->dims();
std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
this->template getAttr<GatherAttr::Axis>():
this->template getAttr<GatherAttr::Axis>()+outDims.size();
std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
this->template getAttr<GatherAttr::Axis>():
this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if( indicesDims[0]>0 ) // In case indices is a scalar indicesDims is a 0
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment