diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index 86fe0c49695b5ba4039bda34a54b02bba204e0b6..1b939d70d02615eb04d89890c40a9da3aedbd531 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -13,8 +13,13 @@
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
 
 namespace Aidge {
+class TensorImpl_cuda_ {
+public:
+    virtual const cudnnTensorDescriptor_t& getCudnnTensorDesc() const = 0;
+};
+
 template <class T>
-class TensorImpl_cuda : public TensorImpl {
+class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_  {
    private:
     const Tensor &mTensor;  // Impl needs to access Tensor information, but is not
                             // supposed to change it!
@@ -53,7 +58,7 @@ class TensorImpl_cuda : public TensorImpl {
         return mData;
     };
 
-    const cudnnTensorDescriptor_t& getCudnnTensorDesc() const {
+    const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override {
         if (mCudnnTensor == nullptr) {
             CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor));
 
diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp
index a1979e1100b188f31c78375228119758dbf2aa17..cb1202f4cffd814fac67453065a786c50f1eed18 100644
--- a/include/aidge/backend/cuda/operator/ConvImpl.hpp
+++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp
@@ -31,7 +31,7 @@ namespace Aidge {
 
 template <DimIdx_t DIM>
 class ConvImpl_cuda : public OperatorImpl {
-   private:
+private:
     const Conv_Op<DIM> &mOp;
     std::array<NbElts_t, 3> mNbConsumedData;
     std::array<NbElts_t, 1> mNbProducedData;
@@ -43,7 +43,7 @@ class ConvImpl_cuda : public OperatorImpl {
     cudnnConvolutionFwdAlgo_t mFwdAlgo;
     cudnnConvolutionDescriptor_t mConvDesc;
 
-   public:
+public:
     ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op), mNbConsumedData({0, 0, 0}), mNbProducedData({0}) {
         CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc));
     }
@@ -52,7 +52,7 @@ class ConvImpl_cuda : public OperatorImpl {
         return std::make_unique<ConvImpl_cuda>(op);
     }
 
-   public:
+public:
     NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
     NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
     NbElts_t getRequiredMemory(const IOIndex_t /*outputIdx*/, const std::vector<DimSize_t> &/*inputsSize*/) const override final;
@@ -65,6 +65,9 @@ class ConvImpl_cuda : public OperatorImpl {
     void backward();
 
     ~ConvImpl_cuda();
+
+private:
+    template <class T> void forward_();
 };
 
 namespace {
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index f5e2a94b4b6bcea1574b763640b03dca217eed3f..fdfdc626cbbb86d03a2e24a0b2dc42134be8904c 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -77,7 +77,6 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
     // FIXME: uncomment the following code once memory handling will work
     assert(mOp.getInput(0) && "missing input #0");
     assert(mOp.getInput(1) && "missing input #1");
-    assert(mOp.getInput(2) && "missing input #2");
 
     const std::vector<int> strides(mOp.template get<ConvParam::StrideDims>().rbegin(), mOp.template get<ConvParam::StrideDims>().rend());
     const std::vector<int> paddings(DIM, 0);
@@ -111,22 +110,22 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
     int returnAlgoCounts = 0;
 
     std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations);
-/**************************************************************************************************************
-https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnFindConvolutionForwardAlgorithm
-This function attempts all cuDNN algorithms (including CUDNN_TENSOR_OP_MATH and CUDNN_DEFAULT_MATH
-versions of algorithms where CUDNN_TENSOR_OP_MATH may be available) for cudnnConvolutionForward(),
-using memory allocated via cudaMalloc(), and outputs performance metrics to a user-allocated array
-of cudnnConvolutionFwdAlgoPerf_t. These metrics are written in sorted fashion where the first element
-has the lowest compute time. The total number of resulting algorithms can be queried through
-the API cudnnGetConvolutionForwardMaxCount().
-***************************************************************************************************************/
 
+    /**************************************************************************************************************
+    https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnFindConvolutionForwardAlgorithm
+    This function attempts all cuDNN algorithms (including CUDNN_TENSOR_OP_MATH and CUDNN_DEFAULT_MATH
+    versions of algorithms where CUDNN_TENSOR_OP_MATH may be available) for cudnnConvolutionForward(),
+    using memory allocated via cudaMalloc(), and outputs performance metrics to a user-allocated array
+    of cudnnConvolutionFwdAlgoPerf_t. These metrics are written in sorted fashion where the first element
+    has the lowest compute time. The total number of resulting algorithms can be queried through
+    the API cudnnGetConvolutionForwardMaxCount().
+    ***************************************************************************************************************/
     CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
                         CudaContext::cudnnHandle(),
-                        static_cast<TensorImpl_cuda<float>*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),  // FIXME: PLAIN WRONG
+                        dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),
                         mFilterDesc,
                         mConvDesc,
-                        static_cast<TensorImpl_cuda<float>*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),  // FIXME: PLAIN WRONG
+                        dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),
                         maxAlgoIterations,
                         &returnAlgoCounts,
                         &returnFwdAlgo[0]));
