diff --git a/include/aidge/backend/cpu.hpp b/include/aidge/backend/cpu.hpp
index 8fae3439c9e5399f360d807e2695bafd19793ec5..c1f1cc71ee7d770d6e7e16dd3311f37f7280b41a 100644
--- a/include/aidge/backend/cpu.hpp
+++ b/include/aidge/backend/cpu.hpp
@@ -23,6 +23,7 @@
 #include "aidge/backend/cpu/operator/ErfImpl.hpp"
 #include "aidge/backend/cpu/operator/FCImpl.hpp"
 #include "aidge/backend/cpu/operator/GatherImpl.hpp"
+#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp"
 #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
 #include "aidge/backend/cpu/operator/MatMulImpl.hpp"
 #include "aidge/backend/cpu/operator/MulImpl.hpp"
diff --git a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..aeefdecff19be7eb5da13895a4e9efa3d5c2dd94
--- /dev/null
+++ b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp
@@ -0,0 +1,56 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_H_
+#define AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_H_
+
+#include <memory>
+#include <vector>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/GlobalAveragePooling.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge
+{
+    // class GlobalAveragePooling_Op;
+
+    class GlobalAveragePoolingImplForward_cpu
+        : public Registrable<GlobalAveragePoolingImplForward_cpu, std::tuple<DataType, DataType>, void(const std::vector<DimSize_t> &, const void *, void *)>
+    {
+    };
+    class GlobalAveragePoolingImplBackward_cpu
+        : public Registrable<GlobalAveragePoolingImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::vector<DimSize_t> &, const void *, void *)>
+    {
+    };
+    // Then we declare the Impl class for the operator
+    class GlobalAveragePoolingImpl_cpu : public OperatorImpl
+    {
+    public:
+        GlobalAveragePoolingImpl_cpu(const GlobalAveragePooling_Op &op) : OperatorImpl(op) {}
+
+        static std::unique_ptr<GlobalAveragePoolingImpl_cpu> create(const GlobalAveragePooling_Op &op)
+        {
+            return std::make_unique<GlobalAveragePoolingImpl_cpu>(op);
+        }
+
+        void forward() override;
+    };
+
+    // Finally we create the Registrar for the operator implementation in which we specify the backend cpu
+    namespace
+    {
+        static Registrar<GlobalAveragePooling_Op> registrarGlobalAveragePoolingImpl_cpu("cpu", Aidge::GlobalAveragePoolingImpl_cpu::create);
+    }
+} // namespace Aidge
+
+#endif /* _AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_H_ */
\ No newline at end of file
diff --git a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_forward_kernels.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..d541480d87512e529cd62173d61a16f3c8c928c5
--- /dev/null
+++ b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_forward_kernels.hpp
@@ -0,0 +1,85 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_FORWARD_KERNEL_H_
+#define AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_FORWARD_KERNEL_H_
+
+#include <functional>
+#include <numeric>
+#include <vector>
+
+#include "aidge/data/Data.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+#include <cmath>
+#include <cstddef>
+
+#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp"
+
+namespace Aidge {
+template <class I, class O>
+void GlobalAveragePoolingImpl_cpu_forward_kernel(
+    const std::vector<DimSize_t> &dims, const void *input_, void *output_) {
+  // error checking
+  if (dims.size() < 3) {
+    AIDGE_THROW_OR_ABORT(std::runtime_error,
+                         "GlobalAveragePool needs at least a 3 dimensions "
+                         "input, number of input dim : %lu",
+                         dims.size());
+  }
+
+  // computation
+  const I *input = static_cast<const I *>(input_);
+  O *output = static_cast<O *>(output_);
+
+  DimSize_t nb_elems = std::accumulate(dims.begin(), dims.end(), std::size_t(1),
+                                       std::multiplies<std::size_t>());
+
+  const DimSize_t in_batch_nb_elems{nb_elems / dims[0]};
+  const DimSize_t in_channel_nb_elems{in_batch_nb_elems / dims[1]};
+  const DimSize_t out_batch_nb_elems{dims[1]};
+  // parse channel by channel and fill each output with the average of the
+  // values in the channel
+  for (DimSize_t batch = 0; batch < dims[0]; ++batch) {
+    for (DimSize_t channel = 0; channel < dims[1]; ++channel) {
+      const I *filter_start = std::next(
+          input, batch * in_batch_nb_elems + (channel * in_channel_nb_elems));
+      // I sum = std::accumulate(&filter_start[0],
+      //                         &filter_start[in_batch_nb_elems + 1], 0);
+      I sum = 0;
+      for (size_t i = 0; i < in_channel_nb_elems; ++i) {
+        sum += filter_start[i];
+      }
+
+      output[batch * out_batch_nb_elems + channel] =
+          sum / static_cast<I>(in_channel_nb_elems);
+    }
+  }
+}
+
+// Then we add the Registrar declaration for different input/output types
+namespace {
+static Registrar<GlobalAveragePoolingImplForward_cpu>
+    registrarGlobalAveragePoolingImplForward_cpu_Float32(
+        {DataType::Float32, DataType::Float32},
+        Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<float, float>);
+static Registrar<GlobalAveragePoolingImplForward_cpu>
+    registrarGlobalAveragePoolingImplForward_cpu_Int32(
+        {DataType::Int32, DataType::Int32},
+        Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<int, int>);
+static Registrar<GlobalAveragePoolingImplForward_cpu>
+    registrarGlobalAveragePoolingImplForward_cpu_Float64(
+        {DataType::Float64, DataType::Float64},
+        Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<double, double>);
+} // namespace
+} // namespace Aidge
+
+#endif /* AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_FORWARD_KERNEL_H_ */
diff --git a/src/operator/GlobalAveragePoolingImpl.cpp b/src/operator/GlobalAveragePoolingImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..50048f71504e5910226ece50373b49695a5c9094
--- /dev/null
+++ b/src/operator/GlobalAveragePoolingImpl.cpp
@@ -0,0 +1,37 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <chrono>  // std::chrono::milliseconds
+#include <numeric> // std::accumulate
+#include <thread>  // std::this_thread::sleep_for
+#include <vector>
+
+#include "aidge/operator/GlobalAveragePooling.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp"
+#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl_forward_kernels.hpp"
+
+void Aidge::GlobalAveragePoolingImpl_cpu::forward()
+{
+    // Check if input is provided
+    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input");
+
+    // Create the forward kernal with the wanted types
+    auto kernelFunc = Registrar<GlobalAveragePoolingImplForward_cpu>::create({std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
+                                                                              std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
+
+    // Call kernel
+    kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
+               std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
+               std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
+}
\ No newline at end of file
diff --git a/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp b/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..7ed65d1894240f92f12f49c7768ee76a186556c2
--- /dev/null
+++ b/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp
@@ -0,0 +1,188 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <aidge/utils/Types.h>
+#include <catch2/catch_test_macros.hpp>
+#include <chrono>
+#include <cmath>
+#include <cstddef> // std::size_t
+#include <cstdint> // std::uint16_t
+#include <iostream>
+#include <memory>
+#include <numeric> // std::accumulate
+#include <ostream>
+#include <random> // std::random_device, std::mt19937, std::uniform_real_distribution
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/operator/GlobalAveragePooling.hpp"
+#include "aidge/utils/TensorUtils.hpp"
+
+namespace Aidge {
+TEST_CASE("[cpu/operator] GlobalAveragePooling",
+          "[GlobalAveragePooling][CPU]") {
+  constexpr std::uint16_t NBTRIALS = 10;
+  // Create a random number generator
+  std::random_device rd;
+  std::mt19937 gen(rd());
+  std::uniform_real_distribution<float> valueDist(
+      0.1f, 1.1f); // Random float distribution between 0 and 1
+  std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(2),
+                                                         std::size_t(10));
+
+  std::uniform_int_distribution<std::size_t> nbLowDimsDist(std::size_t(1),
+                                                           std::size_t(2));
+  std::uniform_int_distribution<std::size_t> nbHighDimsDist(std::size_t(3),
+                                                            std::size_t(7));
+
+  // Create MatGlobalAveragePooling Operator
+  std::shared_ptr<Node> globAvgPool = GlobalAveragePooling();
+  auto op =
+      std::static_pointer_cast<OperatorTensor>(globAvgPool->getOperator());
+  op->setDataType(DataType::Float32);
+  op->setBackend("cpu");
+
+  // Create the input Tensor
+  std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>();
+  op->associateInput(0, T0);
+  T0->setDataType(DataType::Float32);
+  T0->setBackend("cpu");
+
+  // Create results Tensor
+  std::shared_ptr<Tensor> Tres = std::make_shared<Tensor>();
+  Tres->setDataType(DataType::Float32);
+  Tres->setBackend("cpu");
+
+  // To measure execution time of 'MatGlobalAveragePooling_Op::forward()' member
+  // function call
+  std::chrono::time_point<std::chrono::system_clock> start;
+  std::chrono::time_point<std::chrono::system_clock> end;
+  std::chrono::duration<double, std::micro> duration{};
+  int number_of_operation{0};
+
+  SECTION("GlobalAveragePoolingImpl_cpu::forward()") {
+    SECTION(
+        "1-2Dim > not enough dimensions leads to function throwing an error") {
+      // generate a random tensors
+      const std::size_t nbDims = nbLowDimsDist(gen);
+      std::vector<std::size_t> dims;
+      for (std::size_t i = 0; i < nbDims; ++i) {
+        dims.push_back(dimSizeDist(gen));
+      }
+      const std::size_t nb_elements =
+          std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1),
+                          std::multiplies<std::size_t>());
+
+      // without broadcasting
+      float *array0 = new float[nb_elements];
+      float *result = new float[nb_elements];
+
+      for (std::size_t i = 0; i < nb_elements; ++i) {
+        array0[i] = valueDist(gen);
+        result[i] += array0[i] / nb_elements;
+      }
+      REQUIRE_THROWS(globAvgPool->forward());
+    }
+
+    SECTION("3+Dim") {
+      SECTION("Fill a tensor with all values set as N will result with every "
+              "output being N") {
+        // generate the tensor
+        const std::size_t nbDims = nbHighDimsDist(gen);
+        std::vector<std::size_t> dims_in;
+        for (std::size_t i = 0; i < nbDims; ++i) {
+          dims_in.push_back(dimSizeDist(gen));
+        }
+        // create in nb_elems
+        const std::size_t in_nb_elems =
+            std::accumulate(dims_in.cbegin(), dims_in.cend(), std::size_t(1),
+                            std::multiplies<std::size_t>());
+        const DimSize_t in_batch_nb_elems = in_nb_elems / dims_in[0];
+        const DimSize_t in_channel_nb_elems = in_batch_nb_elems / dims_in[1];
+
+        // create out nb_elems
+        std::vector<std::size_t> dims_out{dims_in[0], dims_in[1]};
+        const std::size_t out_nb_elems =
+            std::accumulate(dims_out.cbegin(), dims_out.cend(), std::size_t(1),
+                            std::multiplies<std::size_t>());
+        const DimSize_t out_batch_nb_elems = out_nb_elems / dims_out[0];
+
+        // iterate over each batch/channel
+        float *array0 = new float[in_nb_elems];
+        float *result = new float[out_nb_elems];
+        float val = valueDist(gen);
+        std::cout << "val = " << val << std::endl;
+        for (std::size_t batch = 0; batch < dims_in[0]; ++batch) {
+          for (std::size_t channel = 0; channel < dims_in[1]; ++channel) {
+            for (std::size_t i = 0; i < in_channel_nb_elems; ++i)
+
+            {
+              array0[batch * in_batch_nb_elems + channel * in_channel_nb_elems +
+                     i] = val;
+            }
+            result[batch * out_batch_nb_elems + channel] = val;
+          }
+        }
+
+        // input0
+        T0->resize(dims_in);
+        T0->getImpl()->setRawPtr(array0, in_nb_elems);
+
+        // results
+        Tres->resize(dims_out);
+        Tres->getImpl()->setRawPtr(result, out_nb_elems);
+
+        op->computeOutputDims();
+        start = std::chrono::system_clock::now();
+        REQUIRE_NOTHROW(globAvgPool->forward());
+        end = std::chrono::system_clock::now();
+        duration +=
+            std::chrono::duration_cast<std::chrono::microseconds>(end - start);
+
+        // Print tensors
+        std::cout << "input : size =  [";
+        for (auto &dim : op->getInput(0)->dims()) {
+          std::cout << dim << " , ";
+        }
+        std::cout << "]" << std::endl;
+        // T0->print();
+
+        std::cout << "output : size =  [";
+        for (auto &dim : op->getOutput(0)->dims()) {
+          std::cout << dim << " , ";
+        }
+        std::cout << "]" << std::endl;
+        op->getOutput(0)->print();
+
+        std::cout << "ref Tres : size = output size if no error occurred"
+                  << std::endl;
+        std::cout << "ref Tres: size =  [";
+        for (auto &dim : Tres->dims()) {
+          std::cout << dim << " , ";
+        }
+        std::cout << "]" << std::endl;
+
+        CHECK(Tres->nbDims() == op->getOutput(0)->nbDims());
+        for (DimSize_t i = 0; i < op->getOutput(0)->nbDims(); ++i) {
+          CHECK(Tres->dims().at(i) == op->getOutput(0)->dims().at(i));
+        }
+        Tres->print();
+
+        CHECK(approxEq<float>(*(op->getOutput(0)), *Tres));
+
+        delete[] array0;
+        delete[] result;
+      }
+      SECTION("Using result from a pytorch function as groundtruth") {}
+      SECTION("random testing") {}
+    }
+  }
+}
+} // namespace Aidge