Metadata-Version: 2.4
Name: transformer_flows
Version: 0.0.8
Summary: Implementation of Apple ML's TARFlow in JAX.
Author-email: Jed Homer <jedhmr@gmail.com>
License: MIT License
        
        Copyright (c) [2024] [Jed Homer]
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Keywords: deep-learning,generative-models,jax,machine-learning,normalizing-flows
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Requires-Dist: beartype
Requires-Dist: datasets
Requires-Dist: einops
Requires-Dist: equinox
Requires-Dist: ipykernel>=6.29.5
Requires-Dist: jax
Requires-Dist: jaxtyping
Requires-Dist: matplotlib
Requires-Dist: ml-collections
Requires-Dist: optax
Requires-Dist: pip>=25.1.1
Requires-Dist: tqdm
Description-Content-Type: text/markdown

<h1 align='center'>Transformer flows</h1>

Implementation of Apple ML's Transformer Flow (or TARFlow) from [Normalising flows are capable generative models](https://arxiv.org/pdf/2412.06329) in `jax` and `equinox`.

Features:
- `jax.vmap` & `jax.lax.scan` construction & forward-pass, for layers respectively for fast compilation and execution,
- multi-device training, inference and sampling,
- score-based denoising step (see paper),
- conditioning via class embedding (for discrete class labels) or adaptive layer-normalisation (for continuous variables, like in DiT),
- array-typed to-the-teeth for dependable execution with `jaxtyping` and `beartype`.

To implement:
- [x] Guidance
- [x] Denoising
- [x] Mixed precision
- [x] EMA
- [x] AdaLayerNorm
- [x] Class embedding
- [x] Hyperparameter/model saving
- [x] Uniform and Gaussian noise for dequantisation

#### Usage 

```
pip install transformer-flows
```

```python
from transformer_flow import TransformerFlow
```

#### Samples

I haven't optimised anything here (the authors mention varying the variance of noise used to dequantise the images), nor have I trained for very long. You can see slight artifacts due to the dequantisation noise.

<p align="center">
  <picture>
    <img src="assets/mnist_warp.gif" alt="Your image description">
  </picture>
</p>

<p align="center">
  <picture>
    <img src="assets/cifar10_warp.gif" alt="Your image description">
  </picture>
</p>

#### Citation 

```bibtex
@misc{zhai2024normalizingflowscapablegenerative,
      title={Normalizing Flows are Capable Generative Models}, 
      author={Shuangfei Zhai and Ruixiang Zhang and Preetum Nakkiran and David Berthelot and Jiatao Gu and Huangjie Zheng and Tianrong Chen and Miguel Angel Bautista and Navdeep Jaitly and Josh Susskind},
      year={2024},
      eprint={2412.06329},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2412.06329}, 
}
```