diff --git a/include/aidge/backend/cuda/operator/ILayerNormImpl.hpp b/include/aidge/backend/cuda/operator/ILayerNormImpl.hpp
index 2c0c2791c8e3a4531dab32fe657f2663914a0fe5..30fcd84b1574d9f8efa654fa43c727d513ce55d3 100644
--- a/include/aidge/backend/cuda/operator/ILayerNormImpl.hpp
+++ b/include/aidge/backend/cuda/operator/ILayerNormImpl.hpp
@@ -1,11 +1,13 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 Thales
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
  * http://www.eclipse.org/legal/epl-2.0.
  *
  * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 10.09.2024
  *
  ********************************************************************************/
 
@@ -50,7 +52,7 @@ private:
 };
 
 namespace {
-// add cuda backend to ShiftMax_Op implementation registry
+// add cuda backend to ILayerNorm_Op implementation registry
 static Registrar<ILayerNorm_Op> registrarILayerNormImpl_cuda("cuda", Aidge::ILayerNormImpl_cuda::create);
 }  // namespace
 }  // namespace Aidge
diff --git a/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp
index 8f40d2a34e7ee57e8a13cab97d5d3487f2a833f7..aa54029ea29bc46809f227038a1a23d91bc161ee 100644
--- a/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp
+++ b/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp
@@ -1,11 +1,13 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 Thales
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
  * http://www.eclipse.org/legal/epl-2.0.
  *
  * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 10.09.2024
  *
  ********************************************************************************/
 
@@ -23,17 +25,67 @@
 
 namespace Aidge {
 
+/**
+    * @brief Compute the forward for ILayerNorm
+    * @param input: Input tensor
+    * @param SF: Scaling factor of input tensor
+    * @param dims: Dimensions of input tensor
+    * @param quantized_tensor: Quantized output tensor
+    * @param square_tensor: Tensor use for computation
+    * @param weight: weight of ILayerNorm layer 
+    * @param bias: bias of ILayerNorm layer
+    * @param new_SF: Scaling factor of output that can be use to dequantify
+*/
 template <class T>
-void ILayerNormLaunchKernel(const T* input, T* output, double SF, const T* weight_raw, const T* bias_raw, size_t size, std::vector<long unsigned int> dims_input);
+__global__ void ILayerNormforward_(T* input, double SF, int* dims, int* quantized_tensor,long long int* square_tensor, T* weight, T* biase, double new_SF);
 
+/**
+    * @brief Wrapper function to execute ILayerNormforward_
+    * @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor
+    * @param input: Input tensor
+    * @param output: Output tensor (not quantized)
+    * @param SF: Scaling factor of input tensor
+    * @param weight_raw: weight of ILayerNorm layer 
+    * @param bias_raw: bias of ILayerNorm layer
+    * @param size: Number of elements in the input tensor
+    * @param dims: Dimensions of input tensor
+*/
 template <class T>
-void ILayerNormBackPropagation(const T* InputTensor, const T* Grad, const T* Normalised_Tensor,const T* mean,const T* var, const T* weight, const T* bias, T* grad_input, T* grad_weight, T* grad_bias, size_t size, std::vector<long unsigned int> dims_input);
+void ILayerNormforward(const T* input, T* output, double SF, const T* weight_raw, const T* bias_raw, size_t size, std::vector<long unsigned int> dims_input);
 
+/**
+    * @brief Compute the backward for ILayerNorm
+    * @param output_grad: Gradient of output tensor
+    * @param input_tensor: Input tensor
+    * @param output_tensor: Output tensor obtained after forward
+    * @param mean: Arithmetic mean of input tensor
+    * @param var: Arithmetic variance of input tensor
+    * @param weight: weight of ILayerNorm layer 
+    * @param bias: bias of ILayerNorm layer
+    * @param input_grad: Gradient of input tensor 
+    * @param weight_grad: Gradient of ILayerNorm weight 
+    * @param bias_grad: Gradient of ILayerNorm bias 
+    * @param size: Number of elements in the input tensor
+*/
 template <class T>
-__global__ void ILayerNormKernel(T* input, double SF, int* dims, int* quantized_tensor,long long int* square_tensor, T* weight, T* biase, double new_SF) ;
+__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size);
 
+/**
+    * @brief Wrapper function to execute ILayerNormbackward_
+    * @param input_tensor: Input tensor
+    * @param output_grad: Gradient of output tensor
+    * @param output_tensor: Output tensor obtained after forward
+    * @param mean: Arithmetic mean of input tensor
+    * @param var: Arithmetic variance of input tensor
+    * @param weight: weight of ILayerNorm layer 
+    * @param bias: bias of ILayerNorm layer
+    * @param input_grad: Gradient of input tensor 
+    * @param weight_grad: Gradient of ILayerNorm weight 
+    * @param bias_grad: Gradient of ILayerNorm bias 
+    * @param size: Number of elements in the input tensor
+*/
 template <class T>
-__global__ void ILayerNormBackward(T* d_output, T* x, T* norm_x, T* mean, T* var, T* weight, T* bias, T* grad_x, T* grad_weight, T* grad_bias,int size);
+void ILayerNormbackward(const T* input_tensor, const T* output_grad, const T* output_tensor,const T* mean,const T* var, const T* weight, const T* bias, T* input_grad, T* weight_grad, T* bias_grad, size_t size);
 
 }
 
diff --git a/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp
index 7d6c5920ee76d1b18bf5db077ed57c6c6c8a4d68..ab92ea91c6d6a9a08f7d0423df0f4a5e45b0df66 100644
--- a/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp
+++ b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp
@@ -25,17 +25,53 @@
 
 namespace Aidge {
 
+/**
+    * @brief Compute the forward for ShiftGELU
+    * @param input: Input tensor
+    * @param quantized_tensor: Quantized output tensor
+    * @param GELUtensor: Pointer to an empty memory block allocated on the GPU (just use for computation)
+    * @param SumTensor: Pointer to an empty memory block allocated on the GPU (just use for computation)
+    * @param dims: Dimensions of input tensor
+    * @param SF: Scaling factor of input tensor
+    * @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required)
+    * @param output_bits: Desired bit precision (8 for int8, for example)
+*/
 template <class T>
-void ShiftGELULaunchKernel(const T* input, T* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
+__global__ void ShiftGELUforward_(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits);
 
+/**
+    * @brief Wrapper function to execute ShiftGELUforward_
+    * @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor
+    * @param input: Input tensor
+    * @param output: Output tensor (not quantized)
+    * @param SF: Scaling factor of input tensor
+    * @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required)
+    * @param output_bits: Desired bit precision (8 for int8, for example)
+    * @param size: Number of elements in the input tensor
+    * @param dims_input: Dimensions of input tensor
+*/
 template <class T>
-void ShiftGELUBackPropagation(const T* InputTensor, const T* Grad, T* output, size_t size);
+void ShiftGELUforward(const T* input, T* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
 
+/**
+    * @brief Compute the backward for ShiftGELU
+    * @param input_grad: Gradient of input tensor (that we want to obtain)
+    * @param output_tensor: Output tensor obtained after forward
+    * @param output_grad: Gradient of output tensor
+    * @param size: Number of elements in the input tensor
+*/
 template <class T>
-__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits);
+__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size);
 
+/**
+    * @brief Wrapper function to execute ShiftGELUbackward_
+    * @param output_tensor: Output tensor obtained after forward
+    * @param output_grad: Gradient of output tensor
+    * @param input_grad: Gradient of input tensor (that we want to obtain)
+    * @param size: Number of elements in the input tensor
+*/
 template <class T>
-__global__ void gelu_backward(T* grad_input, const T* GELU_output, const T* grad_output, int size);
+void ShiftGELUbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size);
 
 }
 
diff --git a/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp
index e1e7ce74226fc7f810792c402ae6627bb1c02648..a6eea419f0787e35aef9a41e41e0da6d883baca0 100644
--- a/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp
+++ b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp
@@ -25,17 +25,54 @@
 
 namespace Aidge {
 
+/**
+    * @brief Compute the forward for ShiftMax
+    * @param input: Input tensor
+    * @param quantized_tensor: Quantized output tensor
+    * @param factor: Pointer to an empty memory block allocated on the GPU (just use for computation)
+    * @param dims: Dimensions of input tensor
+    * @param SF: Scaling factor of input tensor
+    * @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required)
+    * @param output_bits: Desired bit precision (8 for int8, for example)
+    * @param new_SF: Scaling factor of output that can be use to dequantify
+*/
 template <class T>
-void ShiftMaxLaunchKernel(const T* input, T* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
+__global__ void ShiftMaxforward_(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF);
 
+/**
+    * @brief Wrapper function to execute ShiftMaxforward_
+    * @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor
+    * @param input: Input tensor
+    * @param output: Output tensor (not quantized)
+    * @param SF: Scaling factor of input tensor
+    * @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required)
+    * @param output_bits: Desired bit precision (8 for int8, for example)
+    * @param size: Number of elements in the input tensor
+    * @param dims_input: Dimensions of input tensor
+*/
 template <class T>
-void ShiftMaxBackPropagation(const T* InputTensor, const T* Grad, T* output, size_t size, std::vector<long unsigned int> dims_input) ;
+void ShiftMaxforward(const T* input, T* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
 
+/**
+    * @brief Compute the backward for ShiftMax
+    * @param input_grad: Gradient of input tensor (that we want to obtain)
+    * @param output_tensor: Output tensor obtained after forward
+    * @param output_grad: Gradient of output tensor
+    * @param dims: Dimensions of input tensor
+*/
 template <class T>
-__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF) ;
+__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims);
 
+/**
+    * @brief Wrapper function to execute ShiftMaxbackward_
+    * @param output_tensor: Output tensor obtained after forward
+    * @param output_grad: Gradient of output tensor
+    * @param input_grad: Gradient of input tensor (that we want to obtain)
+    * @param size: Number of elements in the input tensor
+    * @param dims: Dimensions of input tensor
+*/
 template <class T>
-__global__ void shiftmax_backward(T* grad_input, const T* shiftmax_output, const T* grad_output, const int* dims) ;
+void ShiftMaxbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size, std::vector<long unsigned int> dims);
 
 }
 
