Skip to content

Add partial support for Spiking Neural Networks backward in Aidge

Jerome Hue requested to merge jeromeh/aidge_core:skip-backward into dev

Context

Improve support for Spiking Neural Networks in Aidge.

This MR does multiple things :

  • Add objects that are necessary to perform the backward pass of such operators :
    • Context : A class that is used to save (before forward) and restore (before backward) the inputs of operators. This needed for instance in the Heaviside operator, which uses a surrogate gradient, defined as : \frac{\partial \tilde{S}}{\partial U} = \frac{1}{\pi}\frac{1}{(1+[U\pi]^2)}, with U the input potential of the operator.
    • A way to distinguish leaky operators that are outputs to others. This is necessary because an output neurons will only output its spikes, whereas a non-output neuron could output both its spikes and potential.

Testing can not be done in aidge core module, so it is done in a separate module.

Modified files

  • include/ai‎dge/graph/‎Conte‎xt.hpp‎
  • include/aid‎ge/operator/‎MetaOperat‎orDefs.hpp‎
  • include/aidg‎e/scheduler/‎Schedu‎ler.hpp‎
  • include‎/aidge/‎aidg‎e.hpp‎
  • python_bin‎ding/data/‎pybind_T‎ensor.cpp
  • python_bindi‎ng/operator/‎pybind_MetaOp‎eratorDefs.cpp‎
  • python_bindi‎ng/scheduler/‎pybind_Sch‎eduler.cpp‎
  • src/backend/ge‎neric/operator/‎Memorize‎Impl.cpp‎
  • src/backend/ge‎neric/operator/‎StackI‎mpl.cpp‎
  • src/g‎raph/‎Conte‎xt.cpp‎
  • src/operator/MetaOperator/Leaky.cpp
  • src/scheduler/Scheduler.cpp
  • src/scheduler/SequentialScheduler.cpp

Detailed major modifications

Context

During the backward pass, we sometimes need to use the input value of an operator to calculate its gradients. Since these values are not saved by default in AIdge, I implemented a class class "Context", that holds a stack that will store the successive inputs for each time step, and restore then before each call to operator.backward().

To use this class, we have to create instances of it, which is handled by the scheduler.

output argument

It is also necessary to determine when a Leaky node is an output (for instance, an output argument also exists for snnTorch). In order to do so, an output argument is added to the "Leaky()" constructor and wil in turn modify the creation of a Leaky Node. For output Leaky nodes, we assume that we only want to look at the spike output, and that we won't use the membrane potential output.

backward pass

The backpropagation through time algorithm is handled via the backward pass of the graph directly.

\frac{\partial \mathcal{L}}{\partial W} = \sum_{t=1}^T \Big(  X[t] \times \sum_{i=t}^T \big( \beta^{i-t} \frac{\partial \mathcal{L}[i]}{\partial S[i]} \frac{\partial \tilde{S}[i]}{\partial U[i]} \big) \Big)

The memorize node is responsible for passing the value \beta \times \dfrac{\partial \mathcal{L}[i]}{\partial S[i]} \dfrac{\partial \tilde{S}[i]}{\partial U[i]} to the previous timestep.

Edited by Jerome Hue

Merge request reports

Loading