diff --git a/.gitlab/ci/_global.gitlab-ci.yml b/.gitlab/ci/_global.gitlab-ci.yml
index 1615b8974db11d93cb3305ce800e46cf5377bc33..fc335f699207a53eaf43038198a0cdeb5ca9bc4b 100644
--- a/.gitlab/ci/_global.gitlab-ci.yml
+++ b/.gitlab/ci/_global.gitlab-ci.yml
@@ -14,3 +14,4 @@ default:
   before_script:
     - apt update
     - apt install -y cmake cppcheck python-is-python3 pip git gcovr unzip curl
+    - apt install -y libcudnn8-dev
diff --git a/CMakeLists.txt b/CMakeLists.txt
index cf4866d9ef0bd32e5f264b356e15133b7d2b56e1..6e30fb0010f05825123586e4c4d4c9e56873e854 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -50,6 +50,7 @@ target_link_libraries(${module_name}
         _aidge_core # _ is added because we link the target not the project
         _aidge_backend_cpu # _ is added because we link the target not the project
         CUDA::cudart
+        cudnn
 )
 
 #Set target properties
diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index 78403519283789b2ed5fd839d0c08557854c6d5b..86fe0c49695b5ba4039bda34a54b02bba204e0b6 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -10,6 +10,7 @@
 #include "aidge/utils/Types.h"
 
 #include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
 
 namespace Aidge {
 template <class T>
@@ -18,6 +19,7 @@ class TensorImpl_cuda : public TensorImpl {
     const Tensor &mTensor;  // Impl needs to access Tensor information, but is not
                             // supposed to change it!
     T* mData = nullptr;
+    mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr;
 
    public:
     static constexpr const char *Backend = "cuda";
@@ -51,9 +53,55 @@ class TensorImpl_cuda : public TensorImpl {
         return mData;
     };
 
+    const cudnnTensorDescriptor_t& getCudnnTensorDesc() const {
+        if (mCudnnTensor == nullptr) {
+            CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor));
+
+            if (mTensor.size() > 0) {
+                /**
+                **      cudNN Tensors are restricted to having at least 4 dimensions :
+                **      When working with lower dimensionsal data, unused dimensions are set to 1.
+                **      Referes to the cudnnSetTensorNdDescriptor documentation from :
+                **      https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html
+                **/
+                std::vector<int> dims(4,1);
+                std::vector<int> strides(4,1);
+                int stride = 1;
+
+                for (unsigned int dim = 0; dim < 4; ++dim) {
+                    if(dim < mTensor.nbDims()) {
+                        dims[dim] = mTensor.dims()[dim];
+                        strides[dim] = stride;
+                        stride  *= mTensor.dims()[dim];
+                    }
+                }
+
+                for (unsigned int dim = 4; dim < mTensor.nbDims(); ++dim) {
+                    dims.push_back(mTensor.dims()[dim]);
+                    strides.push_back(stride);
+                    stride *= mTensor.dims()[dim];
+                }
+
+                std::reverse(dims.begin(), dims.end());
+                std::reverse(strides.begin(), strides.end());
+
+                CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor,
+                                            CudaContext::data_type<T>::value,
+                                            dims.size(),
+                                            &dims[0],
+                                            &strides[0]));
+            }
+        }
+
+        return mCudnnTensor;
+    }
+
     virtual ~TensorImpl_cuda() {
         if (mData != nullptr)
             cudaFree(mData);
+
+        if (mCudnnTensor != nullptr)
+            cudnnDestroyTensorDescriptor(mCudnnTensor);
     }
 
     void setRawPtr(void* /*ptr*/) override final {
diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..a1979e1100b188f31c78375228119758dbf2aa17
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp
@@ -0,0 +1,76 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_BACKEND_CUDA_OPERATOR_CONVIMPL_H_
+#define AIDGE_BACKEND_CUDA_OPERATOR_CONVIMPL_H_
+
+#include <array>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include <cudnn.h>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Conv.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+// class Conv_Op;
+
+template <DimIdx_t DIM>
+class ConvImpl_cuda : public OperatorImpl {
+   private:
+    const Conv_Op<DIM> &mOp;
+    std::array<NbElts_t, 3> mNbConsumedData;
+    std::array<NbElts_t, 1> mNbProducedData;
+
+    size_t mWorkspaceSize = 0;
+    void* mWorkspace = nullptr;
+
+    cudnnFilterDescriptor_t mFilterDesc;
+    cudnnConvolutionFwdAlgo_t mFwdAlgo;
+    cudnnConvolutionDescriptor_t mConvDesc;
+
+   public:
+    ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op), mNbConsumedData({0, 0, 0}), mNbProducedData({0}) {
+        CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc));
+    }
+
+    static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) {
+        return std::make_unique<ConvImpl_cuda>(op);
+    }
+
+   public:
+    NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
+    NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
+    NbElts_t getRequiredMemory(const IOIndex_t /*outputIdx*/, const std::vector<DimSize_t> &/*inputsSize*/) const override final;
+    NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final;
+    NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
+    void updateConsummerProducer() override final;
+
+    void forward();
+
+    void backward();
+
+    ~ConvImpl_cuda();
+};
+
+namespace {
+// add cuda backend to Conv_Op<2> implementation registry
+static Registrar<Conv_Op<2>> registrarConvImpl_cuda("cuda", Aidge::ConvImpl_cuda<2>::create);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_BACKEND_CUDA_OPERATOR_CONVIMPL_H_ */
diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..a66ccdf690603c39ba4a7bf691f0dffea64ddddb
--- /dev/null
+++ b/include/aidge/backend/cuda/utils/CudaContext.hpp
@@ -0,0 +1,165 @@
+#ifndef AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H
+#define AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H
+
+#include <vector>
+#include <cstdio>
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+class CudaContext {
+public:
+    static int nbDevice(){
+        int count = 1;
+        CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
+        return count;
+    }
+    static void setDevice(int device = -1)
+    {
+        static int prevDevice = 0;
+
+        if (device >= 0)
+            prevDevice = device;
+        else
+            device = prevDevice;
+
+        CHECK_CUDA_STATUS(cudaSetDevice(device));
+    }
+
+    static std::pair<size_t, size_t> getMemInfo(){
+        size_t free;
+        size_t total;
+        CHECK_CUDA_STATUS(cudaMemGetInfo (&free, &total));
+        return std::make_pair(free, total);
+    }
+    
+
+    static int getDevice(){
+        int dev;
+        CHECK_CUDA_STATUS(cudaGetDevice(&dev));
+        return dev;
+    }
+
+    static const cudaDeviceProp& getDeviceProp()
+    {
+        static std::vector<cudaDeviceProp> deviceProp;
+        static std::vector<bool> init;
+
+        if (deviceProp.empty()) {
+//#pragma omp critical(CudaContext__getDeviceProp)
+            if (deviceProp.empty()) {
+                int count = 1;
+                CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
+
+                deviceProp.resize(count);
+                init.resize(count, false);
+            }
+        }
+
+        int dev;
+        CHECK_CUDA_STATUS(cudaGetDevice(&dev));
+
+        if (!init[dev]) {
+            CHECK_CUDA_STATUS(cudaGetDeviceProperties(&deviceProp[dev], dev));
+            init[dev] = true;
+        }
+
+        return deviceProp[dev];
+    }
+
+    // Declare cublas handle
+    static cublasHandle_t& cublasHandle()
+    {
+        static std::vector<cublasHandle_t> cublas_h;
+
+        if (cublas_h.empty()) {
+//#pragma omp critical(CudaContext__cublasHandle)
+            if (cublas_h.empty()) {
+                int count = 1;
+                CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
+
+                cublas_h.resize(count, NULL);
+            }
+        }
+
+        int dev;
+        CHECK_CUDA_STATUS(cudaGetDevice(&dev));
+
+        if (cublas_h[dev] == NULL) {
+            CHECK_CUBLAS_STATUS(cublasCreate(&cublas_h[dev]));
+            printf("CUBLAS initialized on device #%d\n", dev);
+        }
+
+        return cublas_h[dev];
+    }
+
+    // Declare cudnn handle
+    static cudnnHandle_t& cudnnHandle()
+    {
+        static std::vector<cudnnHandle_t> cudnn_h;
+
+        if (cudnn_h.empty()) {
+//#pragma omp critical(CudaContext__cudnnHandle)
+            if (cudnn_h.empty()) {
+                int count = 1;
+                CHECK_CUDA_STATUS(cudaGetDeviceCount(&count));
+
+                cudnn_h.resize(count, NULL);
+            }
+        }
+
+        int dev;
+        CHECK_CUDA_STATUS(cudaGetDevice(&dev));
+
+        if (cudnn_h[dev] == NULL) {
+            CHECK_CUDNN_STATUS(cudnnCreate(&cudnn_h[dev]));
+            printf("CUDNN initialized on device #%d\n", dev);
+        }
+
+        return cudnn_h[dev];
+    }
+
+    template <class T>
+    struct data_type {
+        static const cudnnDataType_t value = CUDNN_DATA_FLOAT;
+                                            // Dummy value by default
+    };
+};
+}
+
+namespace Aidge {
+    template <>
+    struct CudaContext::data_type<float> {
+        static const cudnnDataType_t value = CUDNN_DATA_FLOAT;
+    };
+
+    template <>
+    struct CudaContext::data_type<double> {
+        static const cudnnDataType_t value = CUDNN_DATA_DOUBLE;
+    };
+
+    inline cudnnDataType_t DataTypeToCudnn(DataType type) {
+        if (type == DataType::Float32)
+            return CUDNN_DATA_FLOAT;
+
+        if (type == DataType::Float64)
+            return CUDNN_DATA_DOUBLE;
+
+        if (type == DataType::Int8)
+            return CUDNN_DATA_INT8;
+
+        if (type == DataType::UInt8)
+            return CUDNN_DATA_UINT8;
+
+        if (type == DataType::Int32)
+            return CUDNN_DATA_INT32;
+
+        if (type == DataType::Int64)
+            return CUDNN_DATA_INT64;
+        
+        assert(false && "Unsupported CuDNN type");
+        return CUDNN_DATA_FLOAT;  // TODO: undefined behavior
+    }
+}
+
+#endif // AIDGE_BACKEND_CUDA_CUDA_CONTEXT_H
diff --git a/include/aidge/backend/cuda/utils/CudaUtils.hpp b/include/aidge/backend/cuda/utils/CudaUtils.hpp
index 767025c2d3306565d7efd49483143db216304ad0..76d7ea48e02473deeaa2cb0801a292623a666a1d 100644
--- a/include/aidge/backend/cuda/utils/CudaUtils.hpp
+++ b/include/aidge/backend/cuda/utils/CudaUtils.hpp
@@ -1,23 +1,91 @@
-#ifndef CudaUtils_cuda_H_
-#define CudaUtils_cuda_H_
+#ifndef AIDGE_BACKEND_CUDA_CUDA_UTILS_H
+#define AIDGE_BACKEND_CUDA_CUDA_UTILS_H
 
 #include <string>
