[add] Element wise backward
Context
Fix #37 (closed)
Modified Files
-
MulImpl.hpp
andMulImpl_kernels.hpp
: Refactor backward pass ofMul
, see explantions below. MulImpl.cpp
-
Test_MulImpl.cpp
: Test the new backward pass. The tests consists of 4 tests with fixed values, and one with random values (but fixed dimensions).
Description of the changes
Overview
The backward kernel computes gradients for the multiplication operation while handling broadcasting. For each element in the output gradient tensor, it:
- Determines which input elements contributed to this output (considering broadcasting)
- Applies the chain rule for multiplication to compute gradients
grad_input_0[i] += grad_output[k] * input1[j]
grad_input_1[j] += grad_output[k] * input0[i]
where i,j are the contributing indices from input0 and input1 for output position k
Example
`input0` [1,3]: [[1,2,3]]
`input1` [4,3]: [[4,5,6],
[7,8,9],
[1,2,3],
[4,5,6]]
Output shape: [4,3], thus `outputDims` = [4,3], `dims0` = [1,3], and `dims1` = [4,3].
Broadcasted shapes:
- getBroadcastedDims(outputDims, dims0): [1,3]
- getBroadcastedDims(outputDims, dims1): [4,3]
For instance, assume we multiply two tensors A and B, with A of shape [1,3] and B of shape [4,3].
input0 = [[1,2,3]]
input1 = [[4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6]]
Let's also assumes that the output gradient is full of 1s.
Then we iterate for i = 0 to output.size()
:
-
At iteration 0 / index [0,0]
→ The element at index [0,0] of input0 contributed to the output value
→ Thelement at index [0,0] of input1 contributed to the output value.
grad_input_0[0] += grad_output[0,0] * input0[0,0]
grad_input_1[0] += grad_output[0,0] * input1[0,0]
-
At iteration 1 / index [0,1]
→ The element at index [0,1] of input0 contributed to the output value.
→ The element at index [0,1] of input1 contributed to the output value.
grad_input_0[1] += grad_output[0,1] * input1[0,1]
grad_input_1[1] += grad_output[0,1] * input0[0,1]
-
At Iteration 2 / index [0,2]
→ The element at index [0,2] of input0 contributed to the output value.
→ The element at index [0,2] of input1 contributed to the output value.
grad_input_0[2] += grad_output[0,2] * input1[0,2]
grad_input_1[2] += grad_output[0,2] * input0[0,2]
-
At Iteration 3 / index [1,0]
→ The element at index [0,0] of input0 contributed to the output value (due to broadcasting!). This can be determined by looking at the dimensions of
getBroadcastedDims(outputDims, dims0)
, which is1
for index0
. It indicates that the dimension has been broacasted for the output, or that both dimensions are equal to1
.→ The element at index [1,0] of input1 contributed to the output value.
// Updates:
grad_input_0[0] += grad_output[1,0] * input1[1,0]
grad_input_1[3] += grad_output[1,0] * input0[0,0]
etc.