Metadata-Version: 2.4
Name: mnistvit
Version: 1.3.0
Summary: A vision transformer for training on MNIST
Author-email: Arno Onken <asnelt@asnelt.org>
License-Expression: GPL-3.0-only
Project-URL: Homepage, https://github.com/asnelt/mnistvit
Project-URL: Bug Tracker, https://github.com/asnelt/mnistvit/issues
Keywords: vision,transformer,mnist
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.6
Requires-Dist: torchvision>=0.21
Provides-Extra: tune
Requires-Dist: ray[tune]>=2.50; extra == "tune"
Requires-Dist: optuna>=4.5; extra == "tune"
Dynamic: license-file

# Python package mnistvit

A PyTorch-only implementation of a vision transformer (ViT) for training on MNIST,
achieving 99.65% test accuracy with default parameters and without pre-training.
The ViT architecture and learning parameters can be configured easily.  Code for
hyperparameter optimization is provided as well.


## Requirements

The package requires Python 3.10 or greater and additionally requires the `torch` and
`torchvision` packages.  For hyperparameter optimization, additionally `ray[tune]` and
`optuna` are required.  The ViT itself requires `torch` only.


## Installation

To install the mnistvit package, run the following command in the parent directory of
the repository:

```
pip install mnistvit
```


## Usage

To train a model with default parameters:
```
python -m mnistvit.train
```

The script will produce a file `config.json` with the model configuration and file
`model.pt` containing the trained model.  Use the `-h` argument for a list of options.

To evaluate the test set accuracy of the model stored in `model.pt` with the
configuration in `config.json`:
```
python -m mnistvit.predict --use-accuracy
```

To predict the class of the digit stored in the file `sample.jpg`:
```
python -m mnistvit.predict --image-file sample.jpg
```

For hyperparameter optimization with default search parameters:
```
python -m mnistvit.tune
```


## License

mnistvit is released under the GPLv3 license, as found in the [LICENSE](LICENSE) file.