-#include <cassert>
+#include <memory>
+#include <sstream>
+#include <iostream>
+#include <stdexcept>
 
+#include <cublas_v2.h>
 #include <cuda.h>
+#include <cudnn.h>
+
+#define CHECK_CUDNN_STATUS(status)                                             \
+    do {                                                                       \
+        const cudnnStatus_t e = (status);                                      \
+        if (e != CUDNN_STATUS_SUCCESS) {                                       \
+            std::stringstream error;                                           \
+            error << "CUDNN failure: " << cudnnGetErrorString(e) << " ("       \
+                  << static_cast<int>(e) << ") in " << __FILE__ << ':' << __LINE__;         \
+            int status_dev;                                                           \
+            if (cudaGetDevice(&status_dev) == cudaSuccess)                            \
+                error << " on device #" << status_dev;                                \
+            std::cerr << error.str() << std::endl;                             \
+            cudaDeviceReset();                                                 \
+            throw std::runtime_error(error.str());                             \
+        }                                                                      \
+    } while(0)
 
 #define CHECK_CUDA_STATUS(status)                                              \
     do {                                                                       \
         const cudaError_t e = (status);                                        \
         if ((e) != cudaSuccess) {                                              \
-            printf("Cuda failure: %s in %s:%d", cudaGetErrorString(e), __FILE__, __LINE__);                            \
-            int dev;                                                           \
-            if (cudaGetDevice(&dev) == cudaSuccess)                            \
-                printf(" on device #%d", dev); \
-            printf("\n"); \
+            std::stringstream error;                                           \
+            error << "Cuda failure: " << cudaGetErrorString(e) << " ("         \
+                  << static_cast<int>(e) << ") in " << __FILE__ << ':' << __LINE__;         \
+            int status_dev;                                                           \
+            if (cudaGetDevice(&status_dev) == cudaSuccess)                            \
+                error << " on device #" << status_dev;                                \
+            std::cerr << error.str() << std::endl;                             \
             cudaDeviceReset();                                                 \
-            assert(false && "Cuda failure");                             \
+            throw std::runtime_error(error.str());                             \
         }                                                                      \
     } while(0)
 
