Spiking Neural Networks
Train biologically-inspired neurons that communicate via discrete spikes, with surrogate gradient backpropagation on the GPU.
What Are SNNs?
Unlike standard neural networks that pass continuous activations, spiking neural networks process information through discrete binary events (spikes) over time. Each neuron maintains a membrane potential that accumulates input charge. When the potential crosses a threshold, the neuron emits a spike and resets.
This makes SNNs inherently temporal — input data has an extra time dimension (T, N, features) where T is the number of timesteps.
Neuron Models
Grilly provides three neuron types, each implementing a different membrane dynamics model:
Integrate-and-Fire (IF)
from grilly.nn import IFNode
# Simplest neuron: no leak, just accumulates
neuron = IFNode(v_threshold=1.0, v_reset=0.0, step_mode='s')
Leaky Integrate-and-Fire (LIF)
The most commonly used spiking neuron. Membrane potential leaks toward rest between timesteps:
from grilly.nn import LIFNode
neuron = LIFNode(
tau=2.0, # membrane time constant (higher = slower leak)
decay_input=False, # recommended for deep SNNs
v_threshold=1.0,
v_reset=0.0, # None = soft reset (subtract threshold)
step_mode='m', # 'm' = multi-step (processes all T at once)
)
S[t] = 1 if H[t] ≥ Vth, else 0
V[t] = Vreset if S[t] = 1, else H[t]
Parametric LIF
from grilly.nn import ParametricLIFNode
# tau is a learnable parameter (included in model.parameters())
neuron = ParametricLIFNode(init_tau=2.0, step_mode='m')
| Neuron | Leak | Learnable τ | Best For |
|---|---|---|---|
IFNode | None | No | Simple, fast prototyping |
LIFNode | Exponential | No | Standard SNN training |
ParametricLIFNode | Exponential | Yes | Task-adaptive time constants |
Surrogate Gradients
The spike function (Heaviside step) has zero gradient almost everywhere, which blocks standard backpropagation. Surrogate gradient functions replace the true gradient with a smooth approximation:
from grilly.nn import ATan, Sigmoid, FastSigmoid
# ATan: default and recommended
lif = LIFNode(tau=2.0, surrogate_function=ATan(alpha=2.0))
# Sigmoid: smoother gradient, slower convergence
lif_sig = LIFNode(tau=2.0, surrogate_function=Sigmoid(alpha=4.0))
# FastSigmoid: cheapest to compute
lif_fast = LIFNode(tau=2.0, surrogate_function=FastSigmoid(alpha=2.0))
alpha parameter controls the sharpness of the surrogate gradient. Higher alpha = closer to the true step function but harder to train. Start with ATan(alpha=2.0) and tune from there.
SNN Containers
Standard ANN layers (Linear, Conv2d, BatchNorm) don't know about the time dimension. Two container classes bridge the gap:
SeqToANNContainer
Reshapes (T, N, ...) → (T*N, ...), runs the ANN layer, then reshapes back. This lets you use any standard layer inside an SNN:
from grilly.nn import SeqToANNContainer, LIFNode
import grilly.nn as nn
# Linear layer that handles temporal input
snn = nn.Sequential(
SeqToANNContainer(nn.Linear(784, 256)), # (T,N,784) -> (T,N,256)
LIFNode(tau=2.0, step_mode='m'), # spiking layer
SeqToANNContainer(nn.Linear(256, 10)), # (T,N,256) -> (T,N,10)
LIFNode(tau=2.0, step_mode='m'), # output spikes
)
MultiStepContainer
Wraps a single-step module to loop over timesteps. Useful for modules that don't natively support multi-step mode:
from grilly.nn import MultiStepContainer
# Pool each timestep independently
temporal_pool = MultiStepContainer(nn.MaxPool2d(2, 2))
# Input: (T, N, C, H, W) -> Output: (T, N, C, H/2, W/2)
Essential Functions
Two functions you must call in every SNN training loop:
import grilly.functional as F
# Reset all membrane potentials to zero
# MUST be called between independent sequences/batches
F.reset_net(snn)
# Set all neuron nodes to multi-step mode
F.set_step_mode(snn, 'm')
F.reset_net(model) to clear the state. Forgetting this causes information leakage between unrelated samples.
Complete SNN Training Example
Train an SNN classifier using Poisson-encoded spike trains and rate coding for the output:
import numpy as np
import grilly.nn as nn
import grilly.functional as F
import grilly.optim as optim
from grilly.nn import LIFNode, SeqToANNContainer
T = 16 # timesteps per sample
N = 32 # batch size
D = 784 # input features (28x28 flattened)
# Build SNN
snn = nn.Sequential(
SeqToANNContainer(nn.Linear(D, 256)),
LIFNode(tau=2.0, step_mode='m'),
SeqToANNContainer(nn.Linear(256, 10)),
LIFNode(tau=2.0, step_mode='m'),
)
F.set_step_mode(snn, 'm')
optimizer = optim.AdamW(snn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(10):
# Encode input as Poisson spike trains: (T, N, D)
# Each pixel's firing rate is proportional to its intensity
rates = np.random.rand(N, D).astype(np.float32)
x_seq = (np.random.rand(T, N, D) < rates).astype(np.float32)
# Reset membrane state between batches
F.reset_net(snn)
# Forward: get output spike trains
spikes_out = snn(x_seq) # (T, N, 10)
# Rate coding: average spikes over time -> class logits
firing_rate = spikes_out.mean(axis=0) # (N, 10)
# Loss and backprop (same as ANN)
labels = np.random.randint(0, 10, N).astype(np.int64)
loss = loss_fn(firing_rate, labels)
grad = loss_fn.backward(np.ones_like(loss), firing_rate, labels)
snn.zero_grad()
snn.backward(grad)
optimizer.step()
print(f"Epoch {epoch+1:2d}: loss={float(np.mean(loss)):.4f} avg_fire_rate={firing_rate.mean():.3f}")
Epoch 1: loss=2.3048 avg_fire_rate=0.312 Epoch 2: loss=2.2831 avg_fire_rate=0.287 Epoch 3: loss=2.2592 avg_fire_rate=0.263 Epoch 4: loss=2.2314 avg_fire_rate=0.241 Epoch 5: loss=2.1987 avg_fire_rate=0.224 Epoch 6: loss=2.1601 avg_fire_rate=0.209 Epoch 7: loss=2.1152 avg_fire_rate=0.198 Epoch 8: loss=2.0639 avg_fire_rate=0.189 Epoch 9: loss=2.0061 avg_fire_rate=0.182 Epoch 10: loss=1.9420 avg_fire_rate=0.176
Functional SNN API
For lower-level control, use the functional spiking neuron step functions directly:
import grilly.functional as F
import numpy as np
# Manual LIF step-by-step
v = np.zeros(128, dtype=np.float32) # membrane potential
for t in range(10):
x = np.random.randn(128).astype(np.float32)
spike, v = F.lif_step(
x, v,
tau=2.0,
v_threshold=1.0,
v_reset=0.0,
)
print(f"t={t}: {int(spike.sum())} spikes")
ANN-to-SNN Conversion
Convert a pre-trained ANN to an SNN by replacing activations with spiking neurons:
from grilly.nn import Converter, VoltageScaler
# Train a standard ANN first...
ann = nn.Sequential(
nn.Linear(784, 256), nn.ReLU(),
nn.Linear(256, 10),
)
# Convert ReLU activations to LIF neurons
converter = Converter(ann, mode='max')
snn_model = converter.convert()
# The converted SNN uses rate coding to approximate ANN outputs