diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp index 0baf6874f10eeb3f174e0c98ca7e315f1f28165c..ad4bdc6aace7225f1eab9a9bc1bdf9edff6d691a 100644 --- a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp +++ b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp @@ -25,10 +25,10 @@ namespace Aidge { // compute kernel registry for forward and backward class BitShiftImplForward_cpu - : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { + : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { }; class BitShiftImplBackward_cpu - : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { + : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { }; class BitShiftImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp index bd841968c4481f5b27359747180472bdeaf77b31..b8b809b69bd479ee122b9981d17817b91c68410a 100644 --- a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp @@ -25,7 +25,7 @@ namespace Aidge { template <class I1, class I2, class O> void BitShiftImpl_cpu_forward_kernel( - const BitShiftDirection direction, + const BitShift_Op::BitShiftDirection direction, const std::vector<std::size_t>& input1Dims, const std::vector<std::size_t>& input2Dims, const std::vector<std::size_t>& outputDims, @@ -45,7 +45,7 @@ void BitShiftImpl_cpu_forward_kernel( std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex); std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); std::size_t idx2 = getFlattenedIndex(input2Dims, indexes); - if(direction == BitShiftDirection::right) + if(direction == BitShift_Op::BitShiftDirection::right) { output[oIndex]= input_1[idx1] >> input_2[idx2]; } diff --git a/src/operator/BitShiftImpl.cpp b/src/operator/BitShiftImpl.cpp index 16e13f8f294f4bd54142a1e431e005e1c8544f85..a3ffa89dc4904fbcc0b6a2e313b450e8d6f57ed9 100644 --- a/src/operator/BitShiftImpl.cpp +++ b/src/operator/BitShiftImpl.cpp @@ -42,7 +42,7 @@ void Aidge::BitShiftImpl_cpu::forward() { const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims()); - BitShiftDirection direction = op_.direction(); + BitShift_Op::BitShiftDirection direction = op_.direction(); // Call kernel kernelFunc( diff --git a/unit_tests/operator/Test_BitShift.cpp b/unit_tests/operator/Test_BitShift.cpp index 59e2bca5e96b6f875b8a57d97267bc6886a490dd..a52990bc7991a325ce151cf6634b0d5a831992c8 100644 --- a/unit_tests/operator/Test_BitShift.cpp +++ b/unit_tests/operator/Test_BitShift.cpp @@ -34,11 +34,11 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") { std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(3)); std::uniform_int_distribution<int> boolDist(0,1); - BitShiftDirection direction = BitShiftDirection::left; + BitShift_Op::BitShiftDirection direction = BitShift_Op::BitShiftDirection::left; if(valueDist(gen) % 2 == 0) { - direction = BitShiftDirection::right; + direction = BitShift_Op::BitShiftDirection::right; } // Create BitShift Operator @@ -90,7 +90,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") { for (std::size_t i = 0; i < nb_elements; ++i) { array0[i] = valueDist(gen); array1[i] = std::abs(valueDist(gen)); // bitshift is impossible with negative value - if(direction == BitShiftDirection::left) + if(direction == BitShift_Op::BitShiftDirection::left) { result[i] = array0[i] << array1[i]; } @@ -188,7 +188,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") { std::size_t idx1 = idx1_0 + strides1[2] * ((dims1[2] > 1) ? c : 0) + ((dims1[3] > 1) ? d : 0); - if(direction == BitShiftDirection::left) + if(direction == BitShift_Op::BitShiftDirection::left) { result[idx_out + d] = array0[idx0] << array1[idx1]; }