diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index 78557d49d5e35a6824b2e16d6f8dc2d5b520587c..be5fd648bbc4f40102aa7803c09564238b681efc 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor, public Registrable<Gather_Op, std::string, std::shared_ptr<OperatorImpl>(const Gather_Op&)>, - public StaticAttributes<GatherAttr, std::int64_t> { + public StaticAttributes<GatherAttr, std::int8_t> { public: static const std::string Type; Gather_Op() = delete; - using Attributes_ = StaticAttributes<GatherAttr, std::int64_t>; + using Attributes_ = StaticAttributes<GatherAttr, std::int8_t>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; - Gather_Op(std::int64_t axis) + Gather_Op(std::int8_t axis) : OperatorTensor(Type, 2, 0, 1), Attributes_(attr<GatherAttr::Axis>(axis)) {} diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 1286ab2821fbb583ecaa18e1a7320c56e000849c..f40feb2c8021d3aff1b77316464df6804144d46b 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -33,9 +33,9 @@ void Aidge::Gather_Op::computeOutputDims() { std::vector<DimSize_t> outDims = getInput(0)->dims(); std::vector<DimSize_t> indicesDims = getInput(1)->dims(); - std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? - this->template getAttr<GatherAttr::Axis>(): - this->template getAttr<GatherAttr::Axis>()+outDims.size(); + std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? + this->template getAttr<GatherAttr::Axis>(): + this->template getAttr<GatherAttr::Axis>()+outDims.size(); outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); if( indicesDims[0]>0 ) // In case indices is a scalar indicesDims is a 0 {