diff --git a/CMakeLists.txt b/CMakeLists.txt
index 01ebb6f258b173aee6df867c5c5c991ec936df57..8dfd9982c7562a92e82212a6b2c9536b6fa5f451 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -75,7 +75,7 @@ target_link_libraries(${module_name}
 )
 
 if( ${ENABLE_ASAN} )
-    message("Building ${module_name} with ASAN.")
+    message("Building ${module_name} with ASAN.")
     set(SANITIZE_FLAGS -fsanitize=address -fno-omit-frame-pointer)
     target_link_libraries(${module_name}
         PUBLIC
diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp
index 580dce246b4c43e9a82fc977103145f79ae0976e..da62b81022550a79d63fa1f20aa9429753e5ab6c 100644
--- a/include/aidge/backend/cuda.hpp
+++ b/include/aidge/backend/cuda.hpp
@@ -22,6 +22,8 @@
 #include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp"
 #include "aidge/backend/cuda/operator/PadImpl.hpp"
 #include "aidge/backend/cuda/operator/ReLUImpl.hpp"
+#include "aidge/backend/cuda/operator/ShiftMaxImpl.hpp"
+#include "aidge/backend/cuda/operator/ShiftGELUImpl.hpp"
 #include "aidge/backend/cuda/operator/ReshapeImpl.hpp"
 #include "aidge/backend/cuda/operator/SigmoidImpl.hpp"
 #include "aidge/backend/cuda/operator/SubImpl.hpp"
diff --git a/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp b/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..c4c6dc6eb57261dd230c023722a131b8858f5951
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp
@@ -0,0 +1,57 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_BACKEND_CUDA_OPERATOR_SHIFTGELUIMPL_H_
+#define AIDGE_BACKEND_CUDA_OPERATOR_SHIFTGELUIMPL_H_
+
+#include <array>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include <cudnn.h>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/ShiftGELU.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+class ShiftGELUImpl_cuda : public OperatorImpl {
+private:
+    std::shared_ptr<Tensor> mInputFallback;
+public:
+    ShiftGELUImpl_cuda(const ShiftGELU_Op &op) : OperatorImpl(op, "cuda") {}
+
+    static std::unique_ptr<ShiftGELUImpl_cuda> create(const ShiftGELU_Op &op) {
+        return std::make_unique<ShiftGELUImpl_cuda>(op);
+    }
+
+public:
+    void forward();
+    //~ShiftGELUImpl_cuda();
+
+private:
+    template <class T> void forward_(const Tensor& input);
+    
+};
+
+namespace {
+// add cuda backend to ShiftGELU_Op implementation registry
+static Registrar<ShiftGELU_Op> registrarShiftGELUImpl_cuda("cuda", Aidge::ShiftGELUImpl_cuda::create);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_BACKEND_CUDA_OPERATOR_SHIFTGELUIMPL_H_ */
\ No newline at end of file
diff --git a/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..cc259366fd3d06744e6f22b13d7ef651cb535d0a
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp
@@ -0,0 +1,34 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_FORWARD_KERNEL_H_
+#define AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_FORWARD_KERNEL_H_
+
+#include <stdexcept>
+#include <cfloat>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+#include <cuda_fp16.h>
+
+#include "aidge/data/Data.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+
+extern void ShiftGELULaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
+
+template <class T>
+__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits);
+}
+
+#endif /* AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_FORWARD_KERNEL_H_ */
\ No newline at end of file
diff --git a/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp b/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..8d72ba0b15cb3d9a91eedab2c2eab1758d0ee00f
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp
@@ -0,0 +1,57 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_BACKEND_CUDA_OPERATOR_SHIFTMAXIMPL_H_
+#define AIDGE_BACKEND_CUDA_OPERATOR_SHIFTMAXIMPL_H_
+
+#include <array>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include <cudnn.h>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/ShiftMax.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+class ShiftMaxImpl_cuda : public OperatorImpl {
+private:
+    std::shared_ptr<Tensor> mInputFallback;
+public:
+    ShiftMaxImpl_cuda(const ShiftMax_Op &op) : OperatorImpl(op, "cuda") {}
+
+    static std::unique_ptr<ShiftMaxImpl_cuda> create(const ShiftMax_Op &op) {
+        return std::make_unique<ShiftMaxImpl_cuda>(op);
+    }
+
+public:
+    void forward();
+    //~ShiftMaxImpl_cuda();
+
+private:
+    template <class T> void forward_(const Tensor& input);
+    
+};
+
+namespace {
+// add cuda backend to ShiftMax_Op implementation registry
+static Registrar<ShiftMax_Op> registrarShiftMaxImpl_cuda("cuda", Aidge::ShiftMaxImpl_cuda::create);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_BACKEND_CUDA_OPERATOR_SHIFTMAXIMPL_H_ */
diff --git a/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..e6f5205c0287c039fedfa88ff05e934c33873b8a
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp
@@ -0,0 +1,34 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_FORWARD_KERNEL_H_
+#define AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_FORWARD_KERNEL_H_
+
+#include <stdexcept>
+#include <cfloat>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+#include <cuda_fp16.h>
+
+#include "aidge/data/Data.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+
+extern void ShiftMaxLaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
+
+template <class T>
+__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF) ;
+}
+
+#endif /* AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_FORWARD_KERNEL_H_ */
\ No newline at end of file
diff --git a/src/operator/ShiftGELUImpl.cpp b/src/operator/ShiftGELUImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..779fbb9175893f80dfa9110ee449e281fe3d5ca6
--- /dev/null
+++ b/src/operator/ShiftGELUImpl.cpp
@@ -0,0 +1,96 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <chrono>  // std::chrono::milliseconds
+#include <numeric> // std::accumulate
+#include <thread>  // std::this_thread::sleep_for
+#include <vector>
+#include <algorithm>  // For std::max
+#include <cmath>      // For pow
+#include <typeinfo>
+
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/ShiftGELUImpl.hpp"
+#include "aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/ShiftGELU.hpp"
+#include "aidge/utils/Types.h"
+
+void Aidge::ShiftGELUImpl_cuda::forward() {
+
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+    assert(mOp.getRawInput(0) && "missing input #0");
+    const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0));
+
+    //forward_<float>(input);
+
+    //should template and changing type
+    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
+        case DataType::Float64:
+            forward_<float>(input);
+            break;
+        case DataType::Float32:
+            forward_<float>(input);
+            break;
+        case DataType::Float16:
+            forward_<float>(input);
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
+    }
+}
+
+template<class T>
+void Aidge::ShiftGELUImpl_cuda::forward_(const Tensor& input)
+{
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+    const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr());
+
+    int N = 15;
+    int output_bits = 8;
+
+    size_t size = input.size();
+    std::vector<DimSize_t> dims_input = input.dims();
+
+    double min = std::numeric_limits<double>::max();
+    double max = std::numeric_limits<double>::min();
+    for(std::size_t i = 0; i < dims_input[0]; i++) {
+        for(std::size_t j = 0; j < dims_input[1]; j++) {
+            for(std::size_t k = 0; k < dims_input[2]; k++) {
+                for(std::size_t l = 0; l < dims_input[3]; l++) {
+                    std::vector<std::size_t> coordIdx = {i, j, k, l};
+                    std::size_t newFlatIdx = input.getIdx(coordIdx);
+                    if (newFlatIdx < min) {
+                        min = newFlatIdx;
+                    }
+                    if (newFlatIdx > max) {
+                        max = newFlatIdx;
+                    }
+               }
+            }     
+        }
+    }
+
+    double m = std::max(std::abs(min), std::abs(max));
+
+    // Calculate the normalization factor
+    double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1;
+
+    // Return the normalized maximum
+    double final_sf =  m / normalization_factor;
+    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
+    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé
+    ShiftGELULaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input);
+}
\ No newline at end of file
diff --git a/src/operator/ShiftGELUImpl_CUDA_kernels.cu b/src/operator/ShiftGELUImpl_CUDA_kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c51af4e51c1270bbe0b5c7cfe658a9d84614c58a
--- /dev/null
+++ b/src/operator/ShiftGELUImpl_CUDA_kernels.cu
@@ -0,0 +1,147 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+#define MAX(X,Y) (((X) > (Y)) ? (X) : (Y))
+#define CLAMP(X) (((X) < (0)) ? (0) : (X))
+
+#include <stdio.h>
+#include <cuda_runtime.h>
+
+#include "aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp"
+
+__device__ inline int ExpShift(int I,int N, double SF)
+{
+    int Ip = I + (I >> 1) - (I >> 4);
+    int I0 = floorf(-1.0/SF);
+    Ip = MAX(Ip,N*I0);
+    int q = floorf(Ip / (I0));
+    int r = Ip -(I0*q);
+    int Ib = r/2 - I0;
+    Ib = CLAMP(Ib * powf(2,N-q));//BitShift?
+    return (int)Ib;
+}
+
+namespace Aidge{
+
+template <class T>
+__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits) {
+    /*
+ * Kernels du Forward de GeLU
+ * Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T)
+ * quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU
+ * geLUTensor => pointeur vers un bloc mémoire vide alloué sur le GPU
+ * SumTensor => pointeur vers un bloc mémoire vide alloué sur le GPU
+ * dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs
+ * SF => Scaling Factor
+ * N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé)
+ * output_bits => précision en bit souhaité (8 pour int8 par exemple)
+ */
+    int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1
+    int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2
+    int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3
+
+    double SF_sig = SF * 1.702;// SF multiplié par une constante utilisé dans l'algo
+    double Final_SF = SF / powf(2,(output_bits-1));
+
+    if (x < dims[0] && y < dims[1] && z < dims[2]) {
+        int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3];
+        for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor)
+            int idx = maxIdx + i;
+            quantized_tensor[idx] = roundf(input[idx] / SF);
+        }
+        int maxVal = quantized_tensor[maxIdx];
+        for (int i = 1; i < dims[3]; i++) { // Computing max value
+            int idx = maxIdx + i;
+            maxVal = MAX(maxVal, quantized_tensor[idx]);
+        }
+        int Max_Exp = ExpShift(-maxVal,N,SF_sig);
+        for (int i = 0; i < dims[3]; i++) { //Exponential (artihmetic)
+            int idx = maxIdx + i;
+            GELUtensor[idx] = ExpShift(quantized_tensor[idx] - maxVal,N,SF_sig);
+            if(GELUtensor[idx] > INT_MAX - Max_Exp) {
+                SumTensor[idx] = 1;
+            }
+            else
+            {
+                SumTensor[idx] = floorf(INT_MAX/(GELUtensor[idx] + Max_Exp));
+            }
+            //SigMoidInt.
+            SumTensor[idx] = floorf((GELUtensor[idx] * SumTensor[idx]) >> (31 - output_bits + 1));
+            quantized_tensor[idx] *= SumTensor[idx];
+            input[idx] = quantized_tensor[idx] * Final_SF;
+        }
+    }
+}
+
+//TODO Template
+void ShiftGELULaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
+
+    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftgelu, utilisé pour déquantifier le tenseur renvoyé
+
+    int dims_input_cuda[4];
+    if (dims_input.size() >= 4) {
+        // Fixed-size array to store the first 4 elements
+
+        // Copy the first 4 elements from dims_input to dims_input2
+        for (std::size_t i = 0; i < 4; ++i) {
+            dims_input_cuda[i] = static_cast<int>(dims_input[i]);
+        }
+    } 
+
+
+    float* input_cuda_tensor;
+    cudaMalloc(&input_cuda_tensor,size*sizeof(float));                                                      //|
+    cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice);
+                                                                                                                        //|
+    int* quantized_tensor;                                                                                              //|
+    cudaMalloc(&quantized_tensor,size*sizeof(int));                                                         //|
+                                                                                                                        //| => Allocation des blocs mémoire sur le GPU    
+    int* GELUtensor;
+    cudaMalloc(&GELUtensor,size*sizeof(int));
+
+    int* SumTensor;
+    cudaMalloc(&SumTensor,size*sizeof(int));                                                                 //|
+                                                                                                                        //|
+    int* dims;                                                                                                          //|
+    cudaMalloc(&dims,4*sizeof(int));
+
+    cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice);     
+
+    dim3 threadsPerBlock(10, 10, 10);                                       //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur
+    dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,         //| le GPU pour en fonctions des dimensions du tenseur en entrée
+                   (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,         //|
+                   (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);        //|
+
+    ShiftGELUWholeKernel<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);//Lancement du Kernel
+    cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU
+
+    cudaError_t err = cudaGetLastError();
+    if(err != cudaSuccess)
+    {
+        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
+    } //Checks des possibles erreurs sur le GPU
+
+
+    //float* ControlFinal = (float*)malloc(size*sizeof(float));
+    //cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
+    //MyTensor<float> control(ControlFinal,x.dims);
+
+    cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
+
+    cudaFree(quantized_tensor); //|
+    cudaFree(GELUtensor);       //|
+    cudaFree(SumTensor);        //|
+    cudaFree(dims);             //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros)
+    cudaFree(input_cuda_tensor);//|
+}
+
+}
\ No newline at end of file
diff --git a/src/operator/ShiftMaxImpl.cpp b/src/operator/ShiftMaxImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..470bac840969eb114d7b7e34e51fb9e733d41ae9
--- /dev/null
+++ b/src/operator/ShiftMaxImpl.cpp
@@ -0,0 +1,96 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <chrono>  // std::chrono::milliseconds
+#include <numeric> // std::accumulate
+#include <thread>  // std::this_thread::sleep_for
+#include <vector>
+#include <algorithm>  // For std::max
+#include <cmath>      // For pow
+#include <typeinfo>
+
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/ShiftMaxImpl.hpp"
+#include "aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/ShiftMax.hpp"
+#include "aidge/utils/Types.h"
+
+void Aidge::ShiftMaxImpl_cuda::forward() {
+
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+    assert(mOp.getRawInput(0) && "missing input #0");
+    const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0));
+
+    //forward_<float>(input);
+
+    //should template and changing type
+    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
+        case DataType::Float64:
+            forward_<float>(input);
+            break;
+        case DataType::Float32:
+            forward_<float>(input);
+            break;
+        case DataType::Float16:
+            forward_<float>(input);
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
+    }
+}
+
+template<class T>
+void Aidge::ShiftMaxImpl_cuda::forward_(const Tensor& input)
+{
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+    const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr());
+
+    int N = 15;
+    int output_bits = 8;
+
+    size_t size = input.size();
+    std::vector<DimSize_t> dims_input = input.dims();
+
+    double min = std::numeric_limits<double>::max();
+    double max = std::numeric_limits<double>::min();
+    for(std::size_t i = 0; i < dims_input[0]; i++) {
+        for(std::size_t j = 0; j < dims_input[1]; j++) {
+            for(std::size_t k = 0; k < dims_input[2]; k++) {
+                for(std::size_t l = 0; l < dims_input[3]; l++) {
+                    std::vector<std::size_t> coordIdx = {i, j, k, l};
+                    std::size_t newFlatIdx = input.getIdx(coordIdx);
+                    if (newFlatIdx < min) {
+                        min = newFlatIdx;
+                    }
+                    if (newFlatIdx > max) {
+                        max = newFlatIdx;
+                    }
+               }
+            }     
+        }
+    }
+
+    double m = std::max(std::abs(min), std::abs(max));
+
+    // Calculate the normalization factor
+    double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1;
+
+    // Return the normalized maximum
+    double final_sf =  m / normalization_factor;
+    T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
+    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé
+    ShiftMaxLaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input);
+}
\ No newline at end of file
diff --git a/src/operator/ShiftMaxImpl_CUDA_kernels.cu b/src/operator/ShiftMaxImpl_CUDA_kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c4c619cdde1a1ee123f19a3d75f3f14bc542bc52
--- /dev/null
+++ b/src/operator/ShiftMaxImpl_CUDA_kernels.cu
@@ -0,0 +1,152 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+#define MAX(X,Y) (((X) > (Y)) ? (X) : (Y))
+#define CLAMP(X) (((X) < (0)) ? (0) : (X))
+
+#include <stdio.h>
+#include <cuda_runtime.h>
+
+#include "aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp"
+
+__device__ inline int ExpShift(int I,int N, double SF)
+{
+    int Ip = I + (I >> 1) - (I >> 4);
+    int I0 = floorf(-1.0/SF);
+    Ip = MAX(Ip,N*I0);
+    int q = floorf(Ip / (I0));
+    int r = Ip -(I0*q);
+    int Ib = r/2 - I0;
+    Ib = CLAMP(Ib * powf(2,N-q));//BitShift?
+    return (int)Ib;
+}
+
+namespace Aidge{
+
+template <class T>
+__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF)
+/*
+ * Kernels du Forward de Shiftmax
+ * Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T)
+ * quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU
+ * factor => pointeur vers un bloc mémoire vide alloué sur le GPU
+ * dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs
+ * SF => Scaling Factor
+ * N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé)
+ * output_bits => précision en bit souhaité (8 pour int8 par exemple)
+ * new_SF => Nouveau SF pour déquantifier le tenseur
+ */
+{
+    int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1
+    int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2
+    int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3
+    int sum = 0;
+    /*
+     * x,y et z représente les indices des dimensions 1,2 et 3 du tenseur, toutes les combinaisons possible de x,y et z
+     * sont appelés en parralèle ce qui permet le speedup GPU
+     * pour iterer dans la derniere dimensions on utilise les boucles "for" ci dessous
+     * */
+    if (x < dims[0] && y < dims[1] && z < dims[2]) {
+        int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3];
+        for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor)
+            int idx = maxIdx + i;
+            quantized_tensor[idx] = roundf(input[idx] / SF);
+        }
+        int maxVal = quantized_tensor[maxIdx];
+        for (int i = 1; i < dims[3]; i++) { // max value par dimensions 4
+            int idx = maxIdx + i;
+            maxVal = MAX(maxVal, quantized_tensor[idx]);
+        }
+        for (int i = 0; i < dims[3]; i++) { //Expo (artihmetic)
+            int idx = maxIdx + i;
+            quantized_tensor[idx] = ExpShift(quantized_tensor[idx]-maxVal,N,SF);
+        }
+        for (int i = 0; i < dims[3]; i++) { // Sum et clamp quand dépassement de valeur
+            int idx = maxIdx + i;
+            if(quantized_tensor[idx] > 0 && sum > INT_MAX - quantized_tensor[idx])//CLAMP(2**31-1)
+            {
+                sum = INT_MAX;
+                break;
+            }
+            else {
+                sum += quantized_tensor[idx];
+            }
+        }
+        factor[x * dims[1] * dims[2] + y * dims[2] + z] = floorf(INT_MAX/sum);
+        for(int i= 0; i < dims[3]; ++i) //bitshift pour quantifier sur 8 bits
+        {
+            int idx = maxIdx + i;
+            quantized_tensor[idx] = (quantized_tensor[idx] * factor[x * dims[1] * dims[2] + y * dims[2] + z]) >> (31-(2*output_bits-1));
+            input[idx] =quantized_tensor[idx]*new_SF;
+        }
+    }
+}
+
+//TODO Template
+void ShiftMaxLaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
+
+    double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé
+
+    int dims_input_cuda[4];
+    if (dims_input.size() >= 4) {
+        // Fixed-size array to store the first 4 elements
+
+        // Copy the first 4 elements from dims_input to dims_input2
+        for (std::size_t i = 0; i < 4; ++i) {
+            dims_input_cuda[i] = static_cast<int>(dims_input[i]);
+        }
+    } 
+
+
+    float* input_cuda_tensor;
+    cudaMalloc(&input_cuda_tensor,size*sizeof(float));                                                      //|
+    cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice);
+                                                                                                                        //|
+    int* quantized_tensor;                                                                                              //|
+    cudaMalloc(&quantized_tensor,size*sizeof(int));                                                         //|
+                                                                                                                        //| => Allocation des blocs mémoire sur le GPU
+    int* factor;                                                                                                        //|
+    cudaMalloc(&factor,size*sizeof(int));                                                                   //|
+                                                                                                                        //|
+    int* dims;                                                                                                          //|
+    cudaMalloc(&dims,4*sizeof(int));
+
+    cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice);     
+
+    dim3 threadsPerBlock(10, 10, 10);                                       //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur
+    dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,         //| le GPU pour en fonctions des dimensions du tenseur en entrée
+                   (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,         //|
+                   (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);        //|
+
+    ShiftMaxWholeKernel<float><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,quantized_tensor,factor,dims,SF,N,output_bits,new_SF);//Lancement du Kernel
+    cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU
+
+    cudaError_t err = cudaGetLastError();
+    if(err != cudaSuccess)
+    {
+        printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
+    } //Checks des possibles erreurs sur le GPU
+
+
+    //float* ControlFinal = (float*)malloc(size*sizeof(float));
+    //cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
+    //MyTensor<float> control(ControlFinal,x.dims);
+
+    cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
+
+    cudaFree(quantized_tensor); //|
+    cudaFree(factor);           //|
+    cudaFree(dims);             //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros)
+    cudaFree(input_cuda_tensor);//|
+}
+
+}
\ No newline at end of file
diff --git a/unit_tests/Test_ShiftGELUImpl.cpp b/unit_tests/Test_ShiftGELUImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d5382c29cef00587eaed3dcd789352c4e8263d31
--- /dev/null
+++ b/unit_tests/Test_ShiftGELUImpl.cpp
@@ -0,0 +1,131 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+
+#include <array>
+
+#include <catch2/catch_test_macros.hpp>
+
+#include "Test_cuda.hpp"
+
+#include "aidge/data/Tensor.hpp"
+
+#include "aidge/backend/cpu.hpp"
+#include "aidge/backend/cuda.hpp"
+
+using namespace Aidge;
+
+TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
+    SECTION("4D Tensor") {
+        std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
+            {
+                {
+                    {
+                        {0.96, 0.48, 0.54, 0.49, 0.59, 0.93, 0.00, 0.00, 0.61, 0.61},
+                        {0.85, 0.06, 0.11, 0.87, 0.55, 0.12, 0.80, 0.48, 0.41, 0.16}
+                    },
+                    {
+                        {0.24, 0.46, 0.97, 0.19, 0.65, 0.12, 0.44, 1.00, 0.37, 0.09},
+                        {0.44, 0.64, 0.21, 0.58, 0.05, 0.24, 0.56, 0.07, 0.49, 0.79}
+                    }
+                },
+                {
+                    {
+                        {0.00, 0.13, 0.55, 0.42, 0.49, 0.28, 0.52, 0.55, 0.34, 0.85},
+                        {0.98, 0.32, 0.09, 0.05, 0.37, 0.47, 0.63, 0.13, 0.70, 0.02}
+                    },
+                    {
+                        {0.69, 0.13, 0.74, 0.61, 0.25, 0.87, 0.46, 0.40, 0.81, 0.06},
+                        {0.89, 0.32, 0.61, 0.24, 0.70, 0.23, 0.09, 0.03, 0.14, 0.80}
+                    }
+                }
+            }
+        });
+
+        std::shared_ptr<Tensor> output_shiftGELU = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
+            {
+                {
+                    {
+                        { 0.991388f, 0.413078f, 0.413078f, 0.413078f, 0.413078f, 0.413078f, 0.0f, 0.0f, 0.413078f, 0.413078f },
+                        { 0.413078f, 0.0f, 0.0f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.0f }
+                    },
+                    {
+                        { 0.0f, 0.413078f, 0.991388f, 0.0f, 0.413078f, 0.0f, 0.413078f, 0.991388f, 0.413078f, 0.0f },
+                        { 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.0f, 0.0f, 0.413078f, 0.0f, 0.413078f, 0.413078f }
+                    }
+                },
+                {
+                    {
+                        { 0.0f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.413078f },
+                        { 0.991388f, 0.413078f, 0.0f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.0f}
+                    },
+                    {
+                        { 0.413078f, 0.0f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.413078f, 0.0f },
+                        { 0.413078f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.0f, 0.0f, 0.0f, 0.0f, 0.413078f }
+                    }
+                }
+            }
+        });
+
+        std::shared_ptr<Tensor> output_GELU = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> {
+            {
+                {
+                    {
+                        { 0.7982f, 0.3285f, 0.3809f, 0.3371f, 0.4262f, 0.7661f, 0.0000f, 0.0000f, 0.4447f, 0.4447f },
+                        { 0.6820f, 0.0314f, 0.0598f, 0.7028f, 0.3899f, 0.0657f, 0.6305f, 0.3285f, 0.2702f, 0.0902f }
+                    },
+                    {
+                        { 0.1428f, 0.3115f, 0.8090f, 0.1093f, 0.4824f, 0.0657f, 0.2948f, 0.8413f, 0.2384f, 0.0482f },
+                        { 0.2948f, 0.4729f, 0.1225f, 0.4170f, 0.0260f, 0.1428f, 0.3989f, 0.0370f, 0.3371f, 0.6203f }
+                    }
+                },
+                {
+                    {
+                        { 0.0000f, 0.0717f, 0.3899f, 0.2784f, 0.3371f, 0.1709f, 0.3632f, 0.3899f, 0.2152f, 0.6820f },
+                        { 0.8197f, 0.2002f, 0.0482f, 0.0260f, 0.2384f, 0.3200f, 0.4635f, 0.0717f, 0.5306f, 0.0102f }
+                    },
+                    {
+                        { 0.5209f, 0.0717f, 0.5701f, 0.4447f, 0.1497f, 0.7028f, 0.3115f, 0.2622f, 0.6407f, 0.0314f },
+                        { 0.7238f, 0.2002f, 0.4447f, 0.1428f, 0.5306f, 0.1359f, 0.0482f, 0.0154f, 0.0778f, 0.6305f }
+                    }
+                }
+            }
+        }); //value given by torch nn GELU
+
+        std::shared_ptr<Node> myShiftGELU = ShiftGELU();
+        auto op = std::static_pointer_cast<OperatorTensor>(myShiftGELU -> getOperator());
+        op->associateInput(0,input0);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->forward();
+        
+        float* computedOutput   = new float[output_shiftGELU->size()]();
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftGELU->size(), cudaMemcpyDeviceToHost);
+
+        for(int i = 0; i < output_shiftGELU->size(); i++){
+            const float targetOutput = *(static_cast<float*>(output_shiftGELU->getImpl()->rawPtr()) + i);
+            REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
+        }
+
+        float sum = 0.0;
+        for(int i = 0; i < output_GELU->size(); i++){
+            const float targetOutput = *(static_cast<float*>(output_GELU->getImpl()->rawPtr()) + i);
+            sum += fabs(computedOutput[i] - targetOutput);
+        }
+        sum = sum / output_GELU->size();
+        std::cout << sum << "\n";
+        REQUIRE(sum < 1.5e-1);
+
+        delete[] computedOutput;
+    }
+
+}
\ No newline at end of file
diff --git a/unit_tests/Test_ShiftMaxImpl.cpp b/unit_tests/Test_ShiftMaxImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a9f8b2a665d9b77f8ba6ce7c4d251f9b6c0da166
--- /dev/null
+++ b/unit_tests/Test_ShiftMaxImpl.cpp
@@ -0,0 +1,128 @@
+/********************************************************************************
+ * Copyright (c) 2024 Thales
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
+ * Date: 25.06.2024
+ *
+ ********************************************************************************/
+
+#include <array>
+
+#include <catch2/catch_test_macros.hpp>
+
+#include "Test_cuda.hpp"
+
+#include "aidge/data/Tensor.hpp"
+
+#include "aidge/backend/cpu.hpp"
+#include "aidge/backend/cuda.hpp"
+
+using namespace Aidge;
+
+TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
+    SECTION("4D Tensor") {
+        std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
+            {
+                {
+                    {
+                        {0.96, 0.48, 0.54, 0.49, 0.59, 0.93, 0.00, 0.00, 0.61, 0.61},
+                        {0.85, 0.06, 0.11, 0.87, 0.55, 0.12, 0.80, 0.48, 0.41, 0.16}
+                    },
+                    {
+                        {0.24, 0.46, 0.97, 0.19, 0.65, 0.12, 0.44, 1.00, 0.37, 0.09},
+                        {0.44, 0.64, 0.21, 0.58, 0.05, 0.24, 0.56, 0.07, 0.49, 0.79}
+                    }
+                },
+                {
+                    {
+                        {0.00, 0.13, 0.55, 0.42, 0.49, 0.28, 0.52, 0.55, 0.34, 0.85},
+                        {0.98, 0.32, 0.09, 0.05, 0.37, 0.47, 0.63, 0.13, 0.70, 0.02}
+                    },
+                    {
+                        {0.69, 0.13, 0.74, 0.61, 0.25, 0.87, 0.46, 0.40, 0.81, 0.06},
+                        {0.89, 0.32, 0.61, 0.24, 0.70, 0.23, 0.09, 0.03, 0.14, 0.80}
+                    }
+                }
+            }
+        });
+        std::shared_ptr<Tensor> output_shiftmax = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
+            {
+                {
+                    {
+                        { 0.111084f, 0.111084f, 0.111084f, 0.111084f, 0.111084f, 0.111084f, 0.055542f, 0.055542f, 0.111084f, 0.111084f },
+                        { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f }
+                    },
+                    {
+                        { 0.0624695f, 0.124969f, 0.124969f, 0.0624695f, 0.124969f, 0.0624695f, 0.124969f, 0.124969f, 0.124969f, 0.0624695f },
+                        { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f }
+                    }
+                },
+                {
+                    {
+                        { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f },
+                        { 0.124969f, 0.124969f, 0.0624695f, 0.0624695f, 0.124969f, 0.124969f, 0.124969f, 0.0624695f, 0.124969f, 0.0624695f }
+                    },
+                    {
+                        { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f },
+                        { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f }
+                    }
+                }
+            }
+        });
+        std::shared_ptr<Tensor> output_softmax = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> {
+            {
+                {
+                    {
+                        { 0.1484f, 0.0918f, 0.0975f, 0.0928f, 0.1025f, 0.1440f, 0.0568f, 0.0568f, 0.1046f, 0.1046f },
+                        { 0.1436f, 0.0652f, 0.0685f, 0.1465f, 0.1064f, 0.0692f, 0.1366f, 0.0992f, 0.0925f, 0.0721f }
+                    },
+                    {
+                        { 0.0768f, 0.0957f, 0.1593f, 0.0730f, 0.1157f, 0.0681f, 0.0938f, 0.1642f, 0.0874f, 0.0661f },
+                        { 0.1005f, 0.1227f, 0.0798f, 0.1156f, 0.0680f, 0.0823f, 0.1133f, 0.0694f, 0.1056f, 0.1426f }
+                    }
+                },
+                {
+                    {
+                        { 0.0645f, 0.0734f, 0.1118f, 0.0981f, 0.1052f, 0.0853f, 0.1085f, 0.1118f, 0.0906f, 0.1509f },
+                        { 0.1743f, 0.0901f, 0.0716f, 0.0688f, 0.0947f, 0.1047f, 0.1228f, 0.0745f, 0.1317f, 0.0667f }
+                    },
+                    {
+                        { 0.1164f, 0.0665f, 0.1224f, 0.1075f, 0.0750f, 0.1394f, 0.0925f, 0.0871f, 0.1313f, 0.0620f },
+                        { 0.1551f, 0.0877f, 0.1172f, 0.0810f, 0.1283f, 0.0802f, 0.0697f, 0.0656f, 0.0733f, 0.1418f }
+                    }
+                }
+            }
+        }); //softmax value given by torch softmax
+
+        std::shared_ptr<Node> myShiftMax = ShiftMax();
+        auto op = std::static_pointer_cast<OperatorTensor>(myShiftMax -> getOperator());
+        op->associateInput(0,input0);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->forward();
+        
+        float* computedOutput   = new float[output_shiftmax->size()]();
+        cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftmax->size(), cudaMemcpyDeviceToHost);
+
+        for(int i = 0; i < output_shiftmax->size(); i++){
+            const float targetOutput = *(static_cast<float*>(output_shiftmax->getImpl()->rawPtr()) + i);
+            REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
+        }
+
+        float sum = 0.0;
+        for(int i = 0; i < output_softmax->size(); i++){
+            const float targetOutput = *(static_cast<float*>(output_softmax->getImpl()->rawPtr()) + i);
+            sum += fabs(computedOutput[i] - targetOutput);
+        }
+        sum = sum / output_softmax->size();
+        REQUIRE(sum < 4e-2);
+
+        delete[] computedOutput;
+    }
+
+}
\ No newline at end of file