From 05ca42758e7055f74f7d465a0b057f664e9c085c Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Fri, 26 Jan 2024 11:47:50 +0100 Subject: [PATCH] add broadcasting for Sub operator --- .../aidge/backend/cpu/operator/SubImpl.hpp | 4 +- .../cpu/operator/SubImpl_forward_kernels.hpp | 42 +++++++++---------- src/operator/SubImpl.cpp | 11 ++++- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/include/aidge/backend/cpu/operator/SubImpl.hpp b/include/aidge/backend/cpu/operator/SubImpl.hpp index 2d4c22f0..b329ec6e 100644 --- a/include/aidge/backend/cpu/operator/SubImpl.hpp +++ b/include/aidge/backend/cpu/operator/SubImpl.hpp @@ -25,10 +25,10 @@ namespace Aidge { // compute kernel registry for forward and backward class SubImplForward_cpu - : public Registrable<SubImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*,void*)> { + : public Registrable<SubImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { }; class SubImplBackward_cpu - : public Registrable<SubImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*, void*)> { + : public Registrable<SubImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { }; class SubImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp index 08f2e24f..19b0bd21 100644 --- a/include/aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp @@ -14,39 +14,35 @@ #include "aidge/utils/Registrar.hpp" +#include "aidge/backend/cpu/data/Broadcasting.hpp" #include "aidge/backend/cpu/operator/SubImpl.hpp" + namespace Aidge { template <class I1, class I2, class O> -void SubImpl_cpu_forward_kernel(std::size_t input1Length, - std::size_t input2Length, - const void* input1_, - const void* input2_, - void* output_) { +void SubImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims, + const std::vector<std::size_t>& input2Dims, + const std::vector<std::size_t>& outputDims, + const void* input1_, + const void* input2_, + void* output_) { const I1* input_1 = static_cast<const I1*>(input1_); const I2* input_2 = static_cast<const I2*>(input2_); O* output = static_cast<O*>(output_); - if (input2Length == input1Length) - { - for (std::size_t i = 0; i < input1Length; ++i) { - output[i] = input_1[i] - input_2[i]; - } - } - else if (input2Length == 1) - { - for (std::size_t i = 0; i < input1Length; ++i) { - output[i] = input_1[i] - input_2[0]; - } - } - else // input_2 is 1d and of size the number of channels of input_1 - { - for (std::size_t i = 0; i < input1Length; ++i) { - std::size_t channelIdx = i % input2Length; - output[i] = input_1[i] - input_2[channelIdx]; - } + size_t totalElements = 1; + for (size_t dimSize : outputDims) { + totalElements *= dimSize; } + + for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) + { + std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex); + std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); + std::size_t idx2 = getFlattenedIndex(input2Dims, indexes); + output[oIndex] = input_1[idx1] - input_2[idx2]; + } } namespace { diff --git a/src/operator/SubImpl.cpp b/src/operator/SubImpl.cpp index 038a1154..475f8cb8 100644 --- a/src/operator/SubImpl.cpp +++ b/src/operator/SubImpl.cpp @@ -17,6 +17,7 @@ #include "aidge/operator/Sub.hpp" #include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/Broadcasting.hpp" #include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/operator/SubImpl.hpp" @@ -35,9 +36,15 @@ void Aidge::SubImpl_cpu::forward() { std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()); + const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims()); + // Call kernel - kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(), + kernelFunc(inputDims0, + inputDims1, + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(1)), getCPUPtr(mOp.getRawOutput(0))); -- GitLab