diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index f6647f99151304d0cf083aed109cc642c9f1ecc2..2fd2efa5e7a99f4a6effe826e004da97f1c8dbdc 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor, public Registrable<Gather_Op, std::string, std::unique_ptr<OperatorImpl>(const Gather_Op&)>, - public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> { + public StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t> { public: static const std::string Type; Gather_Op() = delete; - using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>; + using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; - Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis) + Gather_Op(const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis) : OperatorTensor(Type, 1, 0, 1), Attributes_( attr<GatherAttr::Indices>(indices), @@ -84,7 +84,7 @@ public: } }; -inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") { +inline std::shared_ptr<Node> Gather( const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis = 0, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name); } } // namespace Aidge diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 32d71d5adc3cfd92c9840dcb5bc61bfb6399c6db..f98a109ce5764d41af67dcbbeb9eaf87e4188db5 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -29,18 +29,18 @@ enum class ReshapeAttr { Shape }; class Reshape_Op : public OperatorTensor, public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>, - public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> { + public StaticAttributes<ReshapeAttr, std::vector<std::int32_t>> { public: static const std::string Type; Reshape_Op() = delete; - using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>; + using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int32_t>>; template <ReshapeAttr e> using attr = typename Attributes_::template attr<e>; - Reshape_Op(const std::vector<std::int64_t>& shape) + Reshape_Op(const std::vector<std::int32_t>& shape) : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<ReshapeAttr::Shape>(shape)) {} @@ -79,7 +79,7 @@ public: } }; -inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, +inline std::shared_ptr<Node> Reshape(const std::vector<std::int32_t>& shape, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 3eafe99efcf46e4ba498351eed160f7d48e37a17..fd0fc83fe5821d61f4d2c3dace8b66033a596450 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -30,7 +30,7 @@ void Aidge::Gather_Op::computeOutputDims() { const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); // TODO: check indices and gatheredShape - const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? + const std::int32_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? this->template getAttr<GatherAttr::Axis>() : this->template getAttr<GatherAttr::Axis>() + outDims.size(); outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));