diff --git a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp
index ffbb06752d39c14dc15f24967a394ed6acb9521b..9c83332f9857ed802f5563faef558a7278d3e992 100644
--- a/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp
+++ b/include/aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp
@@ -25,12 +25,18 @@ namespace Aidge {
 
 template <class T>
 cublasStatus_t cublasGemm(cublasHandle_t handle,
-                           cublasOperation_t transa, cublasOperation_t transb,
-                           int m, int n, int k,
-                           const T *alpha,
-                           const T *A, int lda,
-                           const T *B, int ldb,
-                           const T *beta,
-                           T *C, int ldc);
+                          cublasOperation_t transa, cublasOperation_t transb,
+                          int m, int n, int k,
+                          const T *alpha,
+                          const T *A, int lda,
+                          const T *B, int ldb,
+                          const T *beta,
+                          T *C, int ldc);
+// cublasGemm(cublasContext*&, cublasOperation_t, cublasOperation_t, int&, int&, int&, 
+						//  const type*, 
+						//  const __half*&, int&, 
+						//  const __half*&, int&, 
+						//  const type*, 
+						//  __half*&, int&)’
 }
 #endif /* AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_ */
\ No newline at end of file
diff --git a/include/aidge/backend/cuda/utils/CudaUtils.hpp b/include/aidge/backend/cuda/utils/CudaUtils.hpp
index 2f66d0e778778400f0b7def345619d635cc37674..a505d331dc182b6e24857b0d1045282688fdf8d8 100644
--- a/include/aidge/backend/cuda/utils/CudaUtils.hpp
+++ b/include/aidge/backend/cuda/utils/CudaUtils.hpp
@@ -11,6 +11,8 @@
 #include <cuda.h>
 #include <cudnn.h>
 
+#include "aidge/data/half.hpp"
+
 #define CHECK_CUDNN_STATUS(status)                                             \
     do {                                                                       \
         const cudnnStatus_t e = (status);                                      \
@@ -62,6 +64,29 @@
 
 namespace Aidge {
 namespace Cuda {
+    // CuDNN scaling parameters are typically "alpha" and "beta".
+    // Their type must be "float" for HALF and FLOAT (default template)
+    // and "double" for DOUBLE (specialized template)
+    template <class T>
+    struct cudnn_scaling_type {
+        typedef float type;
+    };
+
+    template <>
+    struct cudnn_scaling_type<double> {
+        typedef double type;
+    };
+
+    template <class T>
+    struct cuda_type {
+        typedef T type;
+    };
+
+    template <>
+    struct cuda_type<half_float::half> {
+        typedef __half type;
+    };
+
     const char* cublasGetErrorString(cublasStatus_t error);
 
     // Enable Peer-to-Peer communications between devices
diff --git a/src/operator/AvgPoolingImpl.cpp b/src/operator/AvgPoolingImpl.cpp
index 4737e3ebf43622f81dbe8938c5dddcaaca94cb80..861533eced6112903ea288f09711f3a382db542c 100644
--- a/src/operator/AvgPoolingImpl.cpp
+++ b/src/operator/AvgPoolingImpl.cpp
@@ -10,17 +10,14 @@
  ********************************************************************************/
 
 #include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
 #include <vector>
 
-#include "aidge/utils/Types.h"
-#include "aidge/operator/AvgPooling.hpp"
-
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp"
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/AvgPooling.hpp"
+#include "aidge/utils/Types.h"
 
 template <Aidge::DimIdx_t DIM>
 void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
@@ -49,11 +46,18 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
                                         &strides[0]));
     }
 
-    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
-        forward_<double>(input);
-    }
-    else {
-        forward_<float>(input);
+    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
+        case DataType::Float64:
+            forward_<double>(input);
+            break;
+        case DataType::Float32:
+            forward_<float>(input);
+            break;
+        case DataType::Float16:
+            forward_<half>(input);
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
     }
 }
 
@@ -61,8 +65,8 @@ template <Aidge::DimIdx_t DIM>
 template <class T>
 void Aidge::AvgPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