diff --git a/src/operator/ILayerNormImpl.cpp b/src/operator/ILayerNormImpl.cpp
index cbcaaf8f2a7ffd8672a507b122d04c4a86b0a0b0..47dd1d5d1a3f127c9e08788f605796020a7814a7 100644
--- a/src/operator/ILayerNormImpl.cpp
+++ b/src/operator/ILayerNormImpl.cpp
@@ -1,11 +1,13 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 Thales
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
  * http://www.eclipse.org/legal/epl-2.0.
  *
  * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 10.09.2024
  *
  ********************************************************************************/
 
@@ -39,22 +41,13 @@ void Aidge::ILayerNormImpl_cuda::forward() {
     const auto& input1 = op.getInput(1)->refCastFrom(mInput1Fallback, *op.getOutput(0));
     const auto& input2 = op.getInput(2)->refCastFrom(mInput2Fallback, *op.getOutput(0));
 
-    //forward_<half>(input);
-    //forward_<float>(input0, input1, input2);
-
-    //should template and changing type
     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
         case DataType::Float64:
-            //forward_<float>(input);
             forward_<double>(input0, input1, input2);
             break;
         case DataType::Float32:
             forward_<float>(input0, input1, input2);
             break;
-        case DataType::Float16:
-            //forward_<half>(input);
-            forward_<float>(input0, input1, input2);
-            break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
     }
@@ -68,13 +61,15 @@ void Aidge::ILayerNormImpl_cuda::forward_(const Tensor& input0, const Tensor& in
     const T * input_raw = static_cast<const T*>(input0.getImpl()->rawPtr());
     const T * weight = static_cast<const T*>(input1.getImpl()->rawPtr());
     const T * bias = static_cast<const T*>(input2.getImpl()->rawPtr());
+    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
 
     int N = 15;
     int output_bits = 8;
-
     size_t size = input0.size();
     std::vector<DimSize_t> dims_input = input0.dims();
 
+    // maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value)
+
     double min = std::numeric_limits<double>::max();
     double max = std::numeric_limits<double>::min();
     for(std::size_t i = 0; i < dims_input[0]; i++) {
@@ -93,20 +88,14 @@ void Aidge::ILayerNormImpl_cuda::forward_(const Tensor& input0, const Tensor& in
             }     
         }
     }
-    
     double m = std::max(std::abs(min), std::abs(max));
-
-    // Calculate the normalization factor
     double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1;
-
-    // Return the normalized maximum
-    double final_sf =  m / normalization_factor;
-    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
+    double scaling_factor =  m / normalization_factor;
     
-    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction ILayerNorm, utilisé pour déquantifier le tenseur renvoyé
+    // The new scaling factor that we can use to dequantify the returned tensor (not used here)
+    // double new_SF = 1/std::pow(2,2*output_bits-1); 
 
-    //ILayerNormLaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input);
-    ILayerNormLaunchKernel(input_raw, output, final_sf, weight, bias, size, dims_input);
+    ILayerNormforward(input_raw, output, scaling_factor, weight, bias, size, dims_input);
 }
 
 void Aidge::ILayerNormImpl_cuda::backward() {
@@ -116,26 +105,22 @@ void Aidge::ILayerNormImpl_cuda::backward() {
 
     const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
 
-    backward_<float>(output_grad);
-    // Do the actual backward computation
-    // Template is only for scaling parameters, which are always in float
-    // excepted when the convolution is performed in double precision.
-    /*if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
+    if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
         backward_<double>(output_grad);
     }
     else {
         backward_<float>(output_grad);
-    }*/
+    }
 }
 
 template <class T>
 void Aidge::ILayerNormImpl_cuda::backward_(const Tensor& output_grad) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
-    const T * output = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr());
-
     size_t size = output_grad.size();
     std::vector<DimSize_t> dims_input = output_grad.dims();
 
+    const T * output = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr());
+
     T * input_grad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr());
     T * weight_grad = static_cast<T*>(op.getInput(1)->grad()->getImpl()->rawPtr());
     T * bias_grad = static_cast<T*>(op.getInput(2)->grad()->getImpl()->rawPtr());
@@ -144,9 +129,76 @@ void Aidge::ILayerNormImpl_cuda::backward_(const Tensor& output_grad) {
     const T * weight = static_cast<const T*>(op.getInput(1)->getImpl()->rawPtr());
     const T * bias = static_cast<const T*>(op.getInput(2)->getImpl()->rawPtr());
 
-    const T* mean_ = static_cast<const T*>(op.getInput(2)->getImpl()->rawPtr());
-    const T* var_ = static_cast<const T*>(op.getInput(1)->getImpl()->rawPtr());
+    // maybe find a most efficient way to compute mean and variance tensor
+
+    std::vector<std::vector<std::vector<std::vector<T>>>> means(dims_input[0],
+        std::vector<std::vector<std::vector<T>>>(dims_input[1],
+            std::vector<std::vector<T>>(dims_input[2],
+                std::vector<T>(dims_input[3], 0.0f))));
+
+    for (std::size_t i = 0; i < dims_input[0]; i++) {
+        for (std::size_t j = 0; j < dims_input[1]; j++) {
+            for (std::size_t k = 0; k < dims_input[2]; k++) {
+                T sum = 0.0f;
+
+                for (std::size_t l = 0; l < dims_input[3]; l++) {
+                    std::vector<std::size_t> coordIdx = {i, j, k, l};
+                    sum += output_grad.getIdx(coordIdx);
+                }
+                for (std::size_t l = 0; l < dims_input[3]; l++) {
+                    std::vector<std::size_t> coordIdx = {i, j, k, l};
+                    means[i][j][k][l] = sum / static_cast<T>(dims_input[3]);
+                }
+            }
+        }
+    }
+    std::vector<T> flat_means;
+
+    for (const auto &vec3d : means) {
+        for (const auto &vec2d : vec3d) {
+            for (const auto &vec1d : vec2d) {
+                flat_means.insert(flat_means.end(), vec1d.begin(), vec1d.end());
+            }
+        }
+    }
+
+    std::vector<std::vector<std::vector<std::vector<T>>>> vars(dims_input[0],
+        std::vector<std::vector<std::vector<T>>>(dims_input[1],
+            std::vector<std::vector<T>>(dims_input[2],
+                std::vector<T>(dims_input[3], 0.0f))));
+    
+    for (std::size_t i = 0; i < dims_input[0]; i++) {
+        for (std::size_t j = 0; j < dims_input[1]; j++) {
+            for (std::size_t k = 0; k < dims_input[2]; k++) {
+                T sum_sq_diff = 0.0f;
+
+                for (std::size_t l = 0; l < dims_input[3]; l++) {
+                    std::vector<std::size_t> coordIdx = {i, j, k, l};
+                    T value = static_cast<T>(output_grad.getIdx(coordIdx));
+                    T diff = value - means[i][j][k][l];
+                    sum_sq_diff += diff * diff;
+                }
+                T variance = sum_sq_diff / static_cast<T>(dims_input[3]);
+                for (std::size_t l = 0; l < dims_input[3]; l++) {
+                    vars[i][j][k][l] = variance;
+                }
+            }
+        }
+    }
+
+    std::vector<T> flat_vars;
+
+    for (const auto &vec3d : vars) {
+        for (const auto &vec2d : vec3d) {
+            for (const auto &vec1d : vec2d) {
+                flat_vars.insert(flat_vars.end(), vec1d.begin(), vec1d.end());
+            }
+        }
+    }
+
+    const T* mean_ = flat_means.data();
+    const T* var_ = flat_vars.data();
+    const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr());
 
-    const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr());    
-    ILayerNormBackPropagation(output, output_grad_raw, input, mean_, var_, weight, bias, input_grad, weight_grad, bias_grad, size, dims_input);
+    ILayerNormbackward(output, output_grad_raw, input, mean_, var_, weight, bias, input_grad, weight_grad, bias_grad, size);
 }
diff --git a/src/operator/ILayerNormImpl_CUDA_kernels.cu b/src/operator/ILayerNormImpl_CUDA_kernels.cu
index 7f03a1d84aef3c513c1f2e9c8aef130f759484d8..fafdc176fdad6a6130c9bc4374d75f8a773f2c16 100644
--- a/src/operator/ILayerNormImpl_CUDA_kernels.cu
+++ b/src/operator/ILayerNormImpl_CUDA_kernels.cu
@@ -1,13 +1,16 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 Thales
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
  * http://www.eclipse.org/legal/epl-2.0.
  *
  * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 10.09.2024
  *
  ********************************************************************************/
+
 #define MAX(X,Y) (((X) > (Y)) ? (X) : (Y))
 #define CLAMP(X) (((X) < (0)) ? (0) : (X))
 
@@ -18,12 +21,11 @@
 
 namespace Aidge{
 
-
 template <class T>
-__global__ void ILayerNormKernel(T* input, double SF, int* dims, int* quantized_tensor,long long int* square_tensor, T* weight, T* biase, double new_SF) {
-    int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1
-    int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2
-    int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3
+__global__ void ILayerNormforward_(T* input, double SF, int* dims, int* quantized_tensor,long long int* square_tensor, T* weight, T* biase, double new_SF) {
+    int x = blockIdx.x * blockDim.x + threadIdx.x;
+    int y = blockIdx.y * blockDim.y + threadIdx.y;
+    int z = blockIdx.z * blockDim.z + threadIdx.z;
 
     int k = 1 << 16;
     long long int sum = 0;
@@ -56,31 +58,18 @@ __global__ void ILayerNormKernel(T* input, double SF, int* dims, int* quantized_
         int factor = (((1 << 31) - 1) / k);
         for (int i = 0; i < dims[3]; i++) {
             int idx = maxIdx + i;
-            //printf("Value at index before %d: %f\n", idx, input[idx]);
-            //printf("Weight at index before %d: %f\n", idx, weight[idx]);
-            //printf("Bias at index before %d: %f\n", idx, biase[idx]);
             square_tensor[idx] =  (biase[idx]/weight[idx])/new_SF;
             quantized_tensor[idx] = (quantized_tensor[idx] * factor / 2) + biase[maxIdx];
             input[idx] = quantized_tensor[idx] * new_SF;
-            //printf("Weight at index after %d: %f\n", idx, weight[idx]);
-            //printf("Bias at index after %d: %f\n", idx, biase[idx]);
-            //printf("Square at index after %d: %d\n", idx, square_tensor[idx]);
-            //printf("Quantized at index after %d: %d\n", idx, quantized_tensor[idx]);
         }
 
     }
 }
 
 template <>
-void ILayerNormLaunchKernel<float>(const float* input, float* output, double SF, const float* weight_raw, const float* bias_raw, size_t size, std::vector<long unsigned int> dims_input)
+void ILayerNormforward<float>(const float* input, float* output, double SF, const float* weight_raw, const float* bias_raw, size_t size, std::vector<long unsigned int> dims_input)
 {
-    /*
-     * Weight => Matrice de poids pour le layernorm
-     * Biase => Matrice de biais pour le layer norm
-     */
-
-    int dims_input_cuda[4] = {1, 1, 1, 1};  // Default initialization in case dims_input has less than 4 elements
-    //IndexType dims_input_cuda[4];
+    int dims_input_cuda[4] = {1, 1, 1, 1};
     for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
         dims_input_cuda[i] = static_cast<int>(dims_input[i]);
     }
@@ -114,8 +103,7 @@ void ILayerNormLaunchKernel<float>(const float* input, float* output, double SF,
                    (dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
                    (dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);
 
-
-    ILayerNormKernel<float><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,SF,dims,quantized_tensor,Squaretensor,weight,bias,new_SF);
+    ILayerNormforward_<float><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,SF,dims,quantized_tensor,Squaretensor,weight,bias,new_SF);
     cudaDeviceSynchronize();
 
     cudaError_t err = cudaGetLastError();
@@ -134,15 +122,9 @@ void ILayerNormLaunchKernel<float>(const float* input, float* output, double SF,
 }
 
 template <>
-void ILayerNormLaunchKernel<double>(const double* input, double* output, double SF, const double* weight_raw, const double* bias_raw, size_t size, std::vector<long unsigned int> dims_input)
+void ILayerNormforward<double>(const double* input, double* output, double SF, const double* weight_raw, const double* bias_raw, size_t size, std::vector<long unsigned int> dims_input)
 {
-    /*
-     * Weight => Matrice de poids pour le layernorm
-     * Biase => Matrice de biais pour le layer norm
-     */
-
-    int dims_input_cuda[4] = {1, 1, 1, 1};  // Default initialization in case dims_input has less than 4 elements
-    //IndexType dims_input_cuda[4];
+    int dims_input_cuda[4] = {1, 1, 1, 1};
     for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
         dims_input_cuda[i] = static_cast<int>(dims_input[i]);
     }
@@ -176,8 +158,7 @@ void ILayerNormLaunchKernel<double>(const double* input, double* output, double
                    (dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
                    (dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);
 
-    ILayerNormKernel<double><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,SF,dims,quantized_tensor,Squaretensor,weight,bias,new_SF);
-    //T* input, double SF, int* dims, int* quantized_tensor,int* square_tensor, T* weight, T* biase, double new_SF
+    ILayerNormforward_<double><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,SF,dims,quantized_tensor,Squaretensor,weight,bias,new_SF);
     cudaDeviceSynchronize();
 
     cudaError_t err = cudaGetLastError();
@@ -196,54 +177,35 @@ void ILayerNormLaunchKernel<double>(const double* input, double* output, double
 }
 
 template <class T>
-__global__ void ILayerNormBackward(T* d_output, T* x, T* norm_x, T* mean, T* var, T* weight, T* bias, T* grad_x, T* grad_weight, T* grad_bias, int size)
-/*
-    * d_output => Gradient rétropropagé de la fonction.
-    * x => Entrée originale avant la normalisation.
-    * norm_x => Entrée normalisée.
-    * mean => Moyenne des valeurs de l'entrée.
-    * var => Variance des valeurs de l'entrée.
-    * weight => Poids associés à l'entrée.
-    * bias => Biais associés à l'entrée.
-    * size => Taille de l'entrée.
-    */{
+__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size)
+{
     int i = blockIdx.x * blockDim.x + threadIdx.x;
     if (i < size) {
-        T d_norm_x = d_output[i] * weight[i];
-        T d_var = d_norm_x * (x[i] - mean[i]) * -0.5 * powf(var[i] + 1e-6, -1.5);
-        T d_mean = d_norm_x * -1 / sqrtf(var[i] + 1e-6) + d_var * -2 * mean[i] / size;
-        T d_x = d_norm_x / sqrtf(var[i] + 1e-6) + d_var * 2 * (x[i] - mean[i]) / size + d_mean / size;
-
-        grad_x[i] = d_x;                        // Input gradient
-        grad_weight[i] = d_output[i] * norm_x[i]; // Weight gradient
-        grad_bias[i] = d_output[i]; 
+        T d_norm = output_grad[i] * weight[i];
+        T d_var = d_norm * (input_tensor[i] - mean[i]) * -0.5 * powf(var[i] + 1e-6, -1.5);
+        T d_mean = d_norm * -1 / sqrtf(var[i] + 1e-6) + d_var * -2 * mean[i] / size;
+        T d_input = d_norm / sqrtf(var[i] + 1e-6) + d_var * 2 * (input_tensor[i] - mean[i]) / size + d_mean / size;
+
+        input_grad[i] = d_input;
+        weight_grad[i] = output_grad[i] * output_tensor[i];
+        bias_grad[i] = output_grad[i]; 
     }
 }
 
 template <>
-void ILayerNormBackPropagation<float>(const float* InputTensor, const float* Grad, const float* Normalised_Tensor,const float* mean,const float* var, const float* weight, const float* bias, float* grad_input, float* grad_weight, float* grad_bias, size_t size, std::vector<long unsigned int> dims_input)
+void ILayerNormbackward<float>(const float* input_tensor, const float* output_grad, const float* output_tensor,const float* mean,const float* var, const float* weight, const float* bias, float* input_grad, float* weight_grad, float* bias_grad, size_t size)
 {
-    int dims_input_cuda[4];
-    if (dims_input.size() >= 4) {
-        // Fixed-size array to store the first 4 elements
-
-        // Copy the first 4 elements from dims_input to dims_input2
-        for (std::size_t i = 0; i < 4; ++i) {
-            dims_input_cuda[i] = static_cast<int>(dims_input[i]);
-        }
-    } 
-
     float* input_cuda_tensor;
     cudaMalloc(&input_cuda_tensor,size*sizeof(float));
-    cudaMemcpy(input_cuda_tensor,InputTensor,size*sizeof(float),cudaMemcpyHostToDevice);
+    cudaMemcpy(input_cuda_tensor,input_tensor,size*sizeof(float),cudaMemcpyHostToDevice);
 
-    float* grad;
-    cudaMalloc(&grad,size*sizeof(float));
-    cudaMemcpy(grad,Grad,size*sizeof(float),cudaMemcpyHostToDevice);
+    float* output_grad_;
+    cudaMalloc(&output_grad_,size*sizeof(float));
+    cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice);
 
-    float* NormalizedTensor;
-    cudaMalloc(&NormalizedTensor,size*sizeof(float));
-    cudaMemcpy(NormalizedTensor,Normalised_Tensor,size*sizeof(float),cudaMemcpyHostToDevice);
+    float* output_tensor_;
+    cudaMalloc(&output_tensor_,size*sizeof(float));
+    cudaMemcpy(output_tensor_,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice);
 
     float* mean_;
     cudaMalloc(&mean_,size*sizeof(float));
@@ -262,69 +224,58 @@ void ILayerNormBackPropagation<float>(const float* InputTensor, const float* Gra
     cudaMemcpy(bias_,bias,size*sizeof(float),cudaMemcpyHostToDevice);
 
     
-    float* grad_input_;
-    cudaMalloc(&grad_input_,size*sizeof(float));
+    float* input_grad_;
+    cudaMalloc(&input_grad_,size*sizeof(float));
 
-    float* grad_weight_;
-    cudaMalloc(&grad_weight_,size*sizeof(float));
+    float* weight_grad_;
+    cudaMalloc(&weight_grad_,size*sizeof(float));
 
-    float* grad_bias_;
-    cudaMalloc(&grad_bias_,size*sizeof(float));
+    float* bias_grad_;
+    cudaMalloc(&bias_grad_,size*sizeof(float));
 
 
     dim3 threadParBlock(256);
     dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
 
-    ILayerNormBackward<<<Blocks,threadParBlock>>>(grad,input_cuda_tensor,NormalizedTensor,mean_,var_,weight_,bias_,grad_input_, grad_weight_, grad_bias_, size);
+    ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size);
 
     cudaDeviceSynchronize();
     cudaError_t err = cudaGetLastError();
     if(err != cudaSuccess)
     {
-        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
+        printf("CUDA Error: %s\n", cudaGetErrorString(err));
     }
 
-
-    cudaMemcpy(grad_input , grad_input_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
-    cudaMemcpy(grad_weight , grad_weight_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
-    cudaMemcpy(grad_bias , grad_bias_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
+    cudaMemcpy(input_grad , input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
+    cudaMemcpy(weight_grad , weight_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
+    cudaMemcpy(bias_grad , bias_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
 
     cudaFree(input_cuda_tensor);
-    cudaFree(grad);
+    cudaFree(output_grad_);
     cudaFree(mean_);
     cudaFree(var_);
     cudaFree(weight_);
     cudaFree(bias_);
-    cudaFree(grad_input_);
-    cudaFree(grad_weight_);
-    cudaFree(grad_bias_);
+    cudaFree(input_grad_);
+    cudaFree(weight_grad_);
+    cudaFree(bias_grad_);
     
 }
 
 template <>
-void ILayerNormBackPropagation<double>(const double* InputTensor, const double* Grad, const double* Normalised_Tensor,const double* mean,const double* var, const double* weight, const double* bias, double* grad_input, double* grad_weight, double* grad_bias, size_t size, std::vector<long unsigned int> dims_input)
+void ILayerNormbackward<double>(const double* input_tensor, const double* output_grad, const double* output_tensor,const double* mean,const double* var, const double* weight, const double* bias, double* input_grad, double* weight_grad, double* bias_grad, size_t size)
 {
-    int dims_input_cuda[4];
-    if (dims_input.size() >= 4) {
-        // Fixed-size array to store the first 4 elements
-
-        // Copy the first 4 elements from dims_input to dims_input2
-        for (std::size_t i = 0; i < 4; ++i) {
-            dims_input_cuda[i] = static_cast<int>(dims_input[i]);
-        }
-    } 
-
     double* input_cuda_tensor;
     cudaMalloc(&input_cuda_tensor,size*sizeof(double));
-    cudaMemcpy(input_cuda_tensor,InputTensor,size*sizeof(double),cudaMemcpyHostToDevice);
+    cudaMemcpy(input_cuda_tensor,input_tensor,size*sizeof(double),cudaMemcpyHostToDevice);
 
-    double* grad;
-    cudaMalloc(&grad,size*sizeof(double));
-    cudaMemcpy(grad,Grad,size*sizeof(double),cudaMemcpyHostToDevice);
+    double* output_grad_;
+    cudaMalloc(&output_grad_,size*sizeof(double));
+    cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice);
 
-    double* NormalizedTensor;
-    cudaMalloc(&NormalizedTensor,size*sizeof(double));
-    cudaMemcpy(NormalizedTensor,Normalised_Tensor,size*sizeof(double),cudaMemcpyHostToDevice);
+    double* output_tensor_;
+    cudaMalloc(&output_tensor_,size*sizeof(double));
+    cudaMemcpy(output_tensor_,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice);
 
     double* mean_;
     cudaMalloc(&mean_,size*sizeof(double));
@@ -343,43 +294,42 @@ void ILayerNormBackPropagation<double>(const double* InputTensor, const double*
     cudaMemcpy(bias_,bias,size*sizeof(double),cudaMemcpyHostToDevice);
 
     
-    double* grad_input_;
-    cudaMalloc(&grad_input_,size*sizeof(double));
+    double* input_grad_;
+    cudaMalloc(&input_grad_,size*sizeof(double));
 
-    double* grad_weight_;
-    cudaMalloc(&grad_weight_,size*sizeof(double));
+    double* weight_grad_;
+    cudaMalloc(&weight_grad_,size*sizeof(double));
 
-    double* grad_bias_;
-    cudaMalloc(&grad_bias_,size*sizeof(double));
+    double* bias_grad_;
+    cudaMalloc(&bias_grad_,size*sizeof(double));
 
 
     dim3 threadParBlock(256);
     dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
 
-    ILayerNormBackward<<<Blocks,threadParBlock>>>(grad,input_cuda_tensor,NormalizedTensor,mean_,var_,weight_,bias_,grad_input_, grad_weight_, grad_bias_, size);
+    ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size);
 
     cudaDeviceSynchronize();
     cudaError_t err = cudaGetLastError();
     if(err != cudaSuccess)
     {
-        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
+        printf("CUDA Error: %s\n", cudaGetErrorString(err));
     }
 
 
-    cudaMemcpy(grad_input , grad_input_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
-    cudaMemcpy(grad_weight , grad_weight_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
-    cudaMemcpy(grad_bias , grad_bias_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
+    cudaMemcpy(input_grad , input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
+    cudaMemcpy(weight_grad , weight_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
+    cudaMemcpy(bias_grad , bias_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
 
     cudaFree(input_cuda_tensor);
-    cudaFree(grad);
+    cudaFree(output_grad_);
     cudaFree(mean_);
     cudaFree(var_);
     cudaFree(weight_);
     cudaFree(bias_);
-    cudaFree(grad_input_);
-    cudaFree(grad_weight_);
-    cudaFree(grad_bias_);
-    
+    cudaFree(input_grad_);
+    cudaFree(weight_grad_);
+    cudaFree(bias_grad_);
 }
 
 }
\ No newline at end of file
diff --git a/src/operator/ShiftGELUImpl.cpp b/src/operator/ShiftGELUImpl.cpp
index 40b528215f5fcdd24310b8456e4fc9ef606682ef..c2774804d04a422aefd0c66ed0d1fc1d949b1f06 100644
--- a/src/operator/ShiftGELUImpl.cpp
+++ b/src/operator/ShiftGELUImpl.cpp
@@ -41,9 +41,6 @@ void Aidge::ShiftGELUImpl_cuda::forward() {
         case DataType::Float32:
             forward_<float>(input);
             break;
-        case DataType::Float16:
-            forward_<float>(input);
-            break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
     }
@@ -54,13 +51,15 @@ void Aidge::ShiftGELUImpl_cuda::forward_(const Tensor& input)
 {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr());
+    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
 
     int N = 15;
     int output_bits = 8;
-
     size_t size = input.size();
     std::vector<DimSize_t> dims_input = input.dims();
 
+    // maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value)
+
     double min = std::numeric_limits<double>::max();
     double max = std::numeric_limits<double>::min();
     for(std::size_t i = 0; i < dims_input[0]; i++) {
@@ -81,15 +80,13 @@ void Aidge::ShiftGELUImpl_cuda::forward_(const Tensor& input)
     }
 
     double m = std::max(std::abs(min), std::abs(max));
-
-    // Calculate the normalization factor
     double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1;
+    double scaling_factor =  m / normalization_factor;
 
-    // Return the normalized maximum
-    double final_sf =  m / normalization_factor;
-    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
-    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé
-    ShiftGELULaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input);
+    // The new scaling factor that we can use to dequantify the returned tensor (not used here)
+    // double new_SF = 1/std::pow(2,2*output_bits-1);
+
+    ShiftGELUforward(input_raw, output, scaling_factor,N, output_bits, size, dims_input);
 }
 
 void Aidge::ShiftGELUImpl_cuda::backward() {
@@ -99,16 +96,12 @@ void Aidge::ShiftGELUImpl_cuda::backward() {
 
     const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
 
-    backward_<float>(output_grad);
-    // Do the actual backward computation
-    // Template is only for scaling parameters, which are always in float
-    // excepted when the convolution is performed in double precision.
-    /*if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
+    if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
         backward_<double>(output_grad);
     }
     else {
         backward_<float>(output_grad);
-    }*/
+    }
 }
 
 template <class T>
@@ -121,6 +114,6 @@ void Aidge::ShiftGELUImpl_cuda::backward_(const Tensor& output_grad) {
     T * output = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr());
 
     const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr());
-    ShiftGELUBackPropagation(input, output_grad_raw, output, size);
+    ShiftGELUbackward(input, output_grad_raw, output, size);
 
 }
\ No newline at end of file
diff --git a/src/operator/ShiftGELUImpl_CUDA_kernels.cu b/src/operator/ShiftGELUImpl_CUDA_kernels.cu
index 99b95e8fe3b6424f7fda57cc6c2520f30d41d839..aabd89c04e960f9f19eca69247173168d3eaf71e 100644
--- a/src/operator/ShiftGELUImpl_CUDA_kernels.cu
+++ b/src/operator/ShiftGELUImpl_CUDA_kernels.cu
@@ -26,45 +26,35 @@ __device__ inline int ExpShift(int I,int N, double SF)
     int q = floorf(Ip / (I0));
     int r = Ip -(I0*q);
     int Ib = r/2 - I0;
-    Ib = CLAMP(Ib * powf(2,N-q));//BitShift?
+    Ib = CLAMP(Ib * powf(2,N-q));
     return (int)Ib;
 }
 
 namespace Aidge{
 
 template <class T>
-__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits) {
-    /*
- * Kernels du Forward de GeLU
- * Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T)
- * quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU
- * geLUTensor => pointeur vers un bloc mémoire vide alloué sur le GPU
- * SumTensor => pointeur vers un bloc mémoire vide alloué sur le GPU
- * dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs
- * SF => Scaling Factor
- * N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé)
- * output_bits => précision en bit souhaité (8 pour int8 par exemple)
- */
-    int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1
-    int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2
-    int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3
-
-    double SF_sig = SF * 1.702;// SF multiplié par une constante utilisé dans l'algo
+__global__ void ShiftGELUforward_(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits) {
+
+    int x = blockIdx.x * blockDim.x + threadIdx.x;
+    int y = blockIdx.y * blockDim.y + threadIdx.y;
+    int z = blockIdx.z * blockDim.z + threadIdx.z;
+
+    double SF_sig = SF * 1.702;
     double Final_SF = SF / powf(2,(output_bits-1));
 
     if (x < dims[0] && y < dims[1] && z < dims[2]) {
         int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3];
-        for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor)
+        for (int i = 0; i < dims[3]; i++) {
             int idx = maxIdx + i;
             quantized_tensor[idx] = roundf(input[idx] / SF);
         }
         int maxVal = quantized_tensor[maxIdx];
-        for (int i = 1; i < dims[3]; i++) { // Computing max value
+        for (int i = 1; i < dims[3]; i++) {
             int idx = maxIdx + i;
             maxVal = MAX(maxVal, quantized_tensor[idx]);
         }
         int Max_Exp = ExpShift(-maxVal,N,SF_sig);
-        for (int i = 0; i < dims[3]; i++) { //Exponential (artihmetic)
+        for (int i = 0; i < dims[3]; i++) {
             int idx = maxIdx + i;
             GELUtensor[idx] = ExpShift(quantized_tensor[idx] - maxVal,N,SF_sig);
             if(GELUtensor[idx] > INT_MAX - Max_Exp) {
@@ -74,7 +64,6 @@ __global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUten
             {
                 SumTensor[idx] = floorf(INT_MAX/(GELUtensor[idx] + Max_Exp));
             }
-            //SigMoidInt.
             SumTensor[idx] = floorf((GELUtensor[idx] * SumTensor[idx]) >> (31 - output_bits + 1));
             quantized_tensor[idx] *= SumTensor[idx];
             input[idx] = quantized_tensor[idx] * Final_SF;
@@ -83,175 +72,185 @@ __global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUten
 }
 
 template <>
-void ShiftGELULaunchKernel<float>(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
+void ShiftGELUforward<float>(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
 
-    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftgelu, utilisé pour déquantifier le tenseur renvoyé
+    double new_SF = 1/std::pow(2,2*output_bits-1);
 
     int dims_input_cuda[4];
     if (dims_input.size() >= 4) {
-        // Fixed-size array to store the first 4 elements
-
-        // Copy the first 4 elements from dims_input to dims_input2
         for (std::size_t i = 0; i < 4; ++i) {
             dims_input_cuda[i] = static_cast<int>(dims_input[i]);
         }
     } 
 
-
     float* input_cuda_tensor;
-    cudaMalloc(&input_cuda_tensor,size*sizeof(float));                                                      //|
+    cudaMalloc(&input_cuda_tensor,size*sizeof(float));
     cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice);
-                                                                                                                        //|
-    int* quantized_tensor;                                                                                              //|
-    cudaMalloc(&quantized_tensor,size*sizeof(int));                                                         //|
-                                                                                                                        //| => Allocation des blocs mémoire sur le GPU    
+                                                                                                                        
+    int* quantized_tensor;
+    cudaMalloc(&quantized_tensor,size*sizeof(int));
+    
     int* GELUtensor;
     cudaMalloc(&GELUtensor,size*sizeof(int));
 
     int* SumTensor;
-    cudaMalloc(&SumTensor,size*sizeof(int));                                                                 //|
-                                                                                                                        //|
-    int* dims;                                                                                                          //|
+    cudaMalloc(&SumTensor,size*sizeof(int));
+
+    int* dims;
     cudaMalloc(&dims,4*sizeof(int));
 
     cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice);     
 
-    dim3 threadsPerBlock(10, 10, 10);                                       //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur
-    dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,         //| le GPU pour en fonctions des dimensions du tenseur en entrée
-                   (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,         //|
-                   (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);        //|
+    dim3 threadsPerBlock(10, 10, 10);
+    dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,
+                   (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
+                   (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);
 
-    ShiftGELUWholeKernel<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);//Lancement du Kernel
-    cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU
+    ShiftGELUforward_<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);
+    cudaDeviceSynchronize();
 
     cudaError_t err = cudaGetLastError();
     if(err != cudaSuccess)
     {
-        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
-    } //Checks des possibles erreurs sur le GPU
-
-
-    //float* ControlFinal = (float*)malloc(size*sizeof(float));
-    //cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
-    //MyTensor<float> control(ControlFinal,x.dims);
+        printf("CUDA Error: %s\n", cudaGetErrorString(err));
+    }
 
     cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
 
-    cudaFree(quantized_tensor); //|
-    cudaFree(GELUtensor);       //|
-    cudaFree(SumTensor);        //|
-    cudaFree(dims);             //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros)
-    cudaFree(input_cuda_tensor);//|
+    cudaFree(quantized_tensor);
+    cudaFree(GELUtensor);
+    cudaFree(SumTensor);
+    cudaFree(dims);
+    cudaFree(input_cuda_tensor);
 }
 
 template <>
-void ShiftGELULaunchKernel<double>(const double* input, double* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
+void ShiftGELUforward<double>(const double* input, double* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
 
-    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftgelu, utilisé pour déquantifier le tenseur renvoyé
+    double new_SF = 1/std::pow(2,2*output_bits-1);
 
     int dims_input_cuda[4];
     if (dims_input.size() >= 4) {
-        // Fixed-size array to store the first 4 elements
-
-        // Copy the first 4 elements from dims_input to dims_input2
         for (std::size_t i = 0; i < 4; ++i) {
             dims_input_cuda[i] = static_cast<int>(dims_input[i]);
         }
     } 
 
-
     double* input_cuda_tensor;
-    cudaMalloc(&input_cuda_tensor,size*sizeof(double));                                                      //|
+    cudaMalloc(&input_cuda_tensor,size*sizeof(double));
     cudaMemcpy(input_cuda_tensor,input,size*sizeof(double),cudaMemcpyHostToDevice);
-                                                                                                                        //|
-    int* quantized_tensor;                                                                                              //|
-    cudaMalloc(&quantized_tensor,size*sizeof(int));                                                         //|
-                                                                                                                        //| => Allocation des blocs mémoire sur le GPU    
+    
+    int* quantized_tensor;
+    cudaMalloc(&quantized_tensor,size*sizeof(int));
+
     int* GELUtensor;
     cudaMalloc(&GELUtensor,size*sizeof(int));
 
     int* SumTensor;
-    cudaMalloc(&SumTensor,size*sizeof(int));                                                                 //|
-                                                                                                                        //|
-    int* dims;                                                                                                          //|
+    cudaMalloc(&SumTensor,size*sizeof(int));
+
+    int* dims;
     cudaMalloc(&dims,4*sizeof(int));
 
     cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice);     
 
-    dim3 threadsPerBlock(10, 10, 10);                                       //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur
-    dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,         //| le GPU pour en fonctions des dimensions du tenseur en entrée
-                   (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,         //|
-                   (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);        //|
+    dim3 threadsPerBlock(10, 10, 10);
+    dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,
+                   (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
+                   (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);
 
-    ShiftGELUWholeKernel<double><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);//Lancement du Kernel
-    cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU
+    ShiftGELUforward_<double><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);
+    cudaDeviceSynchronize();
 
     cudaError_t err = cudaGetLastError();
     if(err != cudaSuccess)
     {
-        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
-    } //Checks des possibles erreurs sur le GPU
-
-
-    //float* ControlFinal = (float*)malloc(size*sizeof(float));
-    //cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
-    //MyTensor<float> control(ControlFinal,x.dims);
+        printf("CUDA Error: %s\n", cudaGetErrorString(err));
+    }
 
     cudaMemcpy(output,input_cuda_tensor,size*sizeof(double),cudaMemcpyDeviceToHost);
 
-    cudaFree(quantized_tensor); //|
-    cudaFree(GELUtensor);       //|
-    cudaFree(SumTensor);        //|
-    cudaFree(dims);             //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros)
-    cudaFree(input_cuda_tensor);//|
+    cudaFree(quantized_tensor);
+    cudaFree(GELUtensor);
+    cudaFree(SumTensor);
+    cudaFree(dims);
+    cudaFree(input_cuda_tensor);
 }
 
 template <class T>
-__global__ void gelu_backward(T* grad_input, const T* GELU_output, const T* grad_output, int size) {
-    /*
-     * Pareil que pour softmax
-     */
+__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size) {
+
     int index = blockIdx.x * blockDim.x + threadIdx.x;
     if (index < size) {
-        float x = GELU_output[index];
-        float grad = grad_output[index];
+        float x = output_tensor[index];
+        float grad = output_grad[index];
 
         float cdf = 0.5 * (1.0 + tanh(sqrt(2.0 / M_PI) * (x + 0.044715 * pow(x, 3))));
         float pdf = exp(-0.5 * x * x) / sqrt(2.0 * M_PI);
         float dx = pdf + x * cdf;
         float backprop_grad = grad * dx;
-        grad_input[index] = backprop_grad;
+        input_grad[index] = backprop_grad;
     }
 }
 
 template <>
-void ShiftGELUBackPropagation<float>(const float* InputTensor, const float* Grad, float* output, size_t size)
+void ShiftGELUbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size)
 {
-    float* input_cuda_tensor;
-    cudaMalloc(&input_cuda_tensor,size*sizeof(float));
-    cudaMemcpy(input_cuda_tensor,InputTensor,size*sizeof(float),cudaMemcpyHostToDevice);
+    float* output_cuda_tensor;
+    cudaMalloc(&output_cuda_tensor,size*sizeof(float));
+    cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice);
 
-    float* grad;
-    cudaMalloc(&grad,size*sizeof(float));
-    cudaMemcpy(grad,Grad,size*sizeof(float),cudaMemcpyHostToDevice);
+    float* output_grad_;
+    cudaMalloc(&output_grad_,size*sizeof(float));
+    cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice);
 
-    float *out_grad;
-    cudaMalloc(&out_grad, size * sizeof(float));
+    float *input_grad_;
+    cudaMalloc(&input_grad_, size * sizeof(float));
 
     dim3 threadParBlock(256);
     dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
 
-    gelu_backward<float><<<Blocks,threadParBlock>>>(out_grad,input_cuda_tensor,grad,size);
+    ShiftGELUbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size);
     cudaDeviceSynchronize();
     cudaError_t err = cudaGetLastError();
     if(err != cudaSuccess)
     {
-        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
+        printf("CUDA Error: %s\n", cudaGetErrorString(err));
     }
-    cudaMemcpy(output,out_grad, (size) * sizeof(float), cudaMemcpyDeviceToHost);
-    cudaFree(input_cuda_tensor);
-    cudaFree(grad);
-    cudaFree(out_grad);
+    cudaMemcpy(input_grad,input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
+    cudaFree(output_cuda_tensor);
+    cudaFree(input_grad_);
+    cudaFree(output_grad_);
+}
+
+template <>
+void ShiftGELUbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size)
+{
+    double* output_cuda_tensor;
+    cudaMalloc(&output_cuda_tensor,size*sizeof(double));
+    cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice);
+
+    double* output_grad_;
+    cudaMalloc(&output_grad_,size*sizeof(double));
+    cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice);
+
+    double *input_grad_;
+    cudaMalloc(&input_grad_, size * sizeof(double));
+
+    dim3 threadParBlock(256);
+    dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
+
+    ShiftGELUbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size);
+    cudaDeviceSynchronize();
+    cudaError_t err = cudaGetLastError();
+    if(err != cudaSuccess)
+    {
+        printf("CUDA Error: %s\n", cudaGetErrorString(err));
+    }
+    cudaMemcpy(input_grad,input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
+    cudaFree(output_cuda_tensor);
+    cudaFree(input_grad_);
+    cudaFree(output_grad_);
 }
 
 }
