GPU Neural Networks via Vulkan

Spiking Neural Networks

Train biologically-inspired neurons that communicate via discrete spikes, with surrogate gradient backpropagation on the GPU.

What Are SNNs?

Unlike standard neural networks that pass continuous activations, spiking neural networks process information through discrete binary events (spikes) over time. Each neuron maintains a membrane potential that accumulates input charge. When the potential crosses a threshold, the neuron emits a spike and resets.

This makes SNNs inherently temporal — input data has an extra time dimension (T, N, features) where T is the number of timesteps.

Why SNNs? SNNs are more biologically plausible than standard ANNs, naturally handle temporal/event data, and are highly efficient on neuromorphic hardware (Intel Loihi, IBM TrueNorth). Grilly's SNN framework is unique among GPU-accelerated libraries.

Neuron Models

Grilly provides three neuron types, each implementing a different membrane dynamics model:

Integrate-and-Fire (IF)

python
from grilly.nn import IFNode

# Simplest neuron: no leak, just accumulates
neuron = IFNode(v_threshold=1.0, v_reset=0.0, step_mode='s')

Leaky Integrate-and-Fire (LIF)

The most commonly used spiking neuron. Membrane potential leaks toward rest between timesteps:

python
from grilly.nn import LIFNode

neuron = LIFNode(
    tau=2.0,            # membrane time constant (higher = slower leak)
    decay_input=False,  # recommended for deep SNNs
    v_threshold=1.0,
    v_reset=0.0,        # None = soft reset (subtract threshold)
    step_mode='m',      # 'm' = multi-step (processes all T at once)
)
LIF Dynamics (decay_input=False) H[t] = V[t-1] × (1 − 1/τ) + X[t]
S[t] = 1  if  H[t] ≥ Vth,  else  0
V[t] = Vreset  if  S[t] = 1,  else  H[t]

Parametric LIF

python
from grilly.nn import ParametricLIFNode

# tau is a learnable parameter (included in model.parameters())
neuron = ParametricLIFNode(init_tau=2.0, step_mode='m')
NeuronLeakLearnable τBest For
IFNodeNoneNoSimple, fast prototyping
LIFNodeExponentialNoStandard SNN training
ParametricLIFNodeExponentialYesTask-adaptive time constants

Surrogate Gradients

The spike function (Heaviside step) has zero gradient almost everywhere, which blocks standard backpropagation. Surrogate gradient functions replace the true gradient with a smooth approximation:

python
from grilly.nn import ATan, Sigmoid, FastSigmoid

# ATan: default and recommended
lif = LIFNode(tau=2.0, surrogate_function=ATan(alpha=2.0))

# Sigmoid: smoother gradient, slower convergence
lif_sig = LIFNode(tau=2.0, surrogate_function=Sigmoid(alpha=4.0))

# FastSigmoid: cheapest to compute
lif_fast = LIFNode(tau=2.0, surrogate_function=FastSigmoid(alpha=2.0))
Tip The alpha parameter controls the sharpness of the surrogate gradient. Higher alpha = closer to the true step function but harder to train. Start with ATan(alpha=2.0) and tune from there.

SNN Containers

Standard ANN layers (Linear, Conv2d, BatchNorm) don't know about the time dimension. Two container classes bridge the gap:

SeqToANNContainer

Reshapes (T, N, ...)(T*N, ...), runs the ANN layer, then reshapes back. This lets you use any standard layer inside an SNN:

python
from grilly.nn import SeqToANNContainer, LIFNode
import grilly.nn as nn

# Linear layer that handles temporal input
snn = nn.Sequential(
    SeqToANNContainer(nn.Linear(784, 256)),  # (T,N,784) -> (T,N,256)
    LIFNode(tau=2.0, step_mode='m'),         # spiking layer
    SeqToANNContainer(nn.Linear(256, 10)),   # (T,N,256) -> (T,N,10)
    LIFNode(tau=2.0, step_mode='m'),         # output spikes
)

MultiStepContainer

Wraps a single-step module to loop over timesteps. Useful for modules that don't natively support multi-step mode:

python
from grilly.nn import MultiStepContainer

# Pool each timestep independently
temporal_pool = MultiStepContainer(nn.MaxPool2d(2, 2))
# Input: (T, N, C, H, W) -> Output: (T, N, C, H/2, W/2)

Essential Functions

Two functions you must call in every SNN training loop:

python
import grilly.functional as F

# Reset all membrane potentials to zero
# MUST be called between independent sequences/batches
F.reset_net(snn)

# Set all neuron nodes to multi-step mode
F.set_step_mode(snn, 'm')
Note Membrane potentials persist across forward calls within a sequence — that's how temporal information flows. But between independent sequences (e.g. different training batches), you must call F.reset_net(model) to clear the state. Forgetting this causes information leakage between unrelated samples.

Complete SNN Training Example

Train an SNN classifier using Poisson-encoded spike trains and rate coding for the output:

python
import numpy as np
import grilly.nn as nn
import grilly.functional as F
import grilly.optim as optim
from grilly.nn import LIFNode, SeqToANNContainer

T = 16    # timesteps per sample
N = 32    # batch size
D = 784   # input features (28x28 flattened)

# Build SNN
snn = nn.Sequential(
    SeqToANNContainer(nn.Linear(D, 256)),
    LIFNode(tau=2.0, step_mode='m'),
    SeqToANNContainer(nn.Linear(256, 10)),
    LIFNode(tau=2.0, step_mode='m'),
)
F.set_step_mode(snn, 'm')

optimizer = optim.AdamW(snn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    # Encode input as Poisson spike trains: (T, N, D)
    # Each pixel's firing rate is proportional to its intensity
    rates = np.random.rand(N, D).astype(np.float32)
    x_seq = (np.random.rand(T, N, D) < rates).astype(np.float32)

    # Reset membrane state between batches
    F.reset_net(snn)

    # Forward: get output spike trains
    spikes_out = snn(x_seq)         # (T, N, 10)

    # Rate coding: average spikes over time -> class logits
    firing_rate = spikes_out.mean(axis=0)  # (N, 10)

    # Loss and backprop (same as ANN)
    labels = np.random.randint(0, 10, N).astype(np.int64)
    loss = loss_fn(firing_rate, labels)
    grad = loss_fn.backward(np.ones_like(loss), firing_rate, labels)

    snn.zero_grad()
    snn.backward(grad)
    optimizer.step()

    print(f"Epoch {epoch+1:2d}: loss={float(np.mean(loss)):.4f}  avg_fire_rate={firing_rate.mean():.3f}")
Output
Epoch  1: loss=2.3048  avg_fire_rate=0.312
Epoch  2: loss=2.2831  avg_fire_rate=0.287
Epoch  3: loss=2.2592  avg_fire_rate=0.263
Epoch  4: loss=2.2314  avg_fire_rate=0.241
Epoch  5: loss=2.1987  avg_fire_rate=0.224
Epoch  6: loss=2.1601  avg_fire_rate=0.209
Epoch  7: loss=2.1152  avg_fire_rate=0.198
Epoch  8: loss=2.0639  avg_fire_rate=0.189
Epoch  9: loss=2.0061  avg_fire_rate=0.182
Epoch 10: loss=1.9420  avg_fire_rate=0.176

Functional SNN API

For lower-level control, use the functional spiking neuron step functions directly:

python
import grilly.functional as F
import numpy as np

# Manual LIF step-by-step
v = np.zeros(128, dtype=np.float32)  # membrane potential

for t in range(10):
    x = np.random.randn(128).astype(np.float32)
    spike, v = F.lif_step(
        x, v,
        tau=2.0,
        v_threshold=1.0,
        v_reset=0.0,
    )
    print(f"t={t}: {int(spike.sum())} spikes")

ANN-to-SNN Conversion

Convert a pre-trained ANN to an SNN by replacing activations with spiking neurons:

python
from grilly.nn import Converter, VoltageScaler

# Train a standard ANN first...
ann = nn.Sequential(
    nn.Linear(784, 256), nn.ReLU(),
    nn.Linear(256, 10),
)

# Convert ReLU activations to LIF neurons
converter = Converter(ann, mode='max')
snn_model = converter.convert()

# The converted SNN uses rate coding to approximate ANN outputs
Tip ANN-to-SNN conversion is useful for deploying to neuromorphic hardware without retraining from scratch. However, directly training the SNN (as shown above) typically achieves better accuracy.