-    const T alpha = 1.0f;
-    const T beta = 0.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
     CHECK_CUDNN_STATUS(
         cudnnPoolingForward(
             CudaContext::cudnnHandle(),
@@ -83,6 +87,5 @@ Aidge::AvgPoolingImpl_cuda<DIM>::~AvgPoolingImpl_cuda() {
         cudnnDestroyPoolingDescriptor(mAvgPoolingDesc);
 }
 
-
 // Template declarations
 template class Aidge::AvgPoolingImpl_cuda<2>;
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index 19ce56bcb99f60e08427f8d9b110637c90582adf..b64b2e3d470fd8813938a800293d55b7ba9a4076 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -10,17 +10,13 @@
  ********************************************************************************/
 
 #include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
 #include <vector>
 
-#include "aidge/utils/Types.h"
-#include "aidge/operator/Conv.hpp"
-
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/ConvImpl.hpp"
-#include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/Conv.hpp"
+#include "aidge/utils/Types.h"
 
 template <Aidge::DimIdx_t DIM>
 void Aidge::ConvImpl_cuda<DIM>::forward() {
@@ -106,11 +102,18 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
     // Do the actual forward computation
     // Template is only for scaling parameters, which are always in float
     // excepted when the convolution is performed in double precision.
-    if (op.getOutput(0)->dataType() == DataType::Float64) {
-        forward_<double>(input0, input1, input2);
-    }
-    else {
-        forward_<float>(input0, input1, input2);
+    switch(op.getOutput(0)->dataType()) {
+        case DataType::Float64:
+            forward_<double>(input0, input1, input2);
+            break;
+        case DataType::Float32:
+            forward_<float>(input0, input1, input2);
+            break;
+        case DataType::Float16:
+            forward_<half>(input0, input1, input2);
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
     }
 }
 
@@ -118,9 +121,8 @@ template <Aidge::DimIdx_t DIM>
 template <class T>
 void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
-    const T alpha = 1.0f;
-    const T beta = 0.0f;
-
+    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
     CHECK_CUDNN_STATUS(cudnnConvolutionForward(CudaContext::cudnnHandle(),
         &alpha,
         std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp
index bb107a5c8044bd40c874493f275e20d02b7298d2..a8f8da8e114baaae3f7d09146cdd5b664150f260 100644
--- a/src/operator/FCImpl.cpp
+++ b/src/operator/FCImpl.cpp
@@ -15,15 +15,13 @@
 #include <thread>  // std::this_thread::sleep_for
 #include <vector>
 
-#include "aidge/utils/Types.h"
-#include "aidge/operator/FC.hpp"
-
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/FCImpl.hpp"
 #include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp"
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
-
 #include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/FC.hpp"
+#include "aidge/utils/Types.h"
 
 void Aidge::FCImpl_cuda::forward() {
     assert(mOp.getRawInput(0) && "missing input #0");
@@ -39,11 +37,18 @@ void Aidge::FCImpl_cuda::forward() {
     bool noBias = fcOp.template getAttr<FCAttr::NoBias>();
     std::size_t outChannels = static_cast<std::size_t>(fcOp.template getAttr<FCAttr::OutChannels>());
 
-    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
-        forward_<double>(input0, input1, input2, noBias, outChannels);
-    }
-    else {
-        forward_<float>(input0, input1, input2, noBias, outChannels);
+    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
+        case DataType::Float64:
+            forward_<double>(input0, input1, input2, noBias, outChannels);
+            break;
+        case DataType::Float32:
+            forward_<float>(input0, input1, input2, noBias, outChannels);
+            break;
+        case DataType::Float16:
+            forward_<half>(input0, input1, input2, noBias, outChannels);
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
     }
 }
 
@@ -61,26 +66,26 @@ void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, co
     int lda = k;
     int ldb = k;
     int ldc = n;
-    const T alpha = T(1.0);
-    const T beta = T(0.0);
+    const T alpha = 1.0f;
+    const T beta = 0.0f;
     CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(),
                                     CUBLAS_OP_T,
                                     CUBLAS_OP_N,
                                     n,
                                     m,
                                     k,
-                                    &alpha,
+                                    reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                     weights,
                                     ldb,
                                     input, 
                                     lda,
-                                    &beta,
+                                    reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&beta),
                                     output,
                                     ldc));
 
     if(!noBias){
         T* onesVector;
-        cudaMalloc((void**)&onesVector, m * sizeof(T));
+        CHECK_CUDA_STATUS(cudaMalloc((void**)&onesVector, m * sizeof(T)));
         // Fill the vector with ones
         std::vector<T> onesVec(m, T(1.0));
         CHECK_CUDA_STATUS(cudaMemcpy(onesVector,
@@ -95,12 +100,12 @@ void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, co
                                        n,
                                        m,
                                        1,
-                                       &alpha,
+                                       reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                        biases,
                                        n,
                                        onesVector,
                                        1,
-                                       &alpha,
+                                       reinterpret_cast<const typename Cuda::cuda_type<T>::type*>(&alpha),
                                        output,
                                        n));
 
diff --git a/src/operator/FCImpl_CUDA_kernels.cu b/src/operator/FCImpl_CUDA_kernels.cu
index a30519ebe4bf262b87bb9e07342f18525df2e8f4..5139ac1d7edf61cf347870e6add2870b2792a0e5 100644
--- a/src/operator/FCImpl_CUDA_kernels.cu
+++ b/src/operator/FCImpl_CUDA_kernels.cu
@@ -16,61 +16,61 @@ namespace Aidge{
 
 template <>
 cublasStatus_t cublasGemm<__half>(cublasHandle_t handle,
-                           cublasOperation_t transa, cublasOperation_t transb,
-                           int m, int n, int k,
-                           const __half *alpha,
-                           const __half *A, int lda,
-                           const __half *B, int ldb,
-                           const __half *beta,
-                           __half *C, int ldc)
+                                  cublasOperation_t transa, cublasOperation_t transb,
+                                  int m, int n, int k,
+                                  const __half *alpha,
+                                  const __half *A, int lda,
+                                  const __half *B, int ldb,
+                                  const __half *beta,
+                                  __half *C, int ldc)
 {
     return cublasHgemm(handle,
-        transa, transb,
-        m, n, k,
-        alpha,
-        A, lda,
-        B, ldb,
-        beta,
-        C, ldc);
+					   transa, transb,
+					   m, n, k,
+					   alpha,
+					   A, lda,
+					   B, ldb,
+					   beta,
+					   C, ldc);
 }
 
 template <>
 cublasStatus_t cublasGemm<float>(cublasHandle_t handle,
-                           cublasOperation_t transa, cublasOperation_t transb,
-                           int m, int n, int k,
-                           const float *alpha,
-                           const float *A, int lda,
-                           const float *B, int ldb,
-                           const float *beta,
-                           float *C, int ldc)
+                                 cublasOperation_t transa, cublasOperation_t transb,
+                                 int m, int n, int k,
+                                 const float *alpha,
+                                 const float *A, int lda,
+                                 const float *B, int ldb,
+                                 const float *beta,
+                                 float *C, int ldc)
 {
     return cublasSgemm(handle,
-        transa, transb,
-        m, n, k,
-        alpha,
-        A, lda,
-        B, ldb,
-        beta,
-        C, ldc);
+					   transa, transb,
+					   m, n, k,
+					   alpha,
+					   A, lda,
+					   B, ldb,
+					   beta,
+					   C, ldc);
 }
 
 template <>
 cublasStatus_t cublasGemm<double>(cublasHandle_t handle,
-                           cublasOperation_t transa, cublasOperation_t transb,
-                           int m, int n, int k,
-                           const double *alpha,
-                           const double *A, int lda,
-                           const double *B, int ldb,
-                           const double *beta,
-                           double *C, int ldc)
+								  cublasOperation_t transa, cublasOperation_t transb,
+								  int m, int n, int k,
+								  const double *alpha,
+								  const double *A, int lda,
+								  const double *B, int ldb,
+								  const double *beta,
+								  double *C, int ldc)
 {
     return cublasDgemm(handle,
-        transa, transb,
-        m, n, k,
-        alpha,
-        A, lda,
-        B, ldb,
-        beta,
-        C, ldc);
+					   transa, transb,
+					   m, n, k,
+					   alpha,
+					   A, lda,
+					   B, ldb,
+					   beta,
+					   C, ldc);
 }
 }