\ No newline at end of file
diff --git a/src/operator/ShiftMaxImpl.cpp b/src/operator/ShiftMaxImpl.cpp
index 6abf2d3cb6433dd9f7a8e85636c2fc243ba0dfc6..1134cc5d6b99e53eb492c82e32d811bc0bcba0e0 100644
--- a/src/operator/ShiftMaxImpl.cpp
+++ b/src/operator/ShiftMaxImpl.cpp
@@ -34,7 +34,6 @@ void Aidge::ShiftMaxImpl_cuda::forward() {
     assert(mOp.getRawInput(0) && "missing input #0");
     const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0));
 
-
     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
         case DataType::Float64:
             forward_<double>(input);
@@ -42,9 +41,6 @@ void Aidge::ShiftMaxImpl_cuda::forward() {
         case DataType::Float32:
             forward_<float>(input);
             break;
-        case DataType::Float16:
-            forward_<float>(input);
-            break;
         default:
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
     }
@@ -55,13 +51,15 @@ void Aidge::ShiftMaxImpl_cuda::forward_(const Tensor& input)
 {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr());
+    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
 
     int N = 15;
     int output_bits = 8;
-
     size_t size = input.size();
     std::vector<DimSize_t> dims_input = input.dims();
 
+    // maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value)
+
     double min = std::numeric_limits<double>::max();
     double max = std::numeric_limits<double>::min();
     for(std::size_t i = 0; i < dims_input[0]; i++) {
@@ -82,15 +80,13 @@ void Aidge::ShiftMaxImpl_cuda::forward_(const Tensor& input)
     }
 
     double m = std::max(std::abs(min), std::abs(max));
