diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp
index d34ad64b1b7751e850be6c900f086810d7c002ae..74b56cc58d075b5e167357a60c9cae95cf4fe70d 100644
--- a/include/aidge/backend/cuda.hpp
+++ b/include/aidge/backend/cuda.hpp
@@ -13,6 +13,7 @@
 #define AIDGE_BACKEND_CUDA_IMPORTS_H_
 
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp"
 #include "aidge/backend/cuda/operator/ConvImpl.hpp"
 #include "aidge/backend/cuda/operator/ProducerImpl.hpp"
 #include "aidge/backend/cuda/operator/ReLUImpl.hpp"
diff --git a/include/aidge/backend/cuda/operator/AvgPoolingImpl.hpp b/include/aidge/backend/cuda/operator/AvgPoolingImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..f32dcb1249a591d79671d8d014c56a649ebdebb6
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/AvgPoolingImpl.hpp
@@ -0,0 +1,58 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_BACKEND_CUDA_OPERATOR_AVGPOOLINGIMPL_H_
+#define AIDGE_BACKEND_CUDA_OPERATOR_AVGPOOLINGIMPL_H_
+
+#include <array>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include <cudnn.h>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/AvgPooling.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+template <DimIdx_t DIM>
+class AvgPoolingImpl_cuda : public OperatorImpl {
+private:
+    // CuDNN specific variables
+    cudnnPoolingDescriptor_t mAvgPoolingDesc = nullptr;
+    cudnnPoolingMode_t mMode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
+
+public:
+    AvgPoolingImpl_cuda(const AvgPooling_Op<DIM> &op) : OperatorImpl(op) {}
+
+    static std::unique_ptr<AvgPoolingImpl_cuda> create(const AvgPooling_Op<2> &op) {
+        return std::make_unique<AvgPoolingImpl_cuda>(op);
+    }
+
+public:
+    void forward();
+    ~AvgPoolingImpl_cuda();
+
+private:
+    template <class T> void forward_(const Tensor& input);
+};
+
+namespace {
+// add cuda backend to AvgPooling_Op<2> implementation registry
+static Registrar<AvgPooling_Op<2>> registrarAvgPoolingImpl_cuda("cuda", Aidge::AvgPoolingImpl_cuda<2>::create);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_BACKEND_CUDA_OPERATOR_AVGPOOLINGIMPL_H_ */
diff --git a/src/operator/AvgPoolingImpl.cpp b/src/operator/AvgPoolingImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0257f1907d5f3712638a38cec935c3c8a08edd1e
--- /dev/null
+++ b/src/operator/AvgPoolingImpl.cpp
@@ -0,0 +1,85 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <chrono>  // std::chrono::milliseconds
+#include <numeric> // std::accumulate
+#include <thread>  // std::this_thread::sleep_for
+#include <vector>
+
+#include "aidge/utils/Types.h"
+#include "aidge/operator/AvgPooling.hpp"
+
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+
+template <Aidge::DimIdx_t DIM>
+void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
+    assert(mOp.getRawInput(0) && "missing input #0");
+
+    std::shared_ptr<Tensor> inputFallback;
+    const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+
+    // Lazy-initialize CuDNN AvgPooling descriptor
+    if (mAvgPoolingDesc == nullptr) {
+        const AvgPooling_Op<DIM>& avgPoolingOp = static_cast<const AvgPooling_Op<DIM>&>(mOp);
+        const std::vector<int> strides(avgPoolingOp.template getAttr<AvgPoolingAttr::StrideDims>().begin(), avgPoolingOp.template getAttr<AvgPoolingAttr::StrideDims>().end());
+        const std::vector<int> paddings(DIM, 0);
+        const std::vector<int> window_dims(avgPoolingOp.template getAttr<AvgPoolingAttr::KernelDims>().begin(), avgPoolingOp.template getAttr<AvgPoolingAttr::KernelDims>().end());
+
+        CHECK_CUDNN_STATUS(cudnnCreatePoolingDescriptor(&mAvgPoolingDesc));
+        CHECK_CUDNN_STATUS(
+            cudnnSetPoolingNdDescriptor(mAvgPoolingDesc,
+                                        mMode,
+                                        CUDNN_NOT_PROPAGATE_NAN,
+                                        DIM,
+                                        &window_dims[0],
+                                        &paddings[0],
+                                        &strides[0]));
+    }
+
+    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
+        forward_<double>(input);
+    }
+    else {
+        forward_<float>(input);
+    }
+}
+
+template <Aidge::DimIdx_t DIM>
+template <class T>
+void Aidge::AvgPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
+    const T alpha = 1.0f;
+    const T beta = 0.0f;
+    CHECK_CUDNN_STATUS(
+        cudnnPoolingForward(
+            CudaContext::cudnnHandle(),
+            mAvgPoolingDesc,
+            &alpha,
+            dynamic_cast<TensorImpl_cuda_*>(input.getImpl().get())->getCudnnTensorDesc(),
+            input.getImpl()->rawPtr(),
+            &beta,
+            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
+            std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
+        )
+    );
+}
+
+template <Aidge::DimIdx_t DIM>
+Aidge::AvgPoolingImpl_cuda<DIM>::~AvgPoolingImpl_cuda() {
+    if(mAvgPoolingDesc != nullptr)
+        cudnnDestroyPoolingDescriptor(mAvgPoolingDesc);
+}
+
+
+// Template declarations
+template class Aidge::AvgPoolingImpl_cuda<2>;
diff --git a/unit_tests/Test_AvgPoolingImpl.cpp b/unit_tests/Test_AvgPoolingImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..def2e93f4105eb107c30f7cce3a2a2038da12d58
--- /dev/null
+++ b/unit_tests/Test_AvgPoolingImpl.cpp
@@ -0,0 +1,125 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <array>
+
+#include <catch2/catch_test_macros.hpp>
+
+#include "Test_cuda.hpp"
+
+#include "aidge/data/Tensor.hpp"
+
+#include "aidge/backend/cpu.hpp"
+#include "aidge/backend/cuda.hpp"
+
+using namespace Aidge;
+
+TEST_CASE("[gpu/operator] AvgPooling(forward)", "[AvgPooling][GPU]") {
+    std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,2,5,5> { //NCHW
+        {
+            {
+                {{  0,   1,   2,   3,   4},
+                 {  5,   6,   7,   8,   9},
+                 { 10,  11,  12,  13,  14},
+                 { 15,  16,  17,  18,  19},
+                 { 20,  21,  22,  23,  24}},
+
+                {{ 25,  26,  27,  28,  29},
+                 { 30,  31,  32,  33,  34},
+                 { 35,  36,  37,  38,  39},
+                 { 40,  41,  42,  43,  44},
+                 { 45,  46,  47,  48,  49}}
+            },
+            {
+                {{100, 101, 102, 103, 104},
+                 {105, 106, 107, 108, 109},
+                 {110, 111, 112, 113, 114},
+                 {115, 116, 117, 118, 119},
+                 {120, 121, 122, 123, 124}},
+
+                {{125, 126, 127, 128, 129},
+                 {130, 131, 132, 133, 134},
+                 {135, 136, 137, 138, 139},
+                 {140, 141, 142, 143, 144},
+                 {145, 146, 147, 148, 149}}
+            }
+        }
+    });
+    SECTION("Stride") {
+        std::shared_ptr<Node> myAvgPool = AvgPooling({2,2}, "mycdw", {2,2});
+        auto op = std::static_pointer_cast<OperatorTensor>(myAvgPool -> getOperator());
+
+        std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,2,2,2,2> {
+            {
+                {
+                    {{  3,   5},
+                     { 13,  15}},
+                    {{ 28,  30},
+                     { 38,  40}}
+                },
+                {
+                    {{103, 105},
+                     {113, 115}},
+                    {{128, 130},
+                     {138, 140}}
+                }
+            }
+        });
+        op->associateInput(0,myInput);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->computeOutputDims();
+        myAvgPool->forward();
+
+        float* computedOutput   = new float[myOutput->size()]();
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
+
+        for(int i = 0; i < myOutput->size(); i++){
+            const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);
+            REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
+        }
+
+        delete[] computedOutput;
+    }
+
+    SECTION("Stride >= feature dim") {
+        std::shared_ptr<Tensor> myInput2 = std::make_shared<Tensor>(Array4D<float,1,1,3,3> { //NCHW
+        {
+            {
+                {{0.3745, 0.9507, 0.7320},
+                 {0.5987, 0.1560, 0.1560},
+                 {0.0581, 0.8662, 0.6011}}
+            }
+        }
+        });
+        std::shared_ptr<Node> myAvgPool = AvgPooling({3,3}, "mycdw", {3,3});
+        auto op = std::static_pointer_cast<OperatorTensor>(myAvgPool -> getOperator());
+
+        std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,1,1,1,1> {
+            {{{{(0.3745 + 0.9507 + 0.7320 + 0.5987 + 0.1560 + 0.1560 + 0.0581 + 0.8662 + 0.6011)/9.0}}}}
+        });
+        op->associateInput(0,myInput2);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->computeOutputDims();
+        myAvgPool->forward();
+
+        float* computedOutput   = new float[myOutput->size()]();
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
+
+        for(int i = 0; i < myOutput->size(); i++){
+            const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);
+            REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
+        }
+
+        delete[] computedOutput;
+    }
+}
\ No newline at end of file