@@ -171,6 +170,66 @@ the API cudnnGetConvolutionForwardMaxCount().
         //     << std::endl;
     }
     mFwdAlgo = returnFwdAlgo[0].algo;
+
+    size_t workspaceSize = 0;
+
+    CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize(
+        CudaContext::cudnnHandle(),
+        dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),
+        mFilterDesc,
+        mConvDesc,
+        dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),
+        mFwdAlgo,
+        &workspaceSize));
+
+    if (mWorkspaceSize != workspaceSize) {
+        if (mWorkspace != nullptr) {
+            cudaFree(mWorkspace);
+            mWorkspaceSize = 0;
+        }
+
+        CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, workspaceSize));
+        mWorkspaceSize = workspaceSize;
+    }
+
+    if (mOp.getOutput(0)->dataType() == DataType::Float64) {
+        forward_<double>();
+    }
+    else {
+        forward_<float>();
+    }
+}
+
+template <Aidge::DimIdx_t DIM>
+template <class T>
+void Aidge::ConvImpl_cuda<DIM>::forward_() {
+    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
+    typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
+
+    CHECK_CUDNN_STATUS(
+        cudnnConvolutionForward(CudaContext::cudnnHandle(),
+                                &alpha,
+                                dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(),
+                                mOp.getInput(0)->getImpl()->rawPtr(),
+                                mFilterDesc,
+                                mOp.getInput(1)->getImpl()->rawPtr(),
+                                mConvDesc,
+                                mFwdAlgo,
+                                mWorkspace,
+                                mWorkspaceSize,
+                                &beta,
+                                dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),
+                                mOp.getOutput(0)->getImpl()->rawPtr()));
+
+    if (mOp.getInput(2) != nullptr) {
+        CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
+                                            &alpha,
+                                            dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(2)->getImpl().get())->getCudnnTensorDesc(),
+                                            mOp.getInput(2)->getImpl()->rawPtr(),
+                                            &alpha,
+                                            dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(),
+                                            mOp.getOutput(0)->getImpl()->rawPtr()));
+    }
 }
 
 template <Aidge::DimIdx_t DIM>