-
-    // Calculate the normalization factor
     double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1;
+    double scaling_factor =  m / normalization_factor;
+    
+    // The new scaling factor that we can use to dequantify the returned tensor (not used here)
+    // double new_SF = 1/std::pow(2,2*output_bits-1);
 
-    // Return the normalized maximum
-    double final_sf =  m / normalization_factor;
-    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
-    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé
-    ShiftMaxLaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input);
+    ShiftMaxforward(input_raw, output, scaling_factor,N, output_bits, size, dims_input);
 }
 
 
@@ -101,9 +97,6 @@ void Aidge::ShiftMaxImpl_cuda::backward() {
 
     const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
 
-    // Do the actual backward computation
-    // Template is only for scaling parameters, which are always in float
-    // excepted when the convolution is performed in double precision.
     if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
         backward_<double>(output_grad);
     }
@@ -115,25 +108,14 @@ void Aidge::ShiftMaxImpl_cuda::backward() {
 template <class T>
 void Aidge::ShiftMaxImpl_cuda::backward_(const Tensor& output_grad) {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
-    const T * input = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr());
+    const T * output_tensor = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr());
 
     size_t size = output_grad.size();
     std::vector<DimSize_t> dims_output = output_grad.dims();
 
-    //T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(op.getInput(0)->grad()->getImpl()->rawPtr()));
-    T * output = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr());
-    //op.getInput(0)->grad()->getImpl()->rawPtr();
-    //T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr());
+    T * input_grad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr());
 
     const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr());