-#endif // CudaUtils_cuda_H_
\ No newline at end of file
+#define CHECK_CUBLAS_STATUS(status)                                            \
+    do {                                                                       \
+        const cublasStatus_t e = (status);                                     \
+        if (e != CUBLAS_STATUS_SUCCESS) {                                      \
+            std::stringstream error;                                           \
+            error << "Cublas failure: "                                        \
+                  << Aidge::Cuda::cublasGetErrorString(e) << " ("               \
+                  << static_cast<int>(e) << ") in " << __FILE__ << ':' << __LINE__;         \
+            int status_dev;                                                           \
+            if (cudaGetDevice(&status_dev) == cudaSuccess)                            \
+                error << " on device #" << status_dev;                                \
+            std::cerr << error.str() << std::endl;                             \
+            cudaDeviceReset();                                                 \
+            throw std::runtime_error(error.str());                             \
+        }                                                                      \
+    } while(0)
+
+namespace Aidge {
+namespace Cuda {
+    const char* cublasGetErrorString(cublasStatus_t error);
+
+    // Enable Peer-to-Peer communications between devices
+    // when it is possible
+    void setMultiDevicePeerAccess(unsigned int size, unsigned int* devices);
+
+    // CuDNN scaling parameters are typically "alpha" and "beta".
+    // Their type must be "float" for HALF and FLOAT (default template)
+    // and "double" for DOUBLE (specialized template)
+    template <class T>
+    struct cudnn_scaling_type {
+        typedef float type;
+    };
+
+    template <>
+    struct cudnn_scaling_type<double> {
+        typedef double type;
+    };
+
+    template <class T>
+    struct cuda_type {
+        typedef T type;
+    };
+}
+}
+
+#endif // AIDGE_BACKEND_CUDA_CUDA_UTILS_H
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c745aef46c5c3531c57be37aef9072f5f1357d96
--- /dev/null
+++ b/src/operator/ConvImpl.cpp
@@ -0,0 +1,189 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <chrono>  // std::chrono::milliseconds
+#include <numeric> // std::accumulate
+#include <thread>  // std::this_thread::sleep_for
+#include <vector>
+
+#include "aidge/utils/Types.h"
+#include "aidge/operator/Conv.hpp"
+
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/ConvImpl.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+
+template <Aidge::DimIdx_t DIM>
+Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
+    assert(mOp.getInput(inputIdx) && "requires valid input");
+
+    // Requires the whole tensors
+    const auto &inputDims = std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->dims();
+
+    return std::accumulate(inputDims.begin(), inputDims.end(), Aidge::NbElts_t(1), std::multiplies<NbElts_t>());
+}
+
+template <Aidge::DimIdx_t DIM>
+Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
+    // for the direct convolution algorithm, convolutions can be in-place, if
+    // there is no padding!
+    return 0;
+}
+
+template <Aidge::DimIdx_t DIM>
+Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
+                                                         const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
+    // Requires the whole tensors, regardless of available data on inputs
+    assert(outputIdx == 0 && "operator has only one output");
+    (void) outputIdx;
+
+    const auto &outputDims = std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dims();
+    return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
+}
+
+template <Aidge::DimIdx_t DIM>
+Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
+    assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size());
+    return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
+}
+
+template <Aidge::DimIdx_t DIM>
+Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
+    assert((outputIdx == 0) && (static_cast<std::size_t>(outputIdx) < mNbProducedData.size()));
+    return mNbProducedData[static_cast<std::size_t>(outputIdx)];
+}
+
+template <Aidge::DimIdx_t DIM>
+void Aidge::ConvImpl_cuda<DIM>::updateConsummerProducer(){
+    // Update producer-consumer data
+    for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx)
+        mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx));  // each input is consumed by the minimum
+                                                                   // amount for a forward pass
+
+    mNbProducedData[0] += getRequiredMemory(0, {});
+}
+
+template <Aidge::DimIdx_t DIM>
+void Aidge::ConvImpl_cuda<DIM>::forward() {
+    // FIXME: uncomment the following code once memory handling will work
+    assert(mOp.getInput(0) && "missing input #0");
+    assert(mOp.getInput(1) && "missing input #1");
+    assert(mOp.getInput(2) && "missing input #2");
+
+    const std::vector<int> strides(mOp.template get<ConvParam::StrideDims>().rbegin(), mOp.template get<ConvParam::StrideDims>().rend());
+    const std::vector<int> paddings(DIM, 0);
+    const std::vector<int> upscales(mOp.template get<ConvParam::DilationDims>().rbegin(), mOp.template get<ConvParam::DilationDims>().rend());
+
+    CHECK_CUDNN_STATUS(
+        cudnnSetConvolutionNdDescriptor(mConvDesc,
+                                        DIM,
+                                        &paddings[0],
+                                        &strides[0],
+                                        &upscales[0],
+                                        CUDNN_CROSS_CORRELATION,
+                                        DataTypeToCudnn(mOp.getInput(2)->dataType())));
+
+    const std::vector<int> cudaKernelDims(mOp.getInput(1)->dims().rbegin(),
+                                          mOp.getInput(1)->dims().rend());
+
+    CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
+    CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
+                                                DataTypeToCudnn(mOp.getInput(1)->dataType()),
+                                                CUDNN_TENSOR_NCHW,
+                                                cudaKernelDims.size(),
+                                                &cudaKernelDims[0]));
+
+    int maxAlgoIterations = 0;
+    cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(),
+                                                &maxAlgoIterations);
+
+    assert(maxAlgoIterations > 0 && "No available CUDNN ConvolutionForwardAlgorithm");
+
+    int returnAlgoCounts = 0;
+
+    std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations);
+/**************************************************************************************************************
+https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnFindConvolutionForwardAlgorithm
+This function attempts all cuDNN algorithms (including CUDNN_TENSOR_OP_MATH and CUDNN_DEFAULT_MATH
+versions of algorithms where CUDNN_TENSOR_OP_MATH may be available) for cudnnConvolutionForward(),
+using memory allocated via cudaMalloc(), and outputs performance metrics to a user-allocated array
+of cudnnConvolutionFwdAlgoPerf_t. These metrics are written in sorted fashion where the first element
+has the lowest compute time. The total number of resulting algorithms can be queried through
+the API cudnnGetConvolutionForwardMaxCount().
+***************************************************************************************************************/
+
+    CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
+                        CudaContext::cudnnHandle(),
+                        static_cast<TensorImpl_cuda<float>*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),  // FIXME: PLAIN WRONG
+                        mFilterDesc,
+                        mConvDesc,
+                        static_cast<TensorImpl_cuda<float>*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),  // FIXME: PLAIN WRONG
+                        maxAlgoIterations,
+                        &returnAlgoCounts,
+                        &returnFwdAlgo[0]));
+    // std::cout << "Layer " << mName << "(" << k  << ")"
+    //     << " cuDNN forward algorithm heuristic results: " << std::endl;
+
+    for(int fwdAlgo = 0; fwdAlgo < maxAlgoIterations; ++fwdAlgo)
+    {
+        std::string algoName
+                            = (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM"
+                            : (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM"
+                            : (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_GEMM)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_GEMM"
+                            : (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_DIRECT)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT"
+                            : (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_FFT)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_FFT"
+                            : (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING"
+                            : (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD"
+                            : (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"
+                            : (returnFwdAlgo[fwdAlgo].algo
+                                    == CUDNN_CONVOLUTION_FWD_ALGO_COUNT)
+                                ? "CUDNN_CONVOLUTION_FWD_ALGO_COUNT"
+                            : "Undetermined Algorithm";
+
+
+        // std::cout << "----> Forward convolution algorithm: " << algoName
+        //     << " [" << returnFwdAlgo[fwdAlgo].time << " ms][" << returnFwdAlgo[fwdAlgo].memory / 1.0e6 << " MB]"
+        //     << std::endl;
+    }
+    mFwdAlgo = returnFwdAlgo[0].algo;
+}
+
+template <Aidge::DimIdx_t DIM>
+Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
+
+}
+
+template <Aidge::DimIdx_t DIM>
+void Aidge::ConvImpl_cuda<DIM>::backward() { printf("Not implemented yet.\n"); }
+
+
+// Template declarations
+void ConvImpl_cuda_template_declaration ()
+{
+    Aidge::ConvImpl_cuda<2> ConvImpl_cuda2(Aidge::Conv_Op<2>());
+}
diff --git a/src/utils/CudaUtils.cpp b/src/utils/CudaUtils.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a6e0514f9a949a805561d966e2a712701c18936c
--- /dev/null
+++ b/src/utils/CudaUtils.cpp
@@ -0,0 +1,51 @@
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+const char* Aidge::Cuda::cublasGetErrorString(cublasStatus_t error)
+{
+    switch (error) {
+    case CUBLAS_STATUS_SUCCESS:
+        return "CUBLAS_STATUS_SUCCESS";
+    case CUBLAS_STATUS_NOT_INITIALIZED:
+        return "CUBLAS_STATUS_NOT_INITIALIZED";
+    case CUBLAS_STATUS_ALLOC_FAILED:
+        return "CUBLAS_STATUS_ALLOC_FAILED";
+    case CUBLAS_STATUS_INVALID_VALUE:
+        return "CUBLAS_STATUS_INVALID_VALUE";
+    case CUBLAS_STATUS_ARCH_MISMATCH:
+        return "CUBLAS_STATUS_ARCH_MISMATCH";
+    case CUBLAS_STATUS_MAPPING_ERROR:
+        return "CUBLAS_STATUS_MAPPING_ERROR";
+    case CUBLAS_STATUS_EXECUTION_FAILED:
+        return "CUBLAS_STATUS_EXECUTION_FAILED";
+    case CUBLAS_STATUS_INTERNAL_ERROR:
+        return "CUBLAS_STATUS_INTERNAL_ERROR";
+    case CUBLAS_STATUS_NOT_SUPPORTED:
+        return "CUBLAS_STATUS_NOT_SUPPORTED";
+    case CUBLAS_STATUS_LICENSE_ERROR:
+        return "CUBLAS_STATUS_LICENSE_ERROR";
+    }
+
+    return "<unknown>";
+}
+
+void Aidge::Cuda::setMultiDevicePeerAccess(unsigned int size, unsigned int* devices)
+{
+    for (unsigned int i = 0; i < size; ++i) {
+        for (unsigned int j = 0; j < size; ++j) {
+            if (i != j) {
+                int canAccessPeer = 0;
+                CHECK_CUDA_STATUS(cudaDeviceCanAccessPeer(&canAccessPeer,
+                                            devices[j], devices[i]));                     
+                if (canAccessPeer) {
+                    CHECK_CUDA_STATUS(cudaSetDevice(devices[j]));
+                    const cudaError_t status = cudaDeviceEnablePeerAccess(devices[i], 0);
+                    if (status == cudaErrorPeerAccessAlreadyEnabled) {
+                        printf("Peer access already enabled between device %d and device %d\n", devices[j], devices[i]);
+                    } else {
+                        CHECK_CUDA_STATUS(status);
+                    }
+                }
+            }
+        }
+    }
+}