\ No newline at end of file
diff --git a/src/operator/MaxPoolingImpl.cpp b/src/operator/MaxPoolingImpl.cpp
index 9304160d0d509014785820745f61187fdf13c17e..19a567fe4b273e821f95a16989d0a09bd510fe07 100644
--- a/src/operator/MaxPoolingImpl.cpp
+++ b/src/operator/MaxPoolingImpl.cpp
@@ -10,17 +10,14 @@
  ********************************************************************************/
 
 #include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
 #include <vector>
 
-#include "aidge/utils/Types.h"
-#include "aidge/operator/MaxPooling.hpp"
-
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp"
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/MaxPooling.hpp"
+#include "aidge/utils/Types.h"
 
 template <Aidge::DimIdx_t DIM>
 void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
@@ -49,11 +46,18 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
                                         &strides[0]));
     }
 
-    if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
-        forward_<double>(input);
-    }
-    else {
-        forward_<float>(input);
+    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
+        case DataType::Float64:
+            forward_<double>(input);
+            break;
+        case DataType::Float32:
+            forward_<float>(input);
+            break;
+        case DataType::Float16:
+            forward_<half>(input);
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
     }
 }
 
@@ -61,8 +65,8 @@ template <Aidge::DimIdx_t DIM>
 template <class T>
 void Aidge::MaxPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