-    ShiftMaxBackPropagation(input, output_grad_raw, output, size, dims_output);
-
-}
+    ShiftMaxbackward(output_tensor, output_grad_raw, input_grad, size, dims_output);
 
-/*const float* InputTensor, const float* Grad, float* output
-/*
-   * output : Représente le gradient rétropropagé de la fonction.
-   * InputTensor : Sortie de la fonction softmax.
-   * Grad : Gradient à l'entrée de la fonction (avant la rétropropagation), utilisé pour calculer le gradient en entrée.
-   * dims : Dimensions des différents tenseurs (int[4]).
-   */
\ No newline at end of file
+}
\ No newline at end of file
diff --git a/src/operator/ShiftMaxImpl_CUDA_kernels.cu b/src/operator/ShiftMaxImpl_CUDA_kernels.cu
index abf4e7aecea2a964c0e3c1bc132f2f67a4dd892e..ba3cfcb51e02fb0befbf9f7c1fc054e73a2a7157 100644
--- a/src/operator/ShiftMaxImpl_CUDA_kernels.cu
+++ b/src/operator/ShiftMaxImpl_CUDA_kernels.cu
@@ -26,53 +26,38 @@ __device__ inline int ExpShift(int I,int N, double SF)
     int q = floorf(Ip / (I0));
     int r = Ip -(I0*q);
     int Ib = r/2 - I0;
-    Ib = CLAMP(Ib * powf(2,N-q));//BitShift?
+    Ib = CLAMP(Ib * powf(2,N-q));
     return (int)Ib;
 }
 
 namespace Aidge{
 
 template <class T>
-__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF)
-/*
- * Kernels du Forward de Shiftmax
- * Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T)
- * quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU
- * factor => pointeur vers un bloc mémoire vide alloué sur le GPU
- * dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs
- * SF => Scaling Factor
- * N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé)
- * output_bits => précision en bit souhaité (8 pour int8 par exemple)
- * new_SF => Nouveau SF pour déquantifier le tenseur
- */
+__global__ void ShiftMaxforward_(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF)
 {
-    int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1
-    int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2
-    int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3
+    int x = blockIdx.x * blockDim.x + threadIdx.x;
+    int y = blockIdx.y * blockDim.y + threadIdx.y;
+    int z = blockIdx.z * blockDim.z + threadIdx.z;
     int sum = 0;
-    /*
-     * x,y et z représente les indices des dimensions 1,2 et 3 du tenseur, toutes les combinaisons possible de x,y et z
-     * sont appelés en parralèle ce qui permet le speedup GPU
-     * pour iterer dans la derniere dimensions on utilise les boucles "for" ci dessous
-     * */
+
     if (x < dims[0] && y < dims[1] && z < dims[2]) {
         int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3];
-        for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor)
+        for (int i = 0; i < dims[3]; i++) {
             int idx = maxIdx + i;
             quantized_tensor[idx] = roundf(input[idx] / SF);
         }
         int maxVal = quantized_tensor[maxIdx];
