Add partial support for Spiking Neural Networks backward in Aidge
Context
Improve support for Spiking Neural Networks in Aidge.
This MR does multiple things :
- Add objects that are necessary to perform the backward pass of such operators :
-
Context
: A class that is used to save (before forward) and restore (before backward) the inputs of operators. This needed for instance in the Heaviside operator, which uses a surrogate gradient, defined as : \frac{\partial \tilde{S}}{\partial U} = \frac{1}{\pi}\frac{1}{(1+[U\pi]^2)}, with U the input potential of the operator. - A way to distinguish leaky operators that are outputs to others. This is necessary because an output neurons will only output its spikes, whereas a non-output neuron could output both its spikes and potential.
-
Testing can not be done in aidge core module, so it is done in a separate module.
Modified files
- include/aidge/graph/Context.hpp
- include/aidge/operator/MetaOperatorDefs.hpp
- include/aidge/scheduler/Scheduler.hpp
- include/aidge/aidge.hpp
- python_binding/data/pybind_Tensor.cpp
- python_binding/operator/pybind_MetaOperatorDefs.cpp
- python_binding/scheduler/pybind_Scheduler.cpp
- src/backend/generic/operator/MemorizeImpl.cpp
- src/backend/generic/operator/StackImpl.cpp
- src/graph/Context.cpp
- src/operator/MetaOperator/Leaky.cpp
- src/scheduler/Scheduler.cpp
- src/scheduler/SequentialScheduler.cpp
Detailed major modifications
Context
During the backward pass, we sometimes need to use the input value of an
operator to calculate its gradients. Since these values are not saved by
default in AIdge, I implemented a class class "Context", that holds a stack
that will store the successive inputs for each time step, and restore then
before each call to operator.backward()
.
To use this class, we have to create instances of it, which is handled by the scheduler.
output argument
It is also necessary to determine when a Leaky node is an output (for instance,
an output
argument also exists for snnTorch). In order to do so, an output
argument is added to the "Leaky()" constructor and wil in turn modify the
creation of a Leaky Node. For output Leaky nodes, we assume that we only want
to look at the spike output, and that we won't use the membrane potential output.
backward pass
The backpropagation through time algorithm is handled via the backward pass of the graph directly.
\frac{\partial \mathcal{L}}{\partial W} = \sum_{t=1}^T \Big( X[t] \times \sum_{i=t}^T \big( \beta^{i-t} \frac{\partial \mathcal{L}[i]}{\partial S[i]} \frac{\partial \tilde{S}[i]}{\partial U[i]} \big) \Big)
The memorize node is responsible for passing the value \beta \times \dfrac{\partial \mathcal{L}[i]}{\partial S[i]} \dfrac{\partial \tilde{S}[i]}{\partial U[i]} to the previous timestep.