Efficient event-based delay learning

Myelination

The brain enables the organism to learn and change behavior to an ever-changing environement. Most studies try to understand the underpinnings of plasticy through synaptic weight changes. Myelin however is also plastic and there has been experimental evidence that shows that similary to synaptic strength, it is modulated by neuron activity. A great example of the benefit having a diverse set of delays in a network is that sequence detection can be done in quite simple circuits.

How could we study the benefits of myelin plascitcity in artificial neural networks? One option is introducing optimisable delays in Spiking Neural Networks.

Delay learning

We are the not the first to be interested in employing such mechanisms into machine learning setups. There has been a recent upsurge in interest since Hammouamri et. al. achieved SOTA results on spiking benchmarks after introducing temporal 1D dilated convolutions with learnable spacings (DCLS) as a delay learning algorithm.

A major drawback to their method lies in the ammeanability for neuromorphic hardware. While for inference discretising the whole network is not necessarily an issue, for training every timestep needs to be stored in the memory, which limits both networks size and temporal precisions and/or sequence length. This is an issue that is general to backpropagation through time (BPTT).

Wunderlich et. al. proposed to stay in continous time before calculating gradients, and Nowotny et. al. showed how this could be scaled up for more complex tasks. Recently, it has been even implemented on the SpINNaker2 hardware . This is thanks to the backwards pass having essentially identical computational/memory requirements as the forward pass. It seems like the perfect framework to introduce delays in!

Results

After extensive calculations , we end up with a similarly efficient delay learning update. Relying on the same ordinary differential equations, we can apply both synaptic weight and delay gradients, by sampling the correct terms only at spike times. We end up with a delay learning algorithm that for the first time allows for recurrent connections!

We test our implementations on the Yin-yang, Spiking Heidelberg Digits, Spiking Speech Commands, and Braille letter reading datasets. We find the delays are always a useful addition, and interestingly, they become particularly useful in small recurrent networks.

With the correct implementations this method becomes significantly more efficient than discretisaiton (i.e BPTT) based methods. With fixed network sizes and maximum delay timesteps, we outperform the DCLS based method both in terms of memory requirements in speed.

We also tested our method on Loihi 2 for inference. For this, we had to quantise our parameters and limit our maximum delay timesteps. While this lead to a slight performance decrease, in terms of energy usage, Loihi 2 significantly outperforms GPU based implementaitons even with delays.