From a7a80e6daa692685b0029554350b282e1cf011ba Mon Sep 17 00:00:00 2001
From: Noam ZERAH <noam.zerah@cea.fr>
Date: Tue, 17 Sep 2024 12:12:43 +0000
Subject: [PATCH] Ajusting BitShiftDirection to be part of BitShift_Op

---
 include/aidge/backend/cpu/operator/BitShiftImpl.hpp       | 4 ++--
 .../backend/cpu/operator/BitShiftImpl_forward_kernels.hpp | 4 ++--
 src/operator/BitShiftImpl.cpp                             | 2 +-
 unit_tests/operator/Test_BitShift.cpp                     | 8 ++++----
 4 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
index 0baf6874..ad4bdc6a 100644
--- a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
+++ b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
@@ -25,10 +25,10 @@ namespace Aidge {
 
 // compute kernel registry for forward and backward
 class BitShiftImplForward_cpu
-    : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
+    : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
 };
 class BitShiftImplBackward_cpu
-    : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
+    : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
 };
 
 class BitShiftImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp
index bd841968..b8b809b6 100644
--- a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp
@@ -25,7 +25,7 @@
 namespace Aidge {
 template <class I1, class I2, class O>
 void BitShiftImpl_cpu_forward_kernel(
-                                const BitShiftDirection direction,
+                                const BitShift_Op::BitShiftDirection direction,
                                 const std::vector<std::size_t>& input1Dims,
                                 const std::vector<std::size_t>& input2Dims,
                                 const std::vector<std::size_t>& outputDims,
@@ -45,7 +45,7 @@ void BitShiftImpl_cpu_forward_kernel(
         std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex);
         std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
         std::size_t idx2 = getFlattenedIndex(input2Dims, indexes);
-        if(direction == BitShiftDirection::right)
+        if(direction == BitShift_Op::BitShiftDirection::right)
         {
                 output[oIndex]= input_1[idx1] >> input_2[idx2];
         }
diff --git a/src/operator/BitShiftImpl.cpp b/src/operator/BitShiftImpl.cpp
index 16e13f8f..a3ffa89d 100644
--- a/src/operator/BitShiftImpl.cpp
+++ b/src/operator/BitShiftImpl.cpp
@@ -42,7 +42,7 @@ void Aidge::BitShiftImpl_cpu::forward() {
     const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
                                                                    std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims());
 
-    BitShiftDirection direction = op_.direction();
+    BitShift_Op::BitShiftDirection direction = op_.direction();
 
     // Call kernel
     kernelFunc(
diff --git a/unit_tests/operator/Test_BitShift.cpp b/unit_tests/operator/Test_BitShift.cpp
index 59e2bca5..a52990bc 100644
--- a/unit_tests/operator/Test_BitShift.cpp
+++ b/unit_tests/operator/Test_BitShift.cpp
@@ -34,11 +34,11 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
     std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(3));
     std::uniform_int_distribution<int> boolDist(0,1);
 
-    BitShiftDirection direction = BitShiftDirection::left;
+    BitShift_Op::BitShiftDirection direction = BitShift_Op::BitShiftDirection::left;
 
     if(valueDist(gen) % 2 == 0)
     {
-        direction = BitShiftDirection::right;
+        direction = BitShift_Op::BitShiftDirection::right;
     }
 
     // Create BitShift Operator
@@ -90,7 +90,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
                 for (std::size_t i = 0; i < nb_elements; ++i) {
                     array0[i] = valueDist(gen);
                     array1[i] = std::abs(valueDist(gen)); // bitshift is impossible with negative value
-                    if(direction == BitShiftDirection::left)
+                    if(direction == BitShift_Op::BitShiftDirection::left)
                     {
                         result[i] = array0[i] << array1[i];
                     }
@@ -188,7 +188,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") {
                                 std::size_t idx1 = idx1_0
                                                     + strides1[2] * ((dims1[2] > 1) ? c : 0)
                                                     + ((dims1[3] > 1) ? d : 0);
-                                if(direction == BitShiftDirection::left)
+                                if(direction == BitShift_Op::BitShiftDirection::left)
                                 {
                                     result[idx_out + d] = array0[idx0] << array1[idx1];
                                 }
-- 
GitLab