diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp index 52400baad4922d5729ea4b13d260f04a2836ed59..0baf6874f10eeb3f174e0c98ca7e315f1f28165c 100644 --- a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp +++ b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp @@ -25,10 +25,10 @@ namespace Aidge { // compute kernel registry for forward and backward class BitShiftImplForward_cpu - : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const Direction,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { + : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { }; class BitShiftImplBackward_cpu - : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const Direction,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { + : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { }; class BitShiftImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp index 98b34b9d5d388fd9cabb562f8bd7d311644b73c1..bd841968c4481f5b27359747180472bdeaf77b31 100644 --- a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp @@ -25,7 +25,7 @@ namespace Aidge { template <class I1, class I2, class O> void BitShiftImpl_cpu_forward_kernel( - const Direction direction, + const BitShiftDirection direction, const std::vector<std::size_t>& input1Dims, const std::vector<std::size_t>& input2Dims, const std::vector<std::size_t>& outputDims, @@ -34,27 +34,22 @@ void BitShiftImpl_cpu_forward_kernel( void* output_ ) { - //Cast des entrées en classes I / O const I1* input_1 = static_cast<const I1*>(input1_); const I2* input_2 = static_cast<const I2*>(input2_); O* output = static_cast<O*>(output_); - size_t totalElements = 1; - for (size_t dimSize : outputDims) { - totalElements *= dimSize; - } - + const size_t totalElements = std::accumulate(outputDims.begin(), outputDims.end(), std::size_t(1), std::multiplies<std::size_t>()); + for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) { std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex); std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); std::size_t idx2 = getFlattenedIndex(input2Dims, indexes); - if(direction == Direction::right) + if(direction == BitShiftDirection::right) { - //BitShift ne fonctionne pas sur les types à virgule flottante output[oIndex]= input_1[idx1] >> input_2[idx2]; } - else if(direction == Direction::left) + else { output[oIndex] = input_1[idx1] << input_2[idx2]; } diff --git a/src/operator/BitShiftImpl.cpp b/src/operator/BitShiftImpl.cpp index 600c3a4bc930e71ecd9e72e2af7c5fedfa4a1de2..16e13f8f294f4bd54142a1e431e005e1c8544f85 100644 --- a/src/operator/BitShiftImpl.cpp +++ b/src/operator/BitShiftImpl.cpp @@ -42,7 +42,7 @@ void Aidge::BitShiftImpl_cpu::forward() { const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims()); - Direction direction = op_.direction(); + BitShiftDirection direction = op_.direction(); // Call kernel kernelFunc( diff --git a/unit_tests/operator/Test_BitShift.cpp b/unit_tests/operator/Test_BitShift.cpp index 33c2bd1d879f7ea22d8d8f4448a90fb95ed2e349..59e2bca5e96b6f875b8a57d97267bc6886a490dd 100644 --- a/unit_tests/operator/Test_BitShift.cpp +++ b/unit_tests/operator/Test_BitShift.cpp @@ -29,20 +29,20 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") { // Create a random number generator std::random_device rd; std::mt19937 gen(rd()); - std::uniform_int_distribution<int> valueDist(-15, 15); // Random int distribution between -15 and 15 + std::uniform_int_distribution<int> valueDist(-15, 15); std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(2), std::size_t(5)); std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(3)); std::uniform_int_distribution<int> boolDist(0,1); - Direction direction = Direction::left; + BitShiftDirection direction = BitShiftDirection::left; if(valueDist(gen) % 2 == 0) { - direction = Direction::right; + direction = BitShiftDirection::right; } // Create BitShift Operator - std::shared_ptr<Node> myBitShift = BitShift(direction); // Left opérator to start + std::shared_ptr<Node> myBitShift = BitShift(direction); auto op = std::static_pointer_cast<OperatorTensor>(myBitShift-> getOperator()); op->setDataType(DataType::Int32); op->setBackend("cpu"); @@ -69,7 +69,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") { std::chrono::duration<double, std::micro> duration{}; SECTION("BitShiftImpl_cpu::forward()") { - SECTION("+1-D Tensor / +1-D Tensor - same dimensions") { + SECTION("Test Forward Kernel with same dimensions") { std::size_t number_of_operation = 0; for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) { @@ -90,7 +90,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") { for (std::size_t i = 0; i < nb_elements; ++i) { array0[i] = valueDist(gen); array1[i] = std::abs(valueDist(gen)); // bitshift is impossible with negative value - if(direction == Direction::left) + if(direction == BitShiftDirection::left) { result[i] = array0[i] << array1[i]; } @@ -134,7 +134,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") { std::cout << "number of elements over time spent: " << (number_of_operation / duration.count())<< std::endl; std::cout << "total time: " << duration.count() << "μs" << std::endl; } - SECTION("+1-D Tensor / +1-D Tensor - broadcasting") { + SECTION("Test BitShift kernels with Broadcasting") { std::size_t number_of_operation = 0; for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) { @@ -188,7 +188,7 @@ TEST_CASE("[cpu/operator] BitShift_TEST", "[BitShift][CPU]") { std::size_t idx1 = idx1_0 + strides1[2] * ((dims1[2] > 1) ? c : 0) + ((dims1[3] > 1) ? d : 0); - if(direction == Direction::left) + if(direction == BitShiftDirection::left) { result[idx_out + d] = array0[idx0] << array1[idx1]; }