Metadata-Version: 2.4
Name: cassm
Version: 0.1.0
Summary: Computation-aware State-Space Models
Author-email: Jonathan Wenger <jw4246@columbia.edu>, Jonathan Huml <jrhuml@gmail.com>
License: MIT
Project-URL: github, https://github.com/jonathanhuml/cassm
Keywords: state-space-models,probabilistic-numerics,filtering
Platform: any
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Requires-Python: <3.12,>=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: geotorch
Requires-Dist: gpytorch
Requires-Dist: jaxtyping>=0.2.31
Requires-Dist: linear-operator
Requires-Dist: matplotlib>=3.4.3
Requires-Dist: numpy>=1.21.3
Requires-Dist: prettytable>=3.0
Requires-Dist: scipy>=1.7
Requires-Dist: torch>=2.1.2
Provides-Extra: dev
Requires-Dist: black; extra == "dev"
Requires-Dist: isort; extra == "dev"
Requires-Dist: pylint; extra == "dev"
Requires-Dist: pytest<8,>=4.6; extra == "dev"
Requires-Dist: pytest-cases<4,>=3.6.9; extra == "dev"
Dynamic: license-file

[![arXiv](https://img.shields.io/badge/arXiv-2606.01468-b31b1b.svg)](https://arxiv.org/abs/2606.01468)

# CASSM

![Lorenz Dynamics](./lorenz_dynamics.png)

Computation-aware state-space models for neural data.

CASSM provides PyTorch implementations of computation-aware filtering and
smoothing models for high-dimensional neural time series. The package includes
synthetic data utilities, model implementations, metrics, and a short tutorial
notebook for training on Lorenz-generated spike trains.

## Installation

Install from PyPI:

```bash
pip install cassm
```

Install the package in editable mode from the repository root:

```bash
pip install -e .
```

For development and tests, install the optional development dependencies:

```bash
pip install -e ".[dev]"
```

## Quick Start

```python
import torch
from torch.utils.data import DataLoader, TensorDataset

from cassm.datasets.synthetic_data import LorenzData
from cassm.models import CASSM
from cassm.utils.preprocessing import smooth_firing_rate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, valid_data, _, valid_truth, _, _ = LorenzData(
    num_inits=4,
    neurons=120,
    num_trials=4,
    device="cpu",
    seed=2,
)

train_data = smooth_firing_rate(train_data.numpy()).to(device)
valid_data = smooth_firing_rate(valid_data.numpy()).to(device)
valid_truth = valid_truth.to(device)

train_loader = DataLoader(train_data, batch_size=4, shuffle=True)
test_loader = DataLoader(TensorDataset(valid_data, valid_truth), batch_size=4)

model = CASSM(
    projection_dim=10,
    nneurons=train_data.shape[-1],
    timesteps=train_data.shape[1],
    dt=0.01,
    device=device,
    save_model=False,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
model.train_model(
    epochs=3,
    optimizer=optimizer,
    train_loader=train_loader,
    test_loader=test_loader,
    valid_truth=valid_truth,
    clip_value=300,
)

predicted_rate, predicted_noise = model.predict_rate(valid_data)
```

See [tutorials/cassm_lorenz_tutorial.ipynb](tutorials/cassm_lorenz_tutorial.ipynb)
for a fuller walkthrough.

## Repository Layout

- `src/cassm`: package source code
- `tests`: package tests
- `tutorials`: user-facing notebooks
- `pyproject.toml`: package metadata and tooling configuration

## Testing

```bash
pytest
```

## License

CASSM is released under the MIT License. See [LICENSE](LICENSE).