-    const T alpha = 1.0f;
-    const T beta = 0.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
     CHECK_CUDNN_STATUS(
         cudnnPoolingForward(
             CudaContext::cudnnHandle(),
diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp
index ed2e5d4a0ed14c68ff5329932e938876165e92e3..c880184bc51fce65710d02d9a483e43de6184d89 100644
--- a/src/operator/ReLUImpl.cpp
+++ b/src/operator/ReLUImpl.cpp
@@ -10,17 +10,14 @@
  ********************************************************************************/
 
 #include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
 #include <vector>
 
-#include "aidge/utils/Types.h"
-#include "aidge/operator/ReLU.hpp"
-
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/ReLUImpl.hpp"
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/ReLU.hpp"
+#include "aidge/utils/Types.h"
 
 void Aidge::ReLUImpl_cuda::forward() {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
@@ -52,8 +49,8 @@ void Aidge::ReLUImpl_cuda::forward() {
 template <class T>
 void Aidge::ReLUImpl_cuda::forward_(const Tensor& input) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
-    const T alpha = 1.0f;
-    const T beta = 0.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
+    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
     CHECK_CUDNN_STATUS(
         cudnnActivationForward(CudaContext::cudnnHandle(),
                                mReLUDesc,
diff --git a/unit_tests/Test_AvgPoolingImpl.cpp b/unit_tests/Test_AvgPoolingImpl.cpp
index def2e93f4105eb107c30f7cce3a2a2038da12d58..9b7f898ecb1a321713dc114d6a03d057a810271f 100644
--- a/unit_tests/Test_AvgPoolingImpl.cpp
+++ b/unit_tests/Test_AvgPoolingImpl.cpp
@@ -12,9 +12,11 @@
 #include <array>
 
 #include <catch2/catch_test_macros.hpp>
+#include <cuda_fp16.h>
 
 #include "Test_cuda.hpp"
 
+#include "aidge/data/half.hpp"
 #include "aidge/data/Tensor.hpp"
 
 #include "aidge/backend/cpu.hpp"
@@ -122,4 +124,38 @@ TEST_CASE("[gpu/operator] AvgPooling(forward)", "[AvgPooling][GPU]") {
 
         delete[] computedOutput;
     }
+
+    SECTION("half") {
+        std::shared_ptr<Tensor> myInput2 = std::make_shared<Tensor>(Array4D<half_float::half,1,1,3,3> { //NCHW
+        {
+            {
+                {{half_float::half(0.3745), half_float::half(0.9507), half_float::half(0.7320)},
+                 {half_float::half(0.5987), half_float::half(0.1560), half_float::half(0.1560)},
+                 {half_float::half(0.0581), half_float::half(0.8662), half_float::half(0.6011)}}
+            }
+        }
+        });
+        myInput2->setBackend("cuda");
+
+        std::shared_ptr<Node> myAvgPool = AvgPooling({3,3}, "mycdw", {3,3});
+        auto op = std::static_pointer_cast<OperatorTensor>(myAvgPool -> getOperator());
+        std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<half_float::half,1,1,1,1> {
+            {{{{(half_float::half(0.3745) + half_float::half(0.9507) + half_float::half(0.7320) + half_float::half(0.5987) + half_float::half(0.1560) + half_float::half(0.1560) + half_float::half(0.0581) + half_float::half(0.8662) + half_float::half(0.6011))/half_float::half(9.0)}}}}
+        });
+        op->associateInput(0,myInput2);
+        op->setDataType(DataType::Float16);
+        op->setBackend("cuda");
+        op->computeOutputDims();
+        myAvgPool->forward();
+
+        half* computedOutput   = new half[myOutput->size()]();
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(half) * myOutput->size(), cudaMemcpyDeviceToHost);
+
+        for(int i = 0; i < myOutput->size(); i++){
+            const half_float::half targetOutput = *(static_cast<half_float::half*>(myOutput->getImpl()->rawPtr()) + i);
+            REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
+        }
+
+        delete[] computedOutput;
+    }
 }
\ No newline at end of file