Metadata-Version: 2.4
Name: canvit-nnx
Version: 0.1.0
Summary: CanViT (Canvas Vision Transformer) -- JAX / Flax NNX
Author-email: Yohaï-Eliel Berreby <me@yberreby.com>
License-Expression: MIT
Requires-Python: >=3.12
Requires-Dist: flax>=0.12
Requires-Dist: huggingface-hub
Requires-Dist: jax>=0.9
Requires-Dist: jaxlib>=0.9
Requires-Dist: safetensors
Provides-Extra: convert
Requires-Dist: canvit-pytorch>=0.1.9; extra == 'convert'
Requires-Dist: torch>=2.0; extra == 'convert'
Provides-Extra: demo
Requires-Dist: pillow; extra == 'demo'
Provides-Extra: dev
Requires-Dist: canvit-pytorch>=0.1.9; extra == 'dev'
Requires-Dist: ipython; extra == 'dev'
Requires-Dist: pillow; extra == 'dev'
Requires-Dist: pypatree; extra == 'dev'
Requires-Dist: pytest; extra == 'dev'
Requires-Dist: torch>=2.0; extra == 'dev'
Requires-Dist: torchvision; extra == 'dev'
Provides-Extra: verify
Requires-Dist: canvit-pytorch>=0.1.9; extra == 'verify'
Requires-Dist: pillow; extra == 'verify'
Requires-Dist: torch>=2.0; extra == 'verify'
Description-Content-Type: text/markdown

# CanViT-NNX

> **Experimental.** Port of [CanViT](https://arxiv.org/abs/2603.22570) to [JAX](https://github.com/jax-ml/jax) using [Flax NNX](https://flax.readthedocs.io/). Reference implementation: [CanViT-PyTorch](https://github.com/m2b3/CanViT-PyTorch). May break at any time.

[JAX](https://github.com/jax-ml/jax) / [Flax NNX](https://flax.readthedocs.io/) implementation of the [Canvas Vision Transformer (CanViT)](https://arxiv.org/abs/2603.22570).

Checkpoints: [canvit/canvit-nnx](https://huggingface.co/collections/canvit/canvit-nnx-69d4864ff6ee17d87e7b91d4) on HuggingFace.

## Install

```bash
uv add canvit-nnx
```

We recommend [`uv`](https://docs.astral.sh/uv/) for dependency management.

## Usage

### Classification

```python
from canvit_nnx import CanViTForImageClassification, Viewpoint, sample_at_viewpoint

clf = CanViTForImageClassification.from_pretrained(
    "canvit/canvitb16-add-vpe-finetune-g128px-s512px-in1k-2026-04-06-nnx"
)
state = clf.init_state(batch_size=1, canvas_grid_size=32)

vp = Viewpoint.full_scene(batch_size=1)
glimpse = sample_at_viewpoint(spatial=image, viewpoint=vp, glimpse_size_px=128)
logits, state = clf(glimpse, state, vp)
```

### Pretrained backbone (dense features)

```python
import jax.numpy as jnp
from canvit_nnx import from_pretrained, Viewpoint, sample_at_viewpoint

model = from_pretrained(
    "canvit/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02-nnx"
)
state = model.init_state(batch_size=1, canvas_grid_size=32)

vp = Viewpoint.full_scene(batch_size=1)
glimpse = sample_at_viewpoint(spatial=image, viewpoint=vp, glimpse_size_px=128)
out = model(glimpse, state, vp)
state = out.state

# Canvas features should be layernormed before downstream use (PCA, probing, etc.)
canvas = model.get_spatial(state.canvas)
mean = canvas.mean(axis=-1, keepdims=True)
canvas = (canvas - mean) / jnp.sqrt(canvas.var(axis=-1, keepdims=True) + 1e-5)
```

## Structure

```
canvit_nnx/
  model.py            — CanViT architecture (self-contained, no external model deps)
  classification.py   — CanViTForImageClassification
  viewpoint.py        — sample_at_viewpoint (bilinear glimpse extraction)
  hub.py              — HuggingFace Hub loading (JAX-native safetensors, no torch)

scripts/
  demo.py                  — classification demo
  verify.py                — cross-framework numerical verification
  convert_from_pytorch.py  — one-time PyTorch -> JAX checkpoint conversion
  push_to_hub.py           — convert + push to HuggingFace
```

## Tests

```bash
uv run pytest tests/ -v
```

## Verify against PyTorch

```bash
uv run --extra verify python scripts/verify.py
uv run --extra verify python scripts/verify.py --image path/to/image.jpg
```

## Demo

```bash
uv run --extra demo python scripts/demo.py --image path/to/image.jpg
```

## Citation

```bibtex
@article{berreby2026canvit,
  title={CanViT: Toward Active-Vision Foundation Models},
  author={Berreby, Yoha{\"i}-Eliel and Du, Sabrina and Durand, Audrey and Krishna, B. Suresh},
  year={2026},
  eprint={2603.22570},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```