diff --git a/unit_tests/Test_ConvImpl.cpp b/unit_tests/Test_ConvImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2746f82ffa9840eec7ceec15a1977d0e272a9bde
--- /dev/null
+++ b/unit_tests/Test_ConvImpl.cpp
@@ -0,0 +1,161 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <array>
+
+#include <catch2/catch_test_macros.hpp>
+
+#include "Test_cuda.hpp"
+
+#include "aidge/data/Tensor.hpp"
+
+#include "aidge/backend/cpu.hpp"
+#include "aidge/backend/cuda.hpp"
+
+using namespace Aidge;
+
+TEST_CASE("[gpu/operator] Conv(forward)") {
+    SECTION("Classic Conv") {
+        std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv");
+        myConv->getOperator()->setDatatype(DataType::Float32);
+        myConv->getOperator()->setBackend("cuda");
+        std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<int,4,3,3,3> {
+            {
+                {
+                    {{  0,   1,   2},
+                    {  3,   4,   5},
+                    {  6,   7,   8}},
+                    {{  9,  10,  11},
+                    { 12,  13,  14},
+                    { 15,  16,  17}},
+                    {{ 18,  19,  20},
+                    { 21,  22,  23},
+                    { 24,  25,  26}}
+                },
+                {
+                    {{ 27,  28,  29},
+                    { 30,  31,  32},
+                    { 33,  34,  35}},
+                    {{ 36,  37,  38},
+                    { 39,  40,  41},
+                    { 42,  43,  44}},
+                    {{ 45,  46,  47},
+                    { 48,  49,  50},
+                    { 51,  52,  53}}
+                },
+                {
+                    {{ 54,  55,  56},
+                    { 57,  58,  59},
+                    { 60,  61,  62}},
+                    {{ 63,  64,  65},
+                    { 66,  67,  68},
+                    { 69,  70,  71}},
+                    {{ 72,  73,  74},
+                    { 75,  76,  77},
+                    { 78,  79,  80}}
+                },
+                {
+                    {{ 81,  82,  83},
+                    { 84,  85,  86},
+                    { 87,  88,  89}},
+                    {{ 90,  91,  92},
+                    { 93,  94,  95},
+                    { 96,  97,  98}},
+                    {{ 99, 100, 101},
+                    {102, 103, 104},
+                    {105, 106, 107}}
+                }
+            }
+        });
+        std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<int,4> {{7,0,9,0}});
+        std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<int,2,3,5,5> { //NCHW
+            {
+                {
+                    {{  0,   1,   2,   3,   4},
+                    {  5,   6,   7,   8,   9},
+                    { 10,  11,  12,  13,  14},
+                    { 15,  16,  17,  18,  19},
+                    { 20,  21,  22,  23,  24}},
+
+                    {{ 25,  26,  27,  28,  29},
+                    { 30,  31,  32,  33,  34},
+                    { 35,  36,  37,  38,  39},
+                    { 40,  41,  42,  43,  44},
+                    { 45,  46,  47,  48,  49}},
+
+                    {{ 50,  51,  52,  53,  54},
+                    { 55,  56,  57,  58,  59},
+                    { 60,  61,  62,  63,  64},
+                    { 65,  66,  67,  68,  69},
+                    { 70,  71,  72,  73,  74}}
+                },
+                {
+                    {{ 75,  76,  77,  78,  79},
+                    { 80,  81,  82,  83,  84},
+                    { 85,  86,  87,  88,  89},
+                    { 90,  91,  92,  93,  94},
+                    { 95,  96,  97,  98,  99}},
+
+                    {{100, 101, 102, 103, 104},
+                    {105, 106, 107, 108, 109},
+                    {110, 111, 112, 113, 114},
+                    {115, 116, 117, 118, 119},
+                    {120, 121, 122, 123, 124}},
+
+                    {{125, 126, 127, 128, 129},
+                    {130, 131, 132, 133, 134},
+                    {135, 136, 137, 138, 139},
+                    {140, 141, 142, 143, 144},
+                    {145, 146, 147, 148, 149}}
+                }
+            }
+        });
+        std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<int,2,4,3,3> { 
+            {
+                {
+                    {{ 15226,  15577,  15928},
+                    { 16981,  17332,  17683},
+                    { 18736,  19087,  19438}},
+                    {{ 37818,  38898,  39978},
+                    { 43218,  44298,  45378},
+                    { 48618,  49698,  50778}},
+                    {{ 60426,  62235,  64044},
+                    { 69471,  71280,  73089},
+                    { 78516,  80325,  82134}},
+                    {{ 83016,  85554,  88092},
+                    { 95706,  98244, 100782},
+                    {108396, 110934, 113472}}
+                },
+                {
+                    {{ 41551,  41902,  42253},
+                    { 43306,  43657,  44008},
+                    { 45061,  45412,  45763}},
+                    {{118818, 119898, 120978},
+                    {124218, 125298, 126378},
+                    {129618, 130698, 131778}},
+                    {{196101, 197910, 199719},
+                    {205146, 206955, 208764},
+                    {214191, 216000, 217809}},
+                    {{273366, 275904, 278442},
+                    {286056, 288594, 291132},
+                    {298746, 301284, 303822}}
+                }
+            }
+        });
+        myConv->getOperator()->associateInput(0,myInput);
+        myConv->getOperator()->associateInput(1,myWeights);
+        myConv->getOperator()->associateInput(2,myBias);
+        myConv->getOperator()->computeOutputDims();
+        myConv->forward();
+        // myConv->getOperator()->getOutput(0)->print();
+        REQUIRE(*(myConv->getOperator()->getOutput(0)) == *myOutput);
+    }
+}