-        for (int i = 1; i < dims[3]; i++) { // max value par dimensions 4
+        for (int i = 1; i < dims[3]; i++) {
             int idx = maxIdx + i;
             maxVal = MAX(maxVal, quantized_tensor[idx]);
         }
-        for (int i = 0; i < dims[3]; i++) { //Expo (artihmetic)
+        for (int i = 0; i < dims[3]; i++) {
             int idx = maxIdx + i;
             quantized_tensor[idx] = ExpShift(quantized_tensor[idx]-maxVal,N,SF);
         }
-        for (int i = 0; i < dims[3]; i++) { // Sum et clamp quand dépassement de valeur
+        for (int i = 0; i < dims[3]; i++) {
             int idx = maxIdx + i;
-            if(quantized_tensor[idx] > 0 && sum > INT_MAX - quantized_tensor[idx])//CLAMP(2**31-1)
+            if(quantized_tensor[idx] > 0 && sum > INT_MAX - quantized_tensor[idx])
             {
                 sum = INT_MAX;
                 break;
@@ -82,7 +67,7 @@ __global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor,
             }
         }
         factor[x * dims[1] * dims[2] + y * dims[2] + z] = floorf(INT_MAX/sum);
-        for(int i= 0; i < dims[3]; ++i) //bitshift pour quantifier sur 8 bits
+        for(int i= 0; i < dims[3]; ++i)
         {
             int idx = maxIdx + i;
             quantized_tensor[idx] = (quantized_tensor[idx] * factor[x * dims[1] * dims[2] + y * dims[2] + z]) >> (31-(2*output_bits-1));
@@ -92,13 +77,11 @@ __global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor,
 }
 
 template <>
-void ShiftMaxLaunchKernel<float>(const float* input, float* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
+void ShiftMaxforward<float>(const float* input, float* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
 
     double new_SF = 1 / std::pow(2, 2 * output_bits - 1); // New scaling factor
 
-    // Ensure that dims_input has at least 4 elements
-    int dims_input_cuda[4] = {1, 1, 1, 1};  // Default initialization in case dims_input has less than 4 elements
-    //IndexType dims_input_cuda[4];
+    int dims_input_cuda[4] = {1, 1, 1, 1};
     for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
         dims_input_cuda[i] = static_cast<int>(dims_input[i]);
     }
@@ -127,7 +110,7 @@ void ShiftMaxLaunchKernel<float>(const float* input, float* output, double SF, i
     );
 
     // Launch the kernel (assuming a templated ShiftMaxWholeKernel function exists)
-    ShiftMaxWholeKernel<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor, factor, dims, SF, N, output_bits, new_SF);
+    ShiftMaxforward_<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor, factor, dims, SF, N, output_bits, new_SF);
     cudaDeviceSynchronize();
 
     // Check for CUDA errors
@@ -147,13 +130,11 @@ void ShiftMaxLaunchKernel<float>(const float* input, float* output, double SF, i
 }
 
 template <>
-void ShiftMaxLaunchKernel<double>(const double* input, double* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
+void ShiftMaxforward<double>(const double* input, double* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
 
-    double new_SF = 1 / std::pow(2, 2 * output_bits - 1); // New scaling factor
+    double new_SF = 1 / std::pow(2, 2 * output_bits - 1);
 
-    // Ensure that dims_input has at least 4 elements
-    int dims_input_cuda[4] = {1, 1, 1, 1};  // Default initialization in case dims_input has less than 4 elements
-    //IndexType dims_input_cuda[4];
+    int dims_input_cuda[4] = {1, 1, 1, 1};
     for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
         dims_input_cuda[i] = static_cast<int>(dims_input[i]);
     }
@@ -182,7 +163,7 @@ void ShiftMaxLaunchKernel<double>(const double* input, double* output, double SF
     );
 
     // Launch the kernel (assuming a templated ShiftMaxWholeKernel function exists)
-    ShiftMaxWholeKernel<double><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor, factor, dims, SF, N, output_bits, new_SF);
+    ShiftMaxforward_<double><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor, factor, dims, SF, N, output_bits, new_SF);
     cudaDeviceSynchronize();
 
     // Check for CUDA errors
@@ -203,13 +184,7 @@ void ShiftMaxLaunchKernel<double>(const double* input, double* output, double SF
 
 
 template <class T>
-__global__ void shiftmax_backward(T* grad_input, const T* shiftmax_output, const T* grad_output, const int* dims) {
-    /*
-   * grad_input : Représente le gradient rétropropagé de la fonction.
-   * softmax_output : Sortie de la fonction softmax.
-   * grad_output : Gradient à l'entrée de la fonction (avant la rétropropagation), utilisé pour calculer le gradient en entrée.
-   * dims : Dimensions des différents tenseurs (int[4]).
-   */
+__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims) {
     int index = blockIdx.x * blockDim.x + threadIdx.x;
     if (index < dims[0] * dims[1] * dims[2] * dims[3]) {
         int w = (index / dims[3]) % dims[2];
@@ -218,106 +193,94 @@ __global__ void shiftmax_backward(T* grad_input, const T* shiftmax_output, const
 
         float sum = 0.0f;
         for (int i = 0; i < dims[3]; ++i) {
-            sum += shiftmax_output[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i] * grad_output[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i];
+            sum += output_tensor[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i] * output_grad[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i];
         }
-        grad_input[index] = shiftmax_output[index] * (grad_output[index] - sum);
+        input_grad[index] = output_tensor[index] * (output_grad[index] - sum);
     }
 }
 
 template <>
-void ShiftMaxBackPropagation<float>(const float* InputTensor, const float* Grad, float* output, size_t size, std::vector<long unsigned int> dims_input)
+void ShiftMaxbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size, std::vector<long unsigned int> dims)
 {   
-    /*
-   * output : Représente le gradient rétropropagé de la fonction.
-   * InputTensor : Sortie de la fonction softmax.
-   * Grad : Gradient à l'entrée de la fonction (avant la rétropropagation), utilisé pour calculer le gradient en entrée.
-   * dims : Dimensions des différents tenseurs (int[4]).
-   */
-    int dims_input_cuda[4] = {1, 1, 1, 1};  // Default initialization in case dims_input has less than 4 elements
-    //IndexType dims_input_cuda[4];
-    for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
-        dims_input_cuda[i] = static_cast<int>(dims_input[i]);
+    int dims_input_cuda[4] = {1, 1, 1, 1};
+    for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) {
+        dims_input_cuda[i] = static_cast<int>(dims[i]);
     }
 
-    float* input_cuda_tensor;
-    cudaMalloc(&input_cuda_tensor,size*sizeof(float));
-    cudaMemcpy(input_cuda_tensor,InputTensor,size*sizeof(float),cudaMemcpyHostToDevice);
+    float* output_cuda_tensor;
+    cudaMalloc(&output_cuda_tensor,size*sizeof(float));
+    cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice);
 
-    float* grad;
-    cudaMalloc(&grad,size*sizeof(float));
-    cudaMemcpy(grad,Grad,size*sizeof(float),cudaMemcpyHostToDevice);
+    float* output_grad_;
+    cudaMalloc(&output_grad_,size*sizeof(float));
+    cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice);
 
-    float *out_grad;
-    cudaMalloc(&out_grad, size * sizeof(float));
+    float *input_grad_;
+    cudaMalloc(&input_grad_, size * sizeof(float));
 
-    int *dims;
-    cudaMalloc(&dims, 4 * sizeof(int));
-    cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
+    int *dims_;
+    cudaMalloc(&dims_, 4 * sizeof(int));
+    cudaMemcpy(dims_, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
 
     dim3 threadParBlock(256);
     dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
 
-    shiftmax_backward<float><<<Blocks,threadParBlock>>>(out_grad,input_cuda_tensor,grad,dims);
+    ShiftMaxbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_);
     cudaDeviceSynchronize();
     cudaError_t err = cudaGetLastError();
     if(err != cudaSuccess)
     {
-        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
+        printf("CUDA Error: %s\n", cudaGetErrorString(err));
     }
 
-    cudaMemcpy(output,out_grad, (size) * sizeof(float), cudaMemcpyDeviceToHost);
-    cudaFree(input_cuda_tensor);
-    cudaFree(grad);
-    cudaFree(dims);
-    cudaFree(out_grad);
+    cudaMemcpy(input_grad, input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
+    cudaFree(output_cuda_tensor);
+    cudaFree(input_grad_);
+    cudaFree(dims_);
+    cudaFree(output_grad_);
 }
 
 template <>
-void ShiftMaxBackPropagation<double>(const double* InputTensor, const double* Grad, double* output, size_t size, std::vector<long unsigned int> dims_input)
+void ShiftMaxbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size, std::vector<long unsigned int> dims)
 {   
-    /*
-   * output : Représente le gradient rétropropagé de la fonction.
-   * InputTensor : Sortie de la fonction softmax.
-   * Grad : Gradient à l'entrée de la fonction (avant la rétropropagation), utilisé pour calculer le gradient en entrée.
-   * dims : Dimensions des différents tenseurs (int[4]).
-   */
-    int dims_input_cuda[4] = {1, 1, 1, 1};  // Default initialization in case dims_input has less than 4 elements
-    //IndexType dims_input_cuda[4];
-    for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
-        dims_input_cuda[i] = static_cast<int>(dims_input[i]);
+    int dims_input_cuda[4] = {1, 1, 1, 1};
+    for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) {
+        dims_input_cuda[i] = static_cast<int>(dims[i]);
     }
 
-    double* input_cuda_tensor;
-    cudaMalloc(&input_cuda_tensor,size*sizeof(double));
-    cudaMemcpy(input_cuda_tensor,InputTensor,size*sizeof(double),cudaMemcpyHostToDevice);
+    double* output_cuda_tensor;
+    cudaMalloc(&output_cuda_tensor,size*sizeof(double));
+    cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice);
 
-    double* grad;
-    cudaMalloc(&grad,size*sizeof(double));
-    cudaMemcpy(grad,Grad,size*sizeof(double),cudaMemcpyHostToDevice);
+    double* output_grad_;
+    cudaMalloc(&output_grad_,size*sizeof(double));
+    cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice);
 
-    double *out_grad;
-    cudaMalloc(&out_grad, size * sizeof(double));
+    double *input_grad_;
+    cudaMalloc(&input_grad_, size * sizeof(double));
 
-    int *dims;
-    cudaMalloc(&dims, 4 * sizeof(int));
-    cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
+    int *dims_;
+    cudaMalloc(&dims_, 4 * sizeof(int));
+    cudaMemcpy(dims_, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
 
     dim3 threadParBlock(256);
     dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
 
-    shiftmax_backward<double><<<Blocks,threadParBlock>>>(out_grad,input_cuda_tensor,grad,dims);
+    ShiftMaxbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_);
     cudaDeviceSynchronize();
     cudaError_t err = cudaGetLastError();
     if(err != cudaSuccess)
     {
-        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
+        printf("CUDA Error: %s\n", cudaGetErrorString(err));
     }
 
-    cudaMemcpy(output,out_grad, (size) * sizeof(double), cudaMemcpyDeviceToHost);
-    cudaFree(input_cuda_tensor);
-    cudaFree(grad);
-    cudaFree(dims);
-    cudaFree(out_grad);
+    cudaMemcpy(input_grad,input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
+    cudaFree(output_cuda_tensor);
+    cudaFree(input_grad_);
+    cudaFree(dims_);
+    cudaFree(output_grad_);
 }
 
+
+
 }
\ No newline at end of file
diff --git a/unit_tests/Test_ILayerNormImpl.cpp b/unit_tests/Test_ILayerNormImpl.cpp
index 0f6d61fdf47a6783e468b050285afb874ac26b9f..0487b7c4716596e0d2e7bcbdaf812358be4de3bf 100644
--- a/unit_tests/Test_ILayerNormImpl.cpp
+++ b/unit_tests/Test_ILayerNormImpl.cpp
@@ -1,11 +1,13 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 Thales
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
  * http://www.eclipse.org/legal/epl-2.0.
  *
  * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 10.09.2024
  *
  ********************************************************************************/
 
