diff --git a/include/aidge/backend/cpu/operator/SubImpl_kernels.hpp b/include/aidge/backend/cpu/operator/SubImpl_kernels.hpp index 1d789c3c8886d35ce6597d5704c76060bad196c1..a195477652001f2f900a96c3fe36a9b6ddfd5f5f 100644 --- a/include/aidge/backend/cpu/operator/SubImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SubImpl_kernels.hpp @@ -42,6 +42,7 @@ void sub_contiguous_arrays(const std::size_t input1size, namespace Aidge { + template <class I1, class I2, class O> void SubImpl_cpu_forward_kernel(std::vector<std::size_t> dims0, std::vector<std::size_t> dims1, @@ -149,6 +150,55 @@ void SubImpl_cpu_forward_kernel(std::vector<std::size_t> dims0, } } +template <class I1, class I2, class O> +void SubImpl_cpu_backward_kernel(const std::size_t input0Length, + const std::size_t input1Length, + const std::size_t gradOutputLength, + const std::vector<std::size_t>& dims0, + const std::vector<std::size_t>& dims1, + const std::vector<std::size_t>& outputDims, + const void* input0_, + const void* input1_, + const void* grad_output_, + void* gradientInput0_, + void* gradientInput1_) +{ + const I1* input0 = static_cast<const I1*>(input0_); + const I2* input1 = static_cast<const I2*>(input1_); + const O* grad_output = static_cast<const O*>(grad_output_); + auto* grad_input_0 = static_cast<I1*>(gradientInput0_); + auto* grad_input_1 = static_cast<I2*>(gradientInput1_); + + std::fill_n(grad_input_0, input0Length, static_cast<I1>(0)); + std::fill_n(grad_input_1, input1Length, static_cast<I2>(0)); + + auto broadcastedDims0 = getBroadcastedDims(outputDims, dims0); + auto broadcastedDims1 = getBroadcastedDims(outputDims, dims1); + + for (std::size_t i = 0; i < gradOutputLength; ++i) { + auto idxOutputGrad = getMultiDimIndices(outputDims, i); + std::vector<std::size_t> idxInput0(broadcastedDims0.size()); + std::vector<std::size_t> idxInput1(broadcastedDims1.size()); + + for (std::size_t dimension = 0; dimension < broadcastedDims0.size(); ++dimension) { + idxInput0[dimension] = (broadcastedDims0[dimension] == 1) ? 0 : idxOutputGrad[dimension]; + } + + for (std::size_t dimension = 0; dimension < broadcastedDims1.size(); ++dimension) { + idxInput1[dimension] = (broadcastedDims1[dimension] == 1) ? 0 : idxOutputGrad[dimension]; + } + + auto idx0 = getFlattenedIndex(broadcastedDims0, idxInput0); + auto idx1 = getFlattenedIndex(broadcastedDims1, idxInput1); + + // For subtraction: gradient of first input is 1 * grad_output + grad_input_0[idx0] += static_cast<I1>(grad_output[i]); + // For subtraction: gradient of second input is -1 * grad_output + grad_input_1[idx1] += static_cast<I2>(-grad_output[i]); + } +} + + // Kernels registration to implementation entry point REGISTRAR(SubImpl_cpu, {DataType::Float32},