Metadata-Version: 2.4
Name: umapjax
Version: 0.0.3
Summary: UMAP, but optimized with jax.
Project-URL: Homepage, https://github.com/adamgayoso/umapjax
Project-URL: Source, https://github.com/adamgayoso/umapjax
Author: Adam Gayoso
Maintainer-email: Adam Gayoso <adamgayoso@gmail.com>
License: BSD 3-Clause License
        
        Copyright (c) 2026, Adam Gayoso
        All rights reserved.
        
        Redistribution and use in source and binary forms, with or without
        modification, are permitted provided that the following conditions are met:
        
        1. Redistributions of source code must retain the above copyright notice, this
           list of conditions and the following disclaimer.
        
        2. Redistributions in binary form must reproduce the above copyright notice,
           this list of conditions and the following disclaimer in the documentation
           and/or other materials provided with the distribution.
        
        3. Neither the name of the copyright holder nor the names of its
           contributors may be used to endorse or promote products derived from
           this software without specific prior written permission.
        
        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
        AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
        IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
        DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
        FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
        DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
        SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
        CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
        OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
        OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
License-File: LICENSE
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Requires-Python: >=3.11
Requires-Dist: chex
Requires-Dist: jax
Requires-Dist: jaxlib
Requires-Dist: jaxtyping
Requires-Dist: umap-learn
Description-Content-Type: text/markdown

# umapjax

[![Tests][badge-tests]][tests]

[badge-tests]: https://img.shields.io/github/actions/workflow/status/adamgayoso/umapjax/test.yaml?branch=main

UMAP, but optimized with jax. (Experimental implementation)

`umapjax` inherits the API of [umap-learn](https://umap-learn.readthedocs.io/en/latest/). The `UmapJax` class is a drop-in replacement for `umap.UMAP`, with a few key differences:

1. `umapjax` does not support `densmap`.
2. `umapjax` does not support `output_metric` other than `euclidean`.

**Note:** `umapjax` does not fully replicate `umap-learn` and care should be used when interpreting results.

This package is intended to be used in combination with accelerated hardware like GPUs and TPUs. There is no benefit to using `umapjax` on a CPU.

## Getting started

```python
import umapjax

model = umapjax.UmapJax(n_neighbors=15)
embedding = model.fit_transform(X)
```

## Implementation details

The implementaion used in `umapjax` is very similar to the one used in [umap-learn](https://umap-learn.readthedocs.io/en/latest/); however, rather than a single step updating one single point, we update a set of points in parallel using jax. The gradients of the points are weighted by edge weights, which control sampling frequencies in the original algorithm. If results look strange, try changing `n_epochs` or `batch_size`. The `batch_size` argument can also be used to control acceleration on GPUs/TPUs.

## Installation

You need to have Python 3.11 or newer installed on your system.
If you don't have Python installed, we recommend installing [uv][].

There are several alternative options to install umapjax:

1. Install the latest release of `umapjax` from [PyPI][]:

```bash
pip install umapjax
```

2. Install the latest development version:

```bash
pip install git+https://github.com/adamgayoso/umapjax.git@main
```

## Release notes

See the [changelog][].

## Contact

If you found a bug, please use the [issue tracker][].

## Citation

> t.b.a

[uv]: https://github.com/astral-sh/uv
[scverse discourse]: https://discourse.scverse.org/
[issue tracker]: https://github.com/adamgayoso/umapjax/issues
[tests]: https://github.com/adamgayoso/umapjax/actions/workflows/test.yaml
[documentation]: https://umapjax.readthedocs.io
[changelog]: https://umapjax.readthedocs.io/en/latest/changelog.html
[api documentation]: https://umapjax.readthedocs.io/en/latest/api.html
[pypi]: https://pypi.org/project/umapjax
