Skip to content
Snippets Groups Projects
Commit 50a2eb35 authored by Jerome Hue's avatar Jerome Hue Committed by Maxence Naud
Browse files

Add an (unregistered) backward kernel function for Sub

parent e8e3f535
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,7 @@ void sub_contiguous_arrays(const std::size_t input1size, ...@@ -42,6 +42,7 @@ void sub_contiguous_arrays(const std::size_t input1size,
namespace Aidge { namespace Aidge {
template <class I1, class I2, class O> template <class I1, class I2, class O>
void SubImpl_cpu_forward_kernel(std::vector<std::size_t> dims0, void SubImpl_cpu_forward_kernel(std::vector<std::size_t> dims0,
std::vector<std::size_t> dims1, std::vector<std::size_t> dims1,
...@@ -149,6 +150,55 @@ void SubImpl_cpu_forward_kernel(std::vector<std::size_t> dims0, ...@@ -149,6 +150,55 @@ void SubImpl_cpu_forward_kernel(std::vector<std::size_t> dims0,
} }
} }
template <class I1, class I2, class O>
void SubImpl_cpu_backward_kernel(const std::size_t input0Length,
const std::size_t input1Length,
const std::size_t gradOutputLength,
const std::vector<std::size_t>& dims0,
const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& outputDims,
const void* input0_,
const void* input1_,
const void* grad_output_,
void* gradientInput0_,
void* gradientInput1_)
{
const I1* input0 = static_cast<const I1*>(input0_);
const I2* input1 = static_cast<const I2*>(input1_);
const O* grad_output = static_cast<const O*>(grad_output_);
auto* grad_input_0 = static_cast<I1*>(gradientInput0_);
auto* grad_input_1 = static_cast<I2*>(gradientInput1_);
std::fill_n(grad_input_0, input0Length, static_cast<I1>(0));
std::fill_n(grad_input_1, input1Length, static_cast<I2>(0));
auto broadcastedDims0 = getBroadcastedDims(outputDims, dims0);
auto broadcastedDims1 = getBroadcastedDims(outputDims, dims1);
for (std::size_t i = 0; i < gradOutputLength; ++i) {
auto idxOutputGrad = getMultiDimIndices(outputDims, i);
std::vector<std::size_t> idxInput0(broadcastedDims0.size());
std::vector<std::size_t> idxInput1(broadcastedDims1.size());
for (std::size_t dimension = 0; dimension < broadcastedDims0.size(); ++dimension) {
idxInput0[dimension] = (broadcastedDims0[dimension] == 1) ? 0 : idxOutputGrad[dimension];
}
for (std::size_t dimension = 0; dimension < broadcastedDims1.size(); ++dimension) {
idxInput1[dimension] = (broadcastedDims1[dimension] == 1) ? 0 : idxOutputGrad[dimension];
}
auto idx0 = getFlattenedIndex(broadcastedDims0, idxInput0);
auto idx1 = getFlattenedIndex(broadcastedDims1, idxInput1);
// For subtraction: gradient of first input is 1 * grad_output
grad_input_0[idx0] += static_cast<I1>(grad_output[i]);
// For subtraction: gradient of second input is -1 * grad_output
grad_input_1[idx1] += static_cast<I2>(-grad_output[i]);
}
}
// Kernels registration to implementation entry point // Kernels registration to implementation entry point
REGISTRAR(SubImpl_cpu, REGISTRAR(SubImpl_cpu,
{DataType::Float32}, {DataType::Float32},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment