diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp
index 78557d49d5e35a6824b2e16d6f8dc2d5b520587c..be5fd648bbc4f40102aa7803c09564238b681efc 100644
--- a/include/aidge/operator/Gather.hpp
+++ b/include/aidge/operator/Gather.hpp
@@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor,
                 public Registrable<Gather_Op,
                                    std::string,
                                    std::shared_ptr<OperatorImpl>(const Gather_Op&)>,
-                public StaticAttributes<GatherAttr, std::int64_t> {
+                public StaticAttributes<GatherAttr, std::int8_t> {
 
 public:
     static const std::string Type;
 
     Gather_Op() = delete;
 
-    using Attributes_ = StaticAttributes<GatherAttr, std::int64_t>;
+    using Attributes_ = StaticAttributes<GatherAttr, std::int8_t>;
     template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
-    Gather_Op(std::int64_t axis)
+    Gather_Op(std::int8_t axis)
             : OperatorTensor(Type, 2, 0, 1),
             Attributes_(attr<GatherAttr::Axis>(axis))
     {}
diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp
index 1286ab2821fbb583ecaa18e1a7320c56e000849c..f40feb2c8021d3aff1b77316464df6804144d46b 100644
--- a/src/operator/Gather.cpp
+++ b/src/operator/Gather.cpp
@@ -33,9 +33,9 @@ void Aidge::Gather_Op::computeOutputDims() {
         std::vector<DimSize_t> outDims = getInput(0)->dims();
         std::vector<DimSize_t> indicesDims = getInput(1)->dims();
 
-        std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
-                               this->template getAttr<GatherAttr::Axis>():
-                               this->template getAttr<GatherAttr::Axis>()+outDims.size();
+        std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
+                              this->template getAttr<GatherAttr::Axis>():
+                              this->template getAttr<GatherAttr::Axis>()+outDims.size();
         outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
         if( indicesDims[0]>0 ) // In case indices is a scalar indicesDims is a 0 
         {