From a6ca5440f92996626b9b234fb34335cca6d14801 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 19 Mar 2024 10:53:10 +0100
Subject: [PATCH] change Gather's axis type to int8

---
 include/aidge/operator/Gather.hpp | 6 +++---
 src/operator/Gather.cpp           | 6 +++---
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp
index 78557d49d..be5fd648b 100644
--- a/include/aidge/operator/Gather.hpp
+++ b/include/aidge/operator/Gather.hpp
@@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor,
                 public Registrable<Gather_Op,
                                    std::string,
                                    std::shared_ptr<OperatorImpl>(const Gather_Op&)>,
-                public StaticAttributes<GatherAttr, std::int64_t> {
+                public StaticAttributes<GatherAttr, std::int8_t> {
 
 public:
     static const std::string Type;
 
     Gather_Op() = delete;
 
-    using Attributes_ = StaticAttributes<GatherAttr, std::int64_t>;
+    using Attributes_ = StaticAttributes<GatherAttr, std::int8_t>;
     template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
-    Gather_Op(std::int64_t axis)
+    Gather_Op(std::int8_t axis)
             : OperatorTensor(Type, 2, 0, 1),
             Attributes_(attr<GatherAttr::Axis>(axis))
     {}
diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp
index 1286ab282..f40feb2c8 100644
--- a/src/operator/Gather.cpp
+++ b/src/operator/Gather.cpp
@@ -33,9 +33,9 @@ void Aidge::Gather_Op::computeOutputDims() {
         std::vector<DimSize_t> outDims = getInput(0)->dims();
         std::vector<DimSize_t> indicesDims = getInput(1)->dims();
 
-        std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
-                               this->template getAttr<GatherAttr::Axis>():
-                               this->template getAttr<GatherAttr::Axis>()+outDims.size();
+        std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
+                              this->template getAttr<GatherAttr::Axis>():
+                              this->template getAttr<GatherAttr::Axis>()+outDims.size();
         outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
         if( indicesDims[0]>0 ) // In case indices is a scalar indicesDims is a 0 
         {
-- 
GitLab