From ad0225dfafec58c448062b82f2da0c08ff33dd83 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Thu, 25 Jan 2024 15:50:44 +0100 Subject: [PATCH] changed attrs to int32 --- include/aidge/operator/Gather.hpp | 8 ++++---- include/aidge/operator/Reshape.hpp | 8 ++++---- src/operator/Gather.cpp | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index f6647f991..2fd2efa5e 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor, public Registrable<Gather_Op, std::string, std::unique_ptr<OperatorImpl>(const Gather_Op&)>, - public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> { + public StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t> { public: static const std::string Type; Gather_Op() = delete; - using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>; + using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; - Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis) + Gather_Op(const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis) : OperatorTensor(Type, 1, 0, 1), Attributes_( attr<GatherAttr::Indices>(indices), @@ -84,7 +84,7 @@ public: } }; -inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") { +inline std::shared_ptr<Node> Gather( const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis = 0, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name); } } // namespace Aidge diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 32d71d5ad..f98a109ce 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -29,18 +29,18 @@ enum class ReshapeAttr { Shape }; class Reshape_Op : public OperatorTensor, public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>, - public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> { + public StaticAttributes<ReshapeAttr, std::vector<std::int32_t>> { public: static const std::string Type; Reshape_Op() = delete; - using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>; + using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int32_t>>; template <ReshapeAttr e> using attr = typename Attributes_::template attr<e>; - Reshape_Op(const std::vector<std::int64_t>& shape) + Reshape_Op(const std::vector<std::int32_t>& shape) : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<ReshapeAttr::Shape>(shape)) {} @@ -79,7 +79,7 @@ public: } }; -inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, +inline std::shared_ptr<Node> Reshape(const std::vector<std::int32_t>& shape, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 3eafe99ef..fd0fc83fe 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -30,7 +30,7 @@ void Aidge::Gather_Op::computeOutputDims() { const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); // TODO: check indices and gatheredShape - const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? + const std::int32_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? this->template getAttr<GatherAttr::Axis>() : this->template getAttr<GatherAttr::Axis>() + outDims.size(); outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); -- GitLab