Metadata-Version: 2.4
Name: srkan
Version: 0.1.1
Summary: Spiking Rational Kolmogorov-Arnold Network
Author-email: Author <author@example.com>
Project-URL: Homepage, https://github.com/author/srkan
Project-URL: Issues, https://github.com/author/srkan/issues
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Dynamic: license-file

<div align="center">

# 🧠 SRKAN
**Spiking Recurrent Kolmogorov-Arnold Networks**

A bio-inspired neural architecture that fuses the temporal expressiveness of Spiking Neural Networks with the function-learning power of Kolmogorov-Arnold Networks.

[Installation](#installation) | [Quick Start](#quick-start) | [Architecture](#architecture) | [Benchmarks](#benchmarks) | [API Reference](#api-reference)

</div>

## Why SRKAN?

* **Standard SNNs suffer from surrogate gradient bias** — the mismatch between forward spikes and backward gradients causes training instability in deep networks. 
* **Standard KANs have no temporal memory** — they process each timestep independently, making them blind to spike timing patterns.

SRKAN solves both problems simultaneously:
* **Saltation gradient** — a biologically-motivated surrogate that tracks sub-threshold membrane dynamics, reducing gradient mismatch
* **Dendritic KAN gate** — replaces linear synaptic weights with learnable B-spline functions, letting each synapse learn its own nonlinear transfer curve
* **AdaptiveLIF + EMA deadband homeostasis** — prevents firing-rate collapse in deep stacks (the key training stability fix)
* **Chunked parallel scan** — O(√T) memory complexity for long spike trains, enabling 28ms inference on standard GPU

## Installation

```bash
pip install srkan
```

**Requirements:** Python ≥ 3.9, PyTorch ≥ 2.0, pykan

## Quick Start

```python
import torch
from srkan import SRKAN

# Build model — 700 input channels (SHD cochlear), 20 classes
model = SRKAN(
    n_in      = 700,
    n_hidden  = 256,
    n_out     = 20,
    n_layers  = 4,
    T         = 100,       # timesteps
    grid      = 5,         # B-spline grid resolution
    chunk_size= 10         # chunked scan window
)

# Input: [Batch, Time, Channels] binary spike tensor
x = torch.randint(0, 2, (32, 100, 700)).float()

# Classification logits via membrane voltage readout
logits = model.readout_membrane(x).mean(dim=1)   # [B, n_out]
loss   = torch.nn.functional.cross_entropy(logits, targets)
```

**Note:** Use `model.readout_membrane(x).mean(dim=1)` for classification — not `model(x)`. The membrane voltage is differentiable everywhere and preserves full sub-threshold timing signal.

## Architecture

```text
Input Spikes [B, T, n_in]
        │
        ▼
  ┌─────────────────────────────────┐
  │   Pre-projection  (Linear)      │  784 → 64 (reduces KAN OOM risk)
  └─────────────────────────────────┘
        │
        ▼
  ┌─────────────────────────────────┐  ┐
  │  Modal Synapse  (B-spline KAN)  │  │
  │  Dendritic Gate (gated signal)  │  │  × n_layers
  │  AdaptiveLIF   (spike + EMA)    │  │
  │  Chunked Scan  (O(√T) memory)   │  │
  └─────────────────────────────────┘  ┘
        │
        ▼
  Readout Layer → Membrane Voltage [B, T, n_out]
        │  .mean(dim=1)
        ▼
  Logits [B, n_out]
```

### Key Components

| Component | Role | Innovation |
|---|---|---|
| ModalSynapse | B-spline synaptic transfer | Replaces linear weights with learnable nonlinear curves |
| DendriticGate | Gated signal mixing | Separates fast/slow dendritic compartments |
| SaltationGradient | Surrogate for backprop | Tracks sub-threshold dynamics, reduces bias vs. rectangle surrogate |
| AdaptiveLIF | Leaky Integrate-and-Fire | EMA-based threshold adaptation with deadband to prevent collapse |
| ChunkedScan | Temporal recurrence | Parallel scan in O(√T) memory vs. O(T) for naive RNN |

## Benchmarks

Evaluated on the **Spiking Heidelberg Digits (SHD)** dataset — 700-channel cochlear spike recordings, 20 spoken digit classes, T=100 timesteps.

### Accuracy

| Model | Test Accuracy | Parameters |
|---|---|---|
| KAN (feed-forward) | 18.9% | 132K |
| MLP | 49.6% | 18.2M |
| RNN | 70.2% | 520K |
| **SRKAN (ours)** | **83.4%** | **34.8K** |

SRKAN achieves **+33.8 points over MLP** with **500× fewer parameters**.

### Training Stability (Homeostasis Ablation)

| Variant | Best Acc | Final Acc | Behavior |
|---|---|---|---|
| Naive homeostasis | 50.6% | 31.6% | 19-pt collapse ❌ |
| EMA + deadband (SRKAN) | 52.8% | 48.8% | Stable ✅ |

The EMA deadband homeostasis is the **critical stability fix** — without it, deep SNN stacks collapse during training regardless of surrogate choice.

### Inference Speed

| Model | Latency (ms/sample) |
|---|---|
| KAN | 0.31 |
| RNN | 0.09 |
| MLP | 0.06 |
| SRKAN | 0.28 |

## API Reference

### `SRKAN(n_in, n_hidden, n_out, n_layers, T, grid, chunk_size)`

| Parameter | Type | Default | Description |
|---|---|---|---|
| `n_in` | int | — | Input channels (e.g. 700 for SHD) |
| `n_hidden` | int | 256 | Hidden layer width |
| `n_out` | int | — | Number of output classes |
| `n_layers` | int | 4 | Number of SRKAN blocks |
| `T` | int | 100 | Number of timesteps |
| `grid` | int | 5 | B-spline grid resolution |
| `chunk_size` | int | 10 | Chunked scan window (√T recommended) |

### Methods

```python
model.forward(x)             # → binary spikes [B, T, n_out]  — hidden layers only
model.readout_membrane(x)    # → membrane voltage [B, T, n_out] — use for classification
```

## Training Example (SHD)

```python
import torch
from srkan import SRKAN

model    = SRKAN(n_in=700, n_hidden=256, n_out=20, n_layers=4, T=100)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for spikes, labels in train_loader:           # spikes: [B, T, 700]
    logits = model.readout_membrane(spikes).mean(dim=1)
    loss   = torch.nn.functional.cross_entropy(logits, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
```

## Citation

If SRKAN is useful in your research, please cite:

```text
@software{srkan2026,
  title   = {SRKAN: Spiking Recurrent Kolmogorov-Arnold Networks},
  year    = {2026},
  url     = {https://pypi.org/project/srkan/},
  note    = {PyPI package}
}
```

## License

MIT — free for academic and commercial use.

<div align="center">Made with 🧬 and way too little sleep</div>
