Metadata-Version: 2.2
Name: jax-metallib
Version: 0.9.0
Summary: JAX backend for Apple M series of chips
Keywords: jax,metal,mps,apple,gpu,machine-learning
License: Apache-2.0
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: MacOS
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Project-URL: Homepage, https://github.com/erfanzar/jax-metallib
Project-URL: Repository, https://github.com/erfanzar/jax-metallib
Requires-Python: >=3.11
Requires-Dist: jax<0.10,>=0.9.0
Requires-Dist: jaxlib<0.10,>=0.9.0
Description-Content-Type: text/markdown

# jax-silicon

`jax-silicon` is a PJRT plugin that enables JAX to run on Apple Metal (MPS)
GPUs on Apple Silicon.

## Requirements

- macOS 13+ on Apple Silicon
- Python 3.11+
- `jax` and `jaxlib` 0.9.x

## Install

```bash
pip install jax-silicon
```

## Build From Source

```bash
brew install cmake ninja
uv pip install -e .
```

On first build, missing native dependencies are bootstrapped automatically by
running `scripts/setup_deps.sh`.

To disable auto-bootstrap:

```bash
CMAKE_ARGS="-DJAX_SILICON_AUTO_SETUP_DEPS=OFF" uv pip install -e .
```

## Use

The plugin backend name in JAX is still `mps`.

```bash
JAX_PLATFORMS=mps python -c "import jax; print(jax.devices())"
```

Optional library path overrides:

- `JAX_SILICON_LIBRARY_PATH`
- `JAX_MPS_LIBRARY_PATH` (legacy compatibility)

## Test

```bash
uv run pytest
```

## Repository Layout

- `src/jax_plugins/silicon/` Python plugin entrypoint
- `src/pjrt_plugin/` C++/Objective-C++ PJRT implementation
- `scripts/setup_deps.sh` dependency bootstrap
- `tests/` test suite

## License

Apache-2.0.

This repository is a derivative of `tillahoffmann/jax-mps`.
