Skip to content
Snippets Groups Projects
Commit 9a3803c2 authored by Jerome Hue's avatar Jerome Hue
Browse files

Add an (unregistered) backward kernel function for Sub

parent 839fa0aa
No related branches found
No related tags found
1 merge request!138Implement backward function for Sub
......@@ -42,6 +42,7 @@ void sub_contiguous_arrays(const std::size_t input1size,
namespace Aidge {
template <class I1, class I2, class O>
void SubImpl_cpu_forward_kernel(std::vector<std::size_t> dims0,
std::vector<std::size_t> dims1,
......@@ -149,6 +150,55 @@ void SubImpl_cpu_forward_kernel(std::vector<std::size_t> dims0,
}
}
template <class I1, class I2, class O>
void SubImpl_cpu_backward_kernel(const std::size_t input0Length,
const std::size_t input1Length,
const std::size_t gradOutputLength,
const std::vector<std::size_t>& dims0,
const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& outputDims,
const void* input0_,
const void* input1_,
const void* grad_output_,
void* gradientInput0_,
void* gradientInput1_)
{
const I1* input0 = static_cast<const I1*>(input0_);
const I2* input1 = static_cast<const I2*>(input1_);
const O* grad_output = static_cast<const O*>(grad_output_);
auto* grad_input_0 = static_cast<I1*>(gradientInput0_);
auto* grad_input_1 = static_cast<I2*>(gradientInput1_);
std::fill_n(grad_input_0, input0Length, static_cast<I1>(0));
std::fill_n(grad_input_1, input1Length, static_cast<I2>(0));
auto broadcastedDims0 = getBroadcastedDims(outputDims, dims0);
auto broadcastedDims1 = getBroadcastedDims(outputDims, dims1);
for (std::size_t i = 0; i < gradOutputLength; ++i) {
auto idxOutputGrad = getMultiDimIndices(outputDims, i);
std::vector<std::size_t> idxInput0(broadcastedDims0.size());
std::vector<std::size_t> idxInput1(broadcastedDims1.size());
for (std::size_t dimension = 0; dimension < broadcastedDims0.size(); ++dimension) {
idxInput0[dimension] = (broadcastedDims0[dimension] == 1) ? 0 : idxOutputGrad[dimension];
}
for (std::size_t dimension = 0; dimension < broadcastedDims1.size(); ++dimension) {
idxInput1[dimension] = (broadcastedDims1[dimension] == 1) ? 0 : idxOutputGrad[dimension];
}
auto idx0 = getFlattenedIndex(broadcastedDims0, idxInput0);
auto idx1 = getFlattenedIndex(broadcastedDims1, idxInput1);
// For subtraction: gradient of first input is 1 * grad_output
grad_input_0[idx0] += static_cast<I1>(grad_output[i]);
// For subtraction: gradient of second input is -1 * grad_output
grad_input_1[idx1] += static_cast<I2>(-grad_output[i]);
}
}
// Kernels registration to implementation entry point
REGISTRAR(SubImpl_cpu,
{DataType::Float32},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment