Metadata-Version: 2.4
Name: tracetorch
Version: 0.4.0
Summary: A strict, ergonomic and "just works" Spiking Neural Network library for PyTorch.
Author-email: Yegor Men <yegor.mn@gmail.com>
License-Expression: MIT
Project-URL: Homepage, https://github.com/Yegor-men/tracetorch
Project-URL: Repository, https://github.com/Yegor-men/tracetorch
Project-URL: Issues, https://github.com/Yegor-men/tracetorch/issues
Project-URL: Documentation, https://yegor-men.github.io/tracetorch/
Classifier: Programming Language :: Python :: 3
Classifier: Operating System :: OS Independent
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch
Requires-Dist: numpy
Requires-Dist: matplotlib
Dynamic: license-file

![traceTorch Banner](media/tracetorch_banner.png)

[![License](https://img.shields.io/badge/License-MIT-purple.svg)](https://opensource.org/license/mit)
[![PyPI](https://img.shields.io/badge/PyPI-v0.4.0-blue.svg)](https://pypi.org/project/tracetorch/)

# traceTorch

A strict, ergonomic and "just works" Spiking Neural Network library for PyTorch.

traceTorch is designed to eliminate the boilerplate, shape errors, and gradient issues common in SNN development. It
treats spiking neurons as first-class PyTorch citizens that handle their own state, broadcasting, and constraints
automatically.

## Why traceTorch?

Existing SNN libraries often feel restrictive or require verbose state management. traceTorch follows a different
philosophy:

- **Automatic State Management:** No need to manually pass hidden states through `.forward()`, each layer manages its
  own hidden states, and calling `.zero_states()` on a traceTorch model recursively clears _all_ the hidden states the
  entire model uses, no matter how deeply hidden they are.
- **Lazy Initialization:** Hidden states are initialized as `None` and allocated dynamically based on the input shape.
  This completely eliminates "Batch Size Mismatch" errors during inference.
- **Smooth Constraints:** Parameters like decay ($\beta$) and thresholds are constrained via Sigmoid and Softplus
  respectively. No hard clamping, meaning that gradients flow smoothly and accurately everywhere.
- **Dimension Agnostic:** Whether you are working with `[Time, Batch, Features]` or `[Batch, Channels, Height, Width]`
  tensors, layers _just_ work. Change a single `dim` argument during layer initialization to indicate the target
  dimension the layer acts on. Defaults to `-1` for MLP, `-3` would work for CNN (channels are 3rd last in
  `[B, C, H, W]` or `[C, H, W]`).
- **Rank Based Parameters:** Instead of messy flags like `*_is_vector` or `is_shared`, traceTorch uses a single `*_rank`
  integer to define the parameter scope: 0 for a scalar (parameter is shared across the layer), 1 for a vector (
  per-neuron parameter), 2 for a matrix (dense all-to-all connections for recurrent layers).

## Installation

traceTorch is a PyPI library found [here](https://pypi.org/project/tracetorch/). Requirements are listed in
`requirements.txt`.

```bash
pip install tracetorch
```

## Quick Start

Making a traceTorch model is barely any different from PyTorch models. Here's how:

### 1. The "zero-boilerplate" module

Inherit from `tracetorch.snn.TTModule` instead of `pytorch.nn.Module`. This gives your model powerful recursive methods
like `zero_states()` and `detach_states()` for free, while still integrating with other PyTorch `nn.Module`.

```python
import torch
from torch import nn
import tracetorch as tt
from tracetorch import snn


class ConvSNN(snn.TTModule):
	def __init__(self):
		super().__init__()
		self.net = nn.Sequential(
			nn.Conv2d(1, 32, 3),
			# dim=-3 tells the layer that the 3rd-to-last dimension is the channel dim.
			# This works for (B, C, H, W) AND unbatched (C, H, W) inputs automatically.
			snn.LIF(num_neurons=32, beta=0.9, dim=-3),

			nn.Flatten(),
			nn.Linear(32 * 26 * 26, 10),

			# Readout layer with learnable scalar decay
			snn.Readout(num_neurons=10, beta=0.8, beta_rank=0)
		)

	def forward(self, x):
		return self.net(x)
```

### 2. The Training Loop

State management is easily handled outside the forward pass. Simply call `.zero_states()` on the model to reset all
hidden states to `None`, or call `.detach_states()` to detach the current hidden states (used in truncated BPTT or for
online learning).

```python
model = ConvSNN().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
loss_fn = tt.loss.soft_cross_entropy  # Handles non-onehot targets gracefully

# Training Step
for x, y in loader:
	x, y = x.cuda(), y.cuda()
	model.train()

	# Time loop
	spikes = []
	for step in range(num_timesteps):
		# Just pass x. No state tuples to manage.
		spikes.append(model(x))

	# Stack output and compute loss
	output = torch.stack(spikes)
	loss = loss_fn(output.mean(0), y)  # Rate coding example

	loss.backward()
	optimizer.step()
	optimizer.zero_grad()

	# Crucial: Reset hidden states for the next batch
	model.zero_states() 
```

## Documentation

The online documentation can be found [here](https://yegor-men.github.io/tracetorch/). It primarily focuses on how the
modules work on the backend, although there are a few tutorials there that recreate the code found in `examples/`.

## Authors

- [@Yegor-men](https://github.com/Yegor-men)

## Contributing

Contributions are always welcome. Feel free to fork, submit pull requests or report issues, I will occasionally check in
on it.