@@ -51,7 +53,24 @@ TEST_CASE("[gpu/operator] ILayerNorm(forward)", "[ILayerNorm][GPU]") {
 
         std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float, 10>{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}});
         std::shared_ptr<Tensor> myWeight = std::make_shared<Tensor>(Array1D<float, 10>{{0.1617684f, 0.3833238f ,-0.6842308f ,-0.4342245f ,-0.4717381f ,-0.1776187f, -0.2728751f, -0.4638580f, 0.2936697f, -0.9011016f}});
-        
+
+        myWeight->setBackend("cuda");
+        myBias->setBackend("cuda");
+
+        std::shared_ptr<Node> myILayerNorm = ILayerNorm();
+        auto op = std::static_pointer_cast<OperatorTensor>(myILayerNorm -> getOperator());
+
+        op -> associateInput(1, myWeight);
+        op -> associateInput(2, myBias);
+
+        input0->setBackend("cuda");
+
+        op -> associateInput(0,input0);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->forward();
+
+        // expected output
         std::shared_ptr<Tensor> output_ilayernorm = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
         {
             {
@@ -77,27 +96,11 @@ TEST_CASE("[gpu/operator] ILayerNorm(forward)", "[ILayerNorm][GPU]") {
         }
     });
 
-        myWeight->setBackend("cuda");
-        myBias->setBackend("cuda");
-
-        std::shared_ptr<Node> myILayerNorm = ILayerNorm();
-        auto op = std::static_pointer_cast<OperatorTensor>(myILayerNorm -> getOperator());
-
-        op -> associateInput(1, myWeight);
-        op -> associateInput(2, myBias);
-
-        input0->setBackend("cuda");
-
-        op -> associateInput(0,input0);
-        op->setDataType(DataType::Float32);
-        op->setBackend("cuda");
-        op->forward();
-
 
         float* computedOutput   = new float[output_ilayernorm->size()]();
-
         cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_ilayernorm->size(), cudaMemcpyDeviceToHost);
 
+        //test if forward result are as expected
         for(int i = 0; i < output_ilayernorm->size(); i++){
             const float targetOutput = *(static_cast<float*>(output_ilayernorm->getImpl()->rawPtr()) + i);
             REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
@@ -178,7 +181,7 @@ TEST_CASE("[gpu/operator] ILayerNorm(backward)", "[ILayerNorm][GPU]")
             {
                 {
                     {
-                        { 1.04526, 0.637168, 0.337648, 0.923333, 0.13582, 0.39975, 0.117984, -0.187983},
+                        { 0.467678, 0.310749, 0.1129, 0.351786, 0.0507252, 0.101587, 0.130249, -0.0646476},
                     },
                 },
             }
@@ -188,12 +191,11 @@ TEST_CASE("[gpu/operator] ILayerNorm(backward)", "[ILayerNorm][GPU]")
     float *computedInputGradCuda = new float[myOutputGrad->size()]();
     cudaMemcpy(computedInputGradCuda, op->getInput(0)->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost);
 
+    //test if backward result are as expected
     for(int i = 0; i < expectedInputGradILayerNorm->size(); i++){
         const float targetOutput = *(static_cast<float*>(expectedInputGradILayerNorm->getImpl()->rawPtr()) + i);
         REQUIRE(fabs(computedInputGradCuda[i] - targetOutput) < 2e-6);  
     }
 
-
-
     delete[] computedInputGradCuda;
 }
diff --git a/unit_tests/Test_ShiftGELUImpl.cpp b/unit_tests/Test_ShiftGELUImpl.cpp
index 69ee300eedc5b8b2e2893e8bbc0cb722a83316b4..16c8e405496b8dfcb3e7a26e4e536d7403d865ce 100644
--- a/unit_tests/Test_ShiftGELUImpl.cpp
+++ b/unit_tests/Test_ShiftGELUImpl.cpp
@@ -51,6 +51,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
             }
         });
 
+        //expected output of shiftgelu forward operator
         std::shared_ptr<Tensor> output_shiftGELU = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
             {
                 {
@@ -76,6 +77,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
             }
         });
 
+        //expected output of GELU forward operator (computed with PyTorch)
         std::shared_ptr<Tensor> output_GELU = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> {
             {
                 {
@@ -99,7 +101,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
                     }
                 }
             }
-        }); //value given by torch nn GELU
+        });
 
         std::shared_ptr<Node> myShiftGELU = ShiftGELU();
         auto op = std::static_pointer_cast<OperatorTensor>(myShiftGELU -> getOperator());
@@ -111,11 +113,13 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
         float* computedOutput   = new float[output_shiftGELU->size()]();
         cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftGELU->size(), cudaMemcpyDeviceToHost);
 
+        //test if forward result are as expected
         for(int i = 0; i < output_shiftGELU->size(); i++){
             const float targetOutput = *(static_cast<float*>(output_shiftGELU->getImpl()->rawPtr()) + i);
             REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
         }
 
+        //measure difference between GELU and shiftgelu
         float sum = 0.0;
         for(int i = 0; i < output_GELU->size(); i++){
             const float targetOutput = *(static_cast<float*>(output_GELU->getImpl()->rawPtr()) + i);
@@ -170,6 +174,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(backward)", "[ShiftGELU][GPU]")
     predictedOutput->setGrad(myOutputGrad);
     REQUIRE_NOTHROW(myShiftGELU->backward());
 
+    //expected output of shiftgelu backward operator
     std::shared_ptr<Tensor> expectedInputGradShiftGELU = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
             {
                 {
@@ -180,6 +185,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(backward)", "[ShiftGELU][GPU]")
             }
         });
 
+    //expected output of gelu backward operator (computed with PyTorch)
     std::shared_ptr<Tensor> expectedInputGradGELU = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
             {
                 {
@@ -195,12 +201,13 @@ TEST_CASE("[gpu/operator] ShiftGELU(backward)", "[ShiftGELU][GPU]")
 
     cudaMemcpy(computedGradCuda, input->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost);
 
-
+    //test if backward result are as expected
     for(int i = 0; i < expectedInputGradShiftGELU->size(); i++){
         const float targetOutput = *(static_cast<float*>(expectedInputGradShiftGELU->getImpl()->rawPtr()) + i);
         REQUIRE(fabs(computedGradCuda[i] - targetOutput) < 2e-6);  
     }
 
+    //measure difference between gelu and shifgelu
     float sum = 0.0;
         for(int i = 0; i < expectedInputGradGELU->size(); i++){
             const float targetOutput = *(static_cast<float*>(expectedInputGradGELU->getImpl()->rawPtr()) + i);
diff --git a/unit_tests/Test_ShiftMaxImpl.cpp b/unit_tests/Test_ShiftMaxImpl.cpp
index ff1e0e67c8e5e87fc8505262841a522ea9dd9f6a..2a94a23c3a04edd72cb535ebfb6e2c538e4aeee8 100644
--- a/unit_tests/Test_ShiftMaxImpl.cpp
+++ b/unit_tests/Test_ShiftMaxImpl.cpp
@@ -50,6 +50,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
                 }
             }
         });
+        //expected output of shiftmax forward operator
         std::shared_ptr<Tensor> output_shiftmax = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
             {
                 {
@@ -74,6 +75,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
                 }
             }
         });
+        //expected output of softmax forward operator (computed with PyTorch)
         std::shared_ptr<Tensor> output_softmax = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> {
             {
                 {
@@ -97,7 +99,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
                     }
                 }
             }
-        }); //softmax value given by torch softmax
+        });
 
         std::shared_ptr<Node> myShiftMax = ShiftMax();
         auto op = std::static_pointer_cast<OperatorTensor>(myShiftMax -> getOperator());
@@ -109,11 +111,13 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
         float* computedOutput   = new float[output_shiftmax->size()]();
         cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftmax->size(), cudaMemcpyDeviceToHost);
 
+        //test if forward result are as expected
         for(int i = 0; i < output_shiftmax->size(); i++){
             const float targetOutput = *(static_cast<float*>(output_shiftmax->getImpl()->rawPtr()) + i);
             REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
         }
 
+        //measure difference between softmax and shiftmax
         float sum = 0.0;
         for(int i = 0; i < output_softmax->size(); i++){
             const float targetOutput = *(static_cast<float*>(output_softmax->getImpl()->rawPtr()) + i);
@@ -167,6 +171,7 @@ TEST_CASE("[gpu/operator] ShiftMax(backward)", "[ShiftMax][GPU]")
     predictedOutput->setGrad(myOutputGrad);
     REQUIRE_NOTHROW(myShiftMax->backward());
 
+    //expected output of shiftmax backward operator
     std::shared_ptr<Tensor> expectedInputGradShiftMax = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
             {
                 {
@@ -177,6 +182,7 @@ TEST_CASE("[gpu/operator] ShiftMax(backward)", "[ShiftMax][GPU]")
             }
         });
 
+    //expected output of softmax backward operator (computed with PyTorch)
     std::shared_ptr<Tensor> expectedInputGradSoftmax = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
             {
                 {
@@ -192,11 +198,13 @@ TEST_CASE("[gpu/operator] ShiftMax(backward)", "[ShiftMax][GPU]")
 
     cudaMemcpy(computedGradCuda, input->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost);
 
+    //test if backward result are as expected
     for(int i = 0; i < expectedInputGradShiftMax->size(); i++){
         const float targetOutput = *(static_cast<float*>(expectedInputGradShiftMax->getImpl()->rawPtr()) + i);
         REQUIRE(fabs(computedGradCuda[i] - targetOutput) < 1e-6);
     }
 
+    //measure difference between softmax and shiftmax
     float sum = 0.0;
         for(int i = 0; i < expectedInputGradSoftmax->size(); i++){
             const float targetOutput = *(static_cast<float*>(expectedInputGradSoftmax->getImpl()->rawPtr()) + i);