From 504790997a45d583b77ee2187c553a2c17f932de Mon Sep 17 00:00:00 2001
From: Noam ZERAH <noam.zerah@cea.fr>
Date: Tue, 25 Feb 2025 14:52:14 +0000
Subject: [PATCH] Updating cpu backend for bitshift with the new rounding
 attribute

---
 .../backend/cpu/operator/BitShiftImpl.hpp     |  1 +
 .../cpu/operator/BitShiftImpl_kernels.hpp     | 17 ++--
 src/operator/BitShiftImpl.cpp                 |  1 +
 unit_tests/operator/Test_BitShift.cpp         | 77 ++++++++++++++++++-
 4 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
index 807d2b97..79b0c5a3 100644
--- a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
+++ b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
@@ -24,6 +24,7 @@ namespace Aidge {
 // Operator implementation entry point for the backend
 using BitShiftImpl_cpu = OperatorImpl_cpu<BitShift_Op,
     void(const BitShift_Op::BitShiftDirection,
+    const bool,
     std::vector<std::size_t>,
     std::vector<std::size_t>,
     const std::vector<std::size_t>&,
diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp
index 1f2561af..89921d36 100644
--- a/include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp
@@ -27,6 +27,7 @@ namespace {
 template <class I1, class I2, class O>
 void bitshift_contiguous_arrays(
     const Aidge::BitShift_Op::BitShiftDirection direction,
+    const bool rounding,
     const std::size_t input1size,
     const std::size_t input2size,
     const std::size_t output1size,
@@ -34,13 +35,18 @@ void bitshift_contiguous_arrays(
     const I2* input_2,
     O* output)
 {
-    if(direction == Aidge::BitShift_Op::BitShiftDirection::right) {
+    if (direction == Aidge::BitShift_Op::BitShiftDirection::right) {
         for (std::size_t i = 0; i < output1size; ++i) {
             const std::size_t idx1 = (input1size != 1) ? i : 0;
             const std::size_t idx2 = (input2size != 1) ? i : 0;
-            output[i]= input_1[idx1] >> input_2[idx2];
+            const int shift = input_2[idx2]; 
+            
+            if (rounding && shift > 0) {
+                output[i] = ((input_1[idx1] >> (shift - 1)) + 1) >> 1;
+            } else {
+                output[i] = input_1[idx1] >> shift;
+            }
         }
-
     } else {
         for (std::size_t i = 0; i < output1size; ++i) {
             const std::size_t idx1 = (input1size != 1) ? i : 0;
@@ -55,6 +61,7 @@ namespace Aidge {
 template <class I1, class I2, class O>
 void BitShiftImpl_cpu_forward_kernel(
                                 const BitShift_Op::BitShiftDirection direction,
+                                const bool rounding,
                                 std::vector<std::size_t> dims0,
                                 std::vector<std::size_t> dims1,
                                 const std::vector<std::size_t>& outputDims,
@@ -79,7 +86,7 @@ void BitShiftImpl_cpu_forward_kernel(
     // special case for equal dimensions, the kernel is called with the entire arrays at once
     if (dims0 == dims1) {
         const std::size_t input0_contiguous_size = std::accumulate(dims0.cbegin(), dims0.cend(), std::size_t(1), std::multiplies<std::size_t>());
-        bitshift_contiguous_arrays(direction, input0_contiguous_size, input0_contiguous_size, input0_contiguous_size, input_0, input_1, output);
+        bitshift_contiguous_arrays(direction, rounding, input0_contiguous_size, input0_contiguous_size, input0_contiguous_size, input_0, input_1, output);
         return;
     }
 
@@ -142,7 +149,7 @@ void BitShiftImpl_cpu_forward_kernel(
     std::size_t dim = contiguousIdx - 1;
     const std::size_t nbStacks = std::accumulate(outputDims.cbegin(), outputDims.cbegin() + contiguousIdx, std::size_t(1), std::multiplies<std::size_t>());
     for (std::size_t stack = 0; stack < nbStacks;) {
-        bitshift_contiguous_arrays<I1,I2,O>(direction, input0_contiguous_size, input1_contiguous_size, output_contiguous_size,
+        bitshift_contiguous_arrays<I1,I2,O>(direction,rounding,input0_contiguous_size, input1_contiguous_size, output_contiguous_size,
                     input_0 + offsetIn0*input0_contiguous_size,
                     input_1 + offsetIn1*input1_contiguous_size,
                     output + offsetOut*output_contiguous_size);
diff --git a/src/operator/BitShiftImpl.cpp b/src/operator/BitShiftImpl.cpp
index c6940554..ad41cb15 100644
--- a/src/operator/BitShiftImpl.cpp
+++ b/src/operator/BitShiftImpl.cpp
@@ -33,6 +33,7 @@ void Aidge::BitShiftImpl_cpu::forward() {
     // Call kernel
     impl.forward(
         op_.direction(),
+        op_.rounding(),
         op_.getInput(0)->dims(),
         op_.getInput(1)->dims(),
         op_.getOutput(0)->dims(),
diff --git a/unit_tests/operator/Test_BitShift.cpp b/unit_tests/operator/Test_BitShift.cpp
index 33ab932e..9cce9d6d 100644
--- a/unit_tests/operator/Test_BitShift.cpp
+++ b/unit_tests/operator/Test_BitShift.cpp
@@ -8,7 +8,6 @@
  * SPDX-License-Identifier: EPL-2.0
  *
  ********************************************************************************/
-
 #include <chrono>      // std::micro, std::chrono::time_point,
                        // std::chrono::system_clock
 #include <cstddef>   // std::size_t
@@ -139,6 +138,82 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
             Log::info("number of elements over time spent: {}\n", (number_of_operation / duration.count()));
             Log::info("total time: {}μs\n", duration.count());
         }
+        SECTION("Test Forward Kernel with same dimensions and applying rounding") {
+            std::shared_ptr<Node> RoundBitShift = BitShift(BitShift_Op::BitShiftDirection::right,true);
+            auto op_r = std::static_pointer_cast<OperatorTensor>(RoundBitShift-> getOperator());
+            op_r->setDataType(DataType::Int32);
+            op_r->setBackend("cpu");
+        
+            // Create 2 input Tensors
+            std::shared_ptr<Tensor> T0_r = std::make_shared<Tensor>();
+            op_r->associateInput(0,T0_r);
+            T0_r->setDataType(DataType::Int32);
+            T0_r->setBackend("cpu");
+            std::shared_ptr<Tensor> T1_r = std::make_shared<Tensor>();
+            op_r -> associateInput(1,T1_r);
+            T1_r->setDataType(DataType::Int32);
+            T1_r->setBackend("cpu");
+        
+            // Create results Tensor
+            std::shared_ptr<Tensor> Tres_r = std::make_shared<Tensor>();
+            Tres_r->setDataType(DataType::Int32);
+            Tres_r->setBackend("cpu");
+            std::size_t number_of_operation = 0;
+            
+            for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
+                // generate 2 random Tensors
+                const std::size_t nbDims = nbDimsDist(gen);
+                std::vector<std::size_t> dims;
+                for (std::size_t i = 0; i < nbDims; ++i) {
+                    dims.push_back(dimSizeDist(gen));
+                }
+                const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+                number_of_operation += nb_elements;
+
+                // without broadcasting
+                int* array0 = new int[nb_elements];
+                int* array1 = new int[nb_elements];
+                int* result = new int[nb_elements];
+                for (std::size_t i = 0; i < nb_elements; ++i) 
+                {
+                    array0[i] = valueDist(gen);
+                    array1[i] = std::abs(valueDist(gen)); // bitshift is impossible with negative value
+                    result[i] = array0[i] >> array1[i];
+                    if(array1[i] > 0) //Cannot use rounding when shift value is 0
+                        result[i] = ((array0[i] >> (array1[i] - 1)) + 1) >> 1;
+                }
+
+                // input0
+                T0_r->resize(dims);
+                T0_r -> getImpl() -> setRawPtr(array0, nb_elements);
+
+                // input1
+                T1_r->resize(dims);
+                T1_r -> getImpl() -> setRawPtr(array1, nb_elements);
+
+                // results
+                Tres_r->resize(dims);
+                Tres_r -> getImpl() -> setRawPtr(result, nb_elements);
+
+                op_r->forwardDims();
+                start = std::chrono::system_clock::now();
+                RoundBitShift->forward();
+                end = std::chrono::system_clock::now();
+                duration += std::chrono::duration_cast<std::chrono::microseconds>(end - start);
+
+                bool is_eq_round = approxEq<int>(*(op_r->getOutput(0)), *Tres_r);
+                auto Output = *(op_r->getOutput(0));
+                auto prt = Output.getImpl()->rawPtr();
+
+                REQUIRE(is_eq_round);
+
+                delete[] array0;
+                delete[] array1;
+                delete[] result;
+            }
+            Log::info("number of elements over time spent: {}\n", (number_of_operation / duration.count()));
+            Log::info("total time: {}μs\n", duration.count());
+        }
         SECTION("Test BitShift kernels with Broadcasting") {
             std::size_t number_of_operation = 0;
 
-- 
GitLab