diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
index 52400baad4922d5729ea4b13d260f04a2836ed59..0baf6874f10eeb3f174e0c98ca7e315f1f28165c 100644
--- a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
+++ b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
@@ -25,10 +25,10 @@ namespace Aidge {
 
 // compute kernel registry for forward and backward
 class BitShiftImplForward_cpu
-    : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const Direction,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
+    : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
 };
 class BitShiftImplBackward_cpu
-    : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const Direction,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
+    : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
 };
 
 class BitShiftImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp
index 98b34b9d5d388fd9cabb562f8bd7d311644b73c1..bd841968c4481f5b27359747180472bdeaf77b31 100644
--- a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp
@@ -25,7 +25,7 @@
 namespace Aidge {
 template <class I1, class I2, class O>
 void BitShiftImpl_cpu_forward_kernel(
-                                const Direction direction,
+                                const BitShiftDirection direction,
                                 const std::vector<std::size_t>& input1Dims,
                                 const std::vector<std::size_t>& input2Dims,
                                 const std::vector<std::size_t>& outputDims,
@@ -34,27 +34,22 @@ void BitShiftImpl_cpu_forward_kernel(
                                 void* output_
                                 ) {
 
-    //Cast des entrées en classes I / O                                    
     const I1* input_1 = static_cast<const I1*>(input1_);
     const I2* input_2 = static_cast<const I2*>(input2_);
     O* output = static_cast<O*>(output_);
 
-    size_t totalElements = 1;
-    for (size_t dimSize : outputDims) {
-        totalElements *= dimSize;
-    }
-
+    const size_t totalElements = std::accumulate(outputDims.begin(), outputDims.end(), std::size_t(1), std::multiplies<std::size_t>());
+    
     for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex)
     {
         std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex);
         std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
         std::size_t idx2 = getFlattenedIndex(input2Dims, indexes);
-        if(direction == Direction::right)
+        if(direction == BitShiftDirection::right)
         {
-                //BitShift ne fonctionne pas sur les types à virgule flottante
                 output[oIndex]= input_1[idx1] >> input_2[idx2];
         }
-        else if(direction == Direction::left)
+        else
         {
                 output[oIndex] = input_1[idx1] << input_2[idx2];
         }
diff --git a/src/operator/BitShiftImpl.cpp b/src/operator/BitShiftImpl.cpp
index 600c3a4bc930e71ecd9e72e2af7c5fedfa4a1de2..16e13f8f294f4bd54142a1e431e005e1c8544f85 100644
--- a/src/operator/BitShiftImpl.cpp
+++ b/src/operator/BitShiftImpl.cpp
@@ -42,7 +42,7 @@ void Aidge::BitShiftImpl_cpu::forward() {
     const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
                                                                    std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims());
 
-    Direction direction = op_.direction();
+    BitShiftDirection direction = op_.direction();
 
     // Call kernel
     kernelFunc(
diff --git a/unit_tests/operator/Test_BitShift.cpp b/unit_tests/operator/Test_BitShift.cpp
index 33c2bd1d879f7ea22d8d8f4448a90fb95ed2e349..59e2bca5e96b6f875b8a57d97267bc6886a490dd 100644
--- a/unit_tests/operator/Test_BitShift.cpp
+++ b/unit_tests/operator/Test_BitShift.cpp
@@ -29,20 +29,20 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
     // Create a random number generator
     std::random_device rd;
     std::mt19937 gen(rd());
-    std::uniform_int_distribution<int> valueDist(-15, 15); // Random int distribution between -15 and 15 
+    std::uniform_int_distribution<int> valueDist(-15, 15); 
     std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(2), std::size_t(5));
     std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(3));
     std::uniform_int_distribution<int> boolDist(0,1);
 
-    Direction direction = Direction::left;
+    BitShiftDirection direction = BitShiftDirection::left;
 
     if(valueDist(gen) % 2 == 0)
     {
-        direction = Direction::right;
+        direction = BitShiftDirection::right;
     }
 
     // Create BitShift Operator
-    std::shared_ptr<Node> myBitShift = BitShift(direction); // Left opérator to start
+    std::shared_ptr<Node> myBitShift = BitShift(direction);
     auto op = std::static_pointer_cast<OperatorTensor>(myBitShift-> getOperator());
     op->setDataType(DataType::Int32);
     op->setBackend("cpu");
@@ -69,7 +69,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
     std::chrono::duration<double, std::micro> duration{};
 
     SECTION("BitShiftImpl_cpu::forward()") {
-        SECTION("+1-D Tensor / +1-D Tensor - same dimensions") {
+        SECTION("Test Forward Kernel with same dimensions") {
             std::size_t number_of_operation = 0;
 
             for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
@@ -90,7 +90,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
                 for (std::size_t i = 0; i < nb_elements; ++i) {
                     array0[i] = valueDist(gen);
                     array1[i] = std::abs(valueDist(gen)); // bitshift is impossible with negative value
-                    if(direction == Direction::left)
+                    if(direction == BitShiftDirection::left)
                     {
                         result[i] = array0[i] << array1[i];
                     }
@@ -134,7 +134,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
             std::cout << "number of elements over time spent: " << (number_of_operation / duration.count())<< std::endl;
             std::cout << "total time: " << duration.count() << "μs" << std::endl;
         }
-        SECTION("+1-D Tensor / +1-D Tensor - broadcasting") {
+        SECTION("Test BitShift kernels with Broadcasting") {
             std::size_t number_of_operation = 0;
 
             for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
@@ -188,7 +188,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
                                 std::size_t idx1 = idx1_0
                                                     + strides1[2] * ((dims1[2] > 1) ? c : 0)
                                                     + ((dims1[3] > 1) ? d : 0);
-                                if(direction == Direction::left)
+                                if(direction == BitShiftDirection::left)
                                 {
                                     result[idx_out + d] = array0[idx0] << array1[idx1];
                                 }