Fix #37 (closed)
MulImpl.hpp
and MulImpl_kernels.hpp
: Refactor backward pass of Mul
, see explantions below.MulImpl.cpp
Test_MulImpl.cpp
: Test the new backward pass. The tests consists of 4 tests with fixed values, and one with random values (but fixed dimensions).The backward kernel computes gradients for the multiplication operation while handling broadcasting. For each element in the output gradient tensor, it:
grad_input_0[i] += grad_output[k] * input1[j]
grad_input_1[j] += grad_output[k] * input0[i]
where i,j are the contributing indices from input0 and input1 for output position k
`input0` [1,3]: [[1,2,3]]
`input1` [4,3]: [[4,5,6],
[7,8,9],
[1,2,3],
[4,5,6]]
Output shape: [4,3], thus `outputDims` = [4,3], `dims0` = [1,3], and `dims1` = [4,3].
Broadcasted shapes:
- getBroadcastedDims(outputDims, dims0): [1,3]
- getBroadcastedDims(outputDims, dims1): [4,3]
For instance, assume we multiply two tensors A and B, with A of shape [1,3] and B of shape [4,3].
input0 = [[1,2,3]]
input1 = [[4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6]]
Let's also assumes that the output gradient is full of 1s.
Then we iterate for i = 0 to output.size()
:
At iteration 0 / index [0,0]
→ The element at index [0,0] of input0 contributed to the output value
→ Thelement at index [0,0] of input1 contributed to the output value.
grad_input_0[0] += grad_output[0,0] * input0[0,0]
grad_input_1[0] += grad_output[0,0] * input1[0,0]
At iteration 1 / index [0,1]
→ The element at index [0,1] of input0 contributed to the output value.
→ The element at index [0,1] of input1 contributed to the output value.
grad_input_0[1] += grad_output[0,1] * input1[0,1]
grad_input_1[1] += grad_output[0,1] * input0[0,1]
At Iteration 2 / index [0,2]
→ The element at index [0,2] of input0 contributed to the output value.
→ The element at index [0,2] of input1 contributed to the output value.
grad_input_0[2] += grad_output[0,2] * input1[0,2]
grad_input_1[2] += grad_output[0,2] * input0[0,2]
At Iteration 3 / index [1,0]
→ The element at index [0,0] of input0 contributed to the output value (due to broadcasting!).
This can be determined by looking at the dimensions of getBroadcastedDims(outputDims, dims0)
, which is 1
for index 0
. It indicates that the dimension has been broacasted for the output, or that both dimensions are equal to 1
.
→ The element at index [1,0] of input1 contributed to the output value.
// Updates:
grad_input_0[0] += grad_output[1,0] * input1[1,0]
grad_input_1[3] += grad_output[1,0] * input0[0,0]
etc.
Copyright © Eclipse Foundation, Inc. All Rights Reserved. Privacy Policy | Terms of Use | Copyright Agent