diff --git a/src/operator/BitShitImpl.cpp b/src/operator/BitShitImpl.cpp
index d453eb4d01ab359f9ab2db6f01b98a77e3e0d93f..3a5db9b1439e36ad866d40ff6928f4424d19a182 100644
--- a/src/operator/BitShitImpl.cpp
+++ b/src/operator/BitShitImpl.cpp
@@ -39,8 +39,14 @@
      std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
      std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
      for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
-         inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
- 
+        // TODO: remove the forced cast to int64
+        const auto dt = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
+        if(dt == DataType::Float32 || dt == DataType::Float64 ) {
+            inputs[i] = op.getInput(i)->refCast(inputFallbacks[i], DataType::Int32);
+        }
+        else {
+            inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
+        } 
          // Get tensor dims and broadcast them
          std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
          dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
@@ -59,7 +65,7 @@
          strides[i] = tensorStrides;
      }
      bool left = op.direction() == BitShift_Op::BitShiftDirection::left;
-     switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
+     switch(inputs[0].dataType()) {
          case DataType::Int64:
              forward_<int64_t>(inputs, dims, strides, left);
              break;
@@ -78,21 +84,36 @@
      // const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
      const T * input1Ptr = static_cast<const T*>(inputs[0].getImpl()->rawPtr());
      const T * input2Ptr = static_cast<const T*>(inputs[1].getImpl()->rawPtr());
-     T * outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr());
- 
-     std::vector<int> outputStrides(op.getOutput(0)->nbDims(), 1);
-     if(op.getOutput(0)->nbDims()>1) {
-         for (int i = op.getOutput(0)->nbDims()-2; i >= 0; i--) {
-             outputStrides[i] = outputStrides[i+1] *  op.getOutput(0)->dims()[i+1];
-         }
-     }
-     std::vector<int> outDims(std::max(op.getOutput(0)->nbDims(),std::size_t(4)), 1);
-     for (std::size_t i = 0; i < op.getOutput(0)->nbDims(); i++) {
-         outDims[i] = static_cast<int>(op.getOutput(0)->dims()[i]);
-     }
- 
-     Aidge::bitShiftForward<T>(input1Ptr, input2Ptr, outputPtr,
-                 inputsDims[0], inputsDims[1], outDims,
-                 inputsStrides[0], inputsStrides[1], outputStrides,
-                 static_cast<int>(op.getOutput(0)->size()), left);
- }
+    //  T * outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr());
+
+    std::shared_ptr<Tensor> outputFallback;
+    const auto dt = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
+    Tensor outputCasted;
+    if(dt == DataType::Float32 || dt == DataType::Float64 ) {
+        outputCasted = op.getOutput(0)->refCastFrom(outputFallback, DataType::Int32, "cuda", op.getOutput(0)->device());
+    } else {
+        outputCasted = op.getOutput(0)->refCastFrom(outputFallback, *op.getOutput(0));
+    }
+    std::vector<int> outputStrides(op.getOutput(0)->nbDims(), 1);
+    if(op.getOutput(0)->nbDims()>1) {
+        for (int i = op.getOutput(0)->nbDims()-2; i >= 0; i--) {
+            outputStrides[i] = outputStrides[i+1] *  op.getOutput(0)->dims()[i+1];
+        }
+    }
+    std::vector<int> outDims(std::max(op.getOutput(0)->nbDims(),std::size_t(4)), 1);
+    for (std::size_t i = 0; i < op.getOutput(0)->nbDims(); i++) {
+        outDims[i] = static_cast<int>(op.getOutput(0)->dims()[i]);
+    }
+
+    Aidge::bitShiftForward<T>(input1Ptr, input2Ptr, static_cast<T*>(outputCasted.getImpl()->rawPtr()),
+                inputsDims[0], inputsDims[1], outDims,
+                inputsStrides[0], inputsStrides[1], outputStrides,
+                static_cast<int>(op.getOutput(0)->size()), left);
+
+   if(dt == DataType::Float32 || dt == DataType::Float64 ) {
+       op.getOutput(0)->getImpl()->copyCast(outputCasted.getImpl()->rawPtr(),DataType::Int32, outputCasted.size());
+   }else {
+       // op.getOutput(0)->getImpl()->copy(outputCasted.getImpl()->rawPtr(),outputCasted.size());
+       CHECK_CUDA_STATUS(cudaMemcpy(op.getOutput(0)->getImpl()->rawPtr(), outputCasted.getImpl()->rawPtr(), outputCasted.size() * sizeof(int), cudaMemcpyDeviceToDevice));
+   }
+}
diff --git a/unit_tests/Test_BitShiftImpl.cpp b/unit_tests/Test_BitShiftImpl.cpp
index 454e969a7a825971f7f6c4daa26c00812ac970fb..9b111496dd35dba482bbb1a437b3e7c2b164fbdf 100644
--- a/unit_tests/Test_BitShiftImpl.cpp
+++ b/unit_tests/Test_BitShiftImpl.cpp
@@ -38,6 +38,7 @@ using namespace Aidge;
 
 TEST_CASE("[gpu/operator] BitShift(forward)", "[BitShift][GPU]")
 {
+    SECTION("Int") {
     constexpr std::uint16_t NBTRIALS = 15;
     // Create a random number generator
     std::random_device rd;
@@ -149,6 +150,121 @@ TEST_CASE("[gpu/operator] BitShift(forward)", "[BitShift][GPU]")
                 cudaFree(array1_d);
 
 
+            }
+            Log::info("number of elements over time spent: {}\n", (number_of_operation / duration.count()));
+            Log::info("total time: {}μs\n", duration.count());
+        }
+    }
+
+        SECTION("Float cast") {
+            constexpr std::uint16_t NBTRIALS = 15;
+            // Create a random number generator
+            std::random_device rd;
+            std::mt19937 gen(rd());
+            std::uniform_int_distribution<int> valueDist(-15, 15);
+            std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(2), std::size_t(5));
+            std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(3));
+            std::uniform_int_distribution<int> boolDist(0,1);
+        
+            BitShift_Op::BitShiftDirection direction = BitShift_Op::BitShiftDirection::left;
+        
+            if(valueDist(gen) % 2 == 0)
+            {
+                direction = BitShift_Op::BitShiftDirection::right;
+            }
+        
+            // Create BitShift Operator
+            std::shared_ptr<Node> myBitShift = BitShift(direction);
+            auto op = std::static_pointer_cast<OperatorTensor>(myBitShift-> getOperator());
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+        
+            // Create 2 input Tensors
+            std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>();
+            op->associateInput(0,T0);
+            T0->setDataType(DataType::Float32);
+            T0->setBackend("cuda");
+            std::shared_ptr<Tensor> T1 = std::make_shared<Tensor>();
+            op -> associateInput(1,T1);
+            T1->setDataType(DataType::Float32);
+            T1->setBackend("cuda");
+        
+            // Create results Tensor
+            std::shared_ptr<Tensor> Tres = std::make_shared<Tensor>();
+            Tres->setDataType(DataType::Float32);
+            Tres->setBackend("cpu");
+        
+            // To measure execution time of 'BitShift_Op::forward()' member function call
+            std::chrono::time_point<std::chrono::system_clock> start;
+        
+            std::chrono::time_point<std::chrono::system_clock> end;
+            std::chrono::duration<double, std::micro> duration{};
+            std::size_t number_of_operation = 0;
+
+            for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
+                // generate 2 random Tensors
+                const std::size_t nbDims = nbDimsDist(gen);
+                std::vector<std::size_t> dims;
+                for (std::size_t i = 0; i < nbDims; ++i) {
+                    dims.push_back(dimSizeDist(gen));
+                }
+                const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+                number_of_operation += nb_elements;
+
+                // without broadcasting
+                float* array0 = new float[nb_elements];
+                float* array1 = new float[nb_elements];
+                float* result = new float[nb_elements];
+
+                for (std::size_t i = 0; i < nb_elements; ++i) {
+                    array0[i] = float(valueDist(gen));
+                    array1[i] = float(std::abs(valueDist(gen))); // bitshift is impossible with negative value
+                    if(direction == BitShift_Op::BitShiftDirection::left)
+                    {
+                        result[i] = float(int(array0[i]) << int(array1[i]));
+                    }
+                    else
+                    {
+                        result[i] = float(int(array0[i]) >> int(array1[i]));
+                    }
+                }
+
+                float *array0_d, *array1_d;
+        
+                // input0
+                T0->resize(dims);
+                cudaMalloc(reinterpret_cast<void **>(&array0_d), sizeof(float) * nb_elements);
+                cudaMemcpy(array0_d, array0, sizeof(float) * nb_elements, cudaMemcpyHostToDevice);
+                T0->getImpl()->setRawPtr(array0_d, nb_elements);
+
+                // input1
+                T1->resize(dims);
+                cudaMalloc(reinterpret_cast<void **>(&array1_d), sizeof(float) * nb_elements);
+                cudaMemcpy(array1_d, array1, sizeof(float) * nb_elements, cudaMemcpyHostToDevice);
+                T1->getImpl()->setRawPtr(array1_d, nb_elements);
+
+                // results
+                Tres->resize(dims);
+                Tres -> getImpl() -> setRawPtr(result, nb_elements);
+
+                op->forwardDims();
+                start = std::chrono::system_clock::now();
+                myBitShift->forward();
+                end = std::chrono::system_clock::now();
+                duration += std::chrono::duration_cast<std::chrono::microseconds>(end - start);
+
+
+                std::shared_ptr<Tensor> outputFallback;
+                const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, *Tres);
+                REQUIRE(approxEq<float>(cudaOutput, *(Tres)));
+
+                delete[] array0;
+                delete[] array1;
+                delete[] result;
+                cudaFree(array0_d);
+                cudaFree(array1_d);
+
+
             }
             Log::info("number of elements over time spent: {}\n", (number_of_operation / duration.count()));
             Log::info("total time: {}μs\n", duration.count());