diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp
index 74b56cc58d075b5e167357a60c9cae95cf4fe70d..963c694e01d9f2813eb95622c1c475eb40bf7bb1 100644
--- a/include/aidge/backend/cuda.hpp
+++ b/include/aidge/backend/cuda.hpp
@@ -15,6 +15,7 @@
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp"
 #include "aidge/backend/cuda/operator/ConvImpl.hpp"
+#include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp"
 #include "aidge/backend/cuda/operator/ProducerImpl.hpp"
 #include "aidge/backend/cuda/operator/ReLUImpl.hpp"
 
diff --git a/include/aidge/backend/cuda/operator/MaxPoolingImpl.hpp b/include/aidge/backend/cuda/operator/MaxPoolingImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..9da97736e31ddbc01e47e0bde903e5b8348a6f7f
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/MaxPoolingImpl.hpp
@@ -0,0 +1,58 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_BACKEND_CUDA_OPERATOR_MAXPOOLINGIMPL_H_
+#define AIDGE_BACKEND_CUDA_OPERATOR_MAXPOOLINGIMPL_H_
+
+#include <array>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include <cudnn.h>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/MaxPooling.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+template <DimIdx_t DIM>
+class MaxPoolingImpl_cuda : public OperatorImpl {
+private:
+    // CuDNN specific variables
+    cudnnPoolingDescriptor_t mMaxPoolingDesc = nullptr;
+    cudnnPoolingMode_t mMode = CUDNN_POOLING_MAX;
+
+public:
+    MaxPoolingImpl_cuda(const MaxPooling_Op<DIM> &op) : OperatorImpl(op) {}
+
+    static std::unique_ptr<MaxPoolingImpl_cuda> create(const MaxPooling_Op<2> &op) {
+        return std::make_unique<MaxPoolingImpl_cuda>(op);
+    }
+
+public:
+    void forward();
+    ~MaxPoolingImpl_cuda();
+
+private:
+    template <class T> void forward_(const Tensor& input);
+};
+
+namespace {
+// add cuda backend to MaxPooling_Op<2> implementation registry
+static Registrar<MaxPooling_Op<2>> registrarMaxPoolingImpl_cuda("cuda", Aidge::MaxPoolingImpl_cuda<2>::create);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_BACKEND_CUDA_OPERATOR_MAXPOOLINGIMPL_H_ */
diff --git a/src/operator/MaxPoolingImpl.cpp b/src/operator/MaxPoolingImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8ba19a61c25c5a7afcb43016e118bef63785018b
--- /dev/null
+++ b/src/operator/MaxPoolingImpl.cpp
@@ -0,0 +1,85 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <chrono>  // std::chrono::milliseconds
+#include <numeric> // std::accumulate
+#include <thread>  // std::this_thread::sleep_for
+#include <vector>
+
+#include "aidge/utils/Types.h"
+#include "aidge/operator/MaxPooling.hpp"
+
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+
+template <Aidge::DimIdx_t DIM>
+void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
+    assert(mOp.getRawInput(0) && "missing input #0");
+
+    std::shared_ptr<Tensor> inputFallback;
+    const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+
+    // Lazy-initialize CuDNN MaxPooling descriptor
+    if (mMaxPoolingDesc == nullptr) {
+        const MaxPooling_Op<DIM>& maxPoolingOp = static_cast<const MaxPooling_Op<DIM>&>(mOp);
+        const std::vector<int> strides(maxPoolingOp.template getAttr<MaxPoolingAttr::StrideDims>().begin(), maxPoolingOp.template getAttr<MaxPoolingAttr::StrideDims>().end());
+        const std::vector<int> paddings(DIM, 0);
+        const std::vector<int> window_dims(maxPoolingOp.template getAttr<MaxPoolingAttr::KernelDims>().begin(), maxPoolingOp.template getAttr<MaxPoolingAttr::KernelDims>().end());
+
+        CHECK_CUDNN_STATUS(cudnnCreatePoolingDescriptor(&mMaxPoolingDesc));
+        CHECK_CUDNN_STATUS(
+            cudnnSetPoolingNdDescriptor(mMaxPoolingDesc,
+                                        mMode,
+                                        CUDNN_NOT_PROPAGATE_NAN,
+                                        DIM,
+                                        &window_dims[0],
+                                        &paddings[0],
+                                        &strides[0]));
+    }
+
+    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
+        forward_<double>(input);
+    }
+    else {
+        forward_<float>(input);
+    }
+}
+
+template <Aidge::DimIdx_t DIM>
+template <class T>
+void Aidge::MaxPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
+    const T alpha = 1.0f;
+    const T beta = 0.0f;
+    CHECK_CUDNN_STATUS(
+        cudnnPoolingForward(
+            CudaContext::cudnnHandle(),
+            mMaxPoolingDesc,
+            &alpha,
+            dynamic_cast<TensorImpl_cuda_*>(input.getImpl().get())->getCudnnTensorDesc(),
+            input.getImpl()->rawPtr(),
+            &beta,
+            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
+            std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
+        )
+    );
+}
+
+template <Aidge::DimIdx_t DIM>
+Aidge::MaxPoolingImpl_cuda<DIM>::~MaxPoolingImpl_cuda() {
+    if(mMaxPoolingDesc != nullptr)
+        cudnnDestroyPoolingDescriptor(mMaxPoolingDesc);
+}
+
+
+// Template declarations
+template class Aidge::MaxPoolingImpl_cuda<2>;
diff --git a/unit_tests/Test_MaxPoolingImpl.cpp b/unit_tests/Test_MaxPoolingImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b2ec0dfe5dc6df072b6be3b20c075190cd3f6fce
--- /dev/null
+++ b/unit_tests/Test_MaxPoolingImpl.cpp
@@ -0,0 +1,93 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <array>
+
+#include <catch2/catch_test_macros.hpp>
+
+#include "Test_cuda.hpp"
+
+#include "aidge/data/Tensor.hpp"
+
+#include "aidge/backend/cpu.hpp"
+#include "aidge/backend/cuda.hpp"
+
+using namespace Aidge;
+
+
+TEST_CASE("[cpu/operator] MaxPooling(forward)", "[MaxPooling][CPU]") {
+    std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,2,5,5> { //NCHW
+        {
+            {
+                {{-0.3848,  0.2166, -0.4373,  0.6142,  0.5277},
+                 {0.7995,  0.3638, -1.4589, -1.0843,  1.0918},
+            	 {0.7147,  0.0936, -1.2902,  1.2037,  0.4874},
+                 {-0.5981,  2.1184, -0.9175,  1.3859,  0.3305},
+                 {-1.7700,  0.0563, -0.3914,  0.0538, -0.3955}},
+
+                {{-3.1409, -0.4554,  0.0524,  2.2291,  0.4859},
+                 {-0.7465, -0.6567, -2.3703, -0.6386, -1.4152},
+                 { 2.2329, -0.5850,  0.0700,  1.2838, -1.7363},
+                 { 0.2139,  0.0624, -1.0689, -0.8221, -0.8038},
+                 { 0.1886, -0.7840, -0.2313,  0.2651, -1.6244}}
+            },
+            {
+                {{ 0.4371,  1.6417,  0.9129,  0.6325,  0.5438},
+                 {-2.3552, -0.8850, -0.0232, -0.5462, -1.2011},
+                 {1.7653, -1.6668, -1.0814,  0.6182,  1.2071},
+                 {0.9541, -0.5133,  0.8664, -0.8892,  1.4585},
+                 {1.0220, -0.5107,  0.1829, -0.2301, -0.4268}},
+
+                {{ 1.0429,  0.6279, -0.2875,  0.7187, -0.1500},
+                 {1.6041,  2.9635,  1.4172, -0.7517,  0.5441},
+                 {-0.2276,  0.0857,  0.6776, -0.1389, -0.0614},
+                 {-0.1547, -0.3435,  0.0650, -0.5095, -1.8073},
+                 {1.7217,  0.3999, -0.5953,  1.0604, -0.4126}}
+            }
+        }
+    });
+    SECTION("Stride") {
+        std::shared_ptr<Node> myMaxPool = MaxPooling({2,2}, "mycdw", {2,2});
+        auto op = std::static_pointer_cast<OperatorTensor>(myMaxPool -> getOperator());
+
+        std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,2,2,2,2> {
+            {
+                {
+                    {{  0.7995,  0.6142},
+                     { 2.1184,  1.3859}},
+                    {{ -0.4554,  2.2291},
+                     {  2.2329,  1.2838}}
+                },
+                {
+                    {{1.6417,  0.9129},
+                     {1.7653,  0.8664}},
+                    {{2.9635,  1.4172},
+                     {0.0857,  0.6776}}
+                }
+            }
+        });
+        myMaxPool->getOperator()->associateInput(0,myInput);
+        myMaxPool->getOperator()->setDataType(DataType::Float32);
+        myMaxPool->getOperator()->setBackend("cuda");
+        op->computeOutputDims();
+        myMaxPool->forward();
+        
+        float* computedOutput   = new float[myOutput->size()]();
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
+
+        for(int i = 0; i < myOutput->size(); i++){
+            const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);
+            REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
+        }
+
+        delete[] computedOutput;
+    }
+}
\ No newline at end of file