Metadata-Version: 2.4
Name: stablehlo-coreml
Version: 0.1.1
Summary: Convert StableHLO models into Apple Core ML format
Project-URL: Homepage, https://github.com/kasper0406/stablehlo-coreml
Project-URL: Issues, https://github.com/kasper0406/stablehlo-coreml/issues
Author-email: Kasper Nielsen <kasper0406@gmail.com>
License-File: LICENSE
Keywords: converter,coreml,coremltools,hlo,machinelearning,ml,neural,stablehlo,xla
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: End Users/Desktop
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Programming Language :: Python
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Software Development
Requires-Python: >=3.9
Requires-Dist: coremltools>=9.0; python_version >= '3.10' and python_version <= '3.13'
Requires-Dist: jax==0.9.1
Requires-Dist: numpy~=2.0
Description-Content-Type: text/markdown

# stablehlo-coreml

Convert [StableHLO](https://github.com/openxla/stablehlo) models into Apple Core ML format.

StableHLO is the portability layer used by ML frameworks like [JAX](https://github.com/jax-ml/jax) and [PyTorch](https://pytorch.org/). This library converts StableHLO programs into Apple's [Core ML](https://developer.apple.com/documentation/coreml) format via [coremltools](https://github.com/apple/coremltools), enabling deployment on Apple hardware (iOS, macOS, etc.).

## Installation

```bash
pip install stablehlo-coreml
```

Requires Python 3.9+ and targets iOS/macOS 18+.

## Supported Frameworks

Models can be exported from any framework that produces StableHLO:

- **JAX / Flax / Equinox** — via `jax.export`
- **PyTorch** — via [torchax](https://github.com/google/torchax) to trace the model into JAX, then `jax.export` to StableHLO

The test suite validates against a broad set of models, including full HuggingFace Transformers such as TinyLlama, T5, DistilBERT, GPT-2, BERT, and Whisper, as well as vision models like ResNet, EfficientNet, ViT, ConvNeXt, and more.

For a real-world example, see [gemma-coreml-chat](https://github.com/kasper0406/gemma-coreml-chat), which exports Google's Gemma 4 model to Core ML using this library.

## Converting a Model

To convert a StableHLO module:

```python
import coremltools as ct
from stablehlo_coreml.converter import convert
from stablehlo_coreml import DEFAULT_HLO_PIPELINE

mil_program = convert(hlo_module, minimum_deployment_target=ct.target.iOS18)
cml_model = ct.convert(
    mil_program,
    source="milinternal",
    minimum_deployment_target=ct.target.iOS18,
    pass_pipeline=DEFAULT_HLO_PIPELINE,
)
```

### Obtaining a StableHLO Module from JAX

```python
import jax
from jax._src.lib.mlir import ir
from jax._src.interpreters import mlir as jax_mlir
from jax.export import export

import jax.numpy as jnp

def jax_function(a, b):
    return jnp.einsum("ij,jk -> ik", a, b)

context = jax_mlir.make_ir_context()
input_shapes = (jnp.zeros((2, 4)), jnp.zeros((4, 3)))
jax_exported = export(jax.jit(jax_function))(*input_shapes)
hlo_module = ir.Module.parse(jax_exported.mlir_module(), context=context)
```

For the JAX example to work, you will additionally need to install `absl-py` and `flatbuffers` as dependencies.

## Dynamic / symbolic shapes

JAX models exported with symbolic dimensions are supported. Symbolic dims flow
through `GetDimensionSizeOp`, `DynamicBroadcastInDimOp`, `DynamicIotaOp`, and
shape-assertion `CustomCallOp`s automatically, producing CoreML models with
flexible inputs.

```python
import jax
import jax.numpy as jnp
from jax.export import export, symbolic_shape

jax_exported = export(jax.jit(jax_function))(
    jax.ShapeDtypeStruct(symbolic_shape("batch, 4"), jnp.float32),
    jax.ShapeDtypeStruct((4, 3), jnp.float32),
)
```

When converting to a CoreML model, specify `RangeDim` for each symbolic
dimension so the model accepts a range of sizes at inference time:

```python
cml_model = ct.convert(
    mil_program,
    source="milinternal",
    minimum_deployment_target=ct.target.iOS18,
    pass_pipeline=DEFAULT_HLO_PIPELINE,
    inputs=[
        ct.TensorType(name="_arg0", shape=(ct.RangeDim(1, 2048, 1), 4)),
        ct.TensorType(name="_arg1", shape=(4, 3)),
    ],
)
```

See [`tests/test_symbolic_shapes.py`](tests/test_symbolic_shapes.py) for
symbolic matmul, batched einsum, and multi-axis patterns (for example
transformer-style projections).


### Examples in the test suite

The [`tests/`](tests/) directory has end-to-end export and conversion examples:

- **PyTorch (torchax)** — [`tests/pytorch/test_pytorch.py`](tests/pytorch/test_pytorch.py): `export_to_stablehlo_module`, HuggingFace Transformers, and torchvision models.
- **JAX** — [`tests/test_jax.py`](tests/test_jax.py)
- **Flax / Equinox** — [`tests/test_flax.py`](tests/test_flax.py), [`tests/test_equinox.py`](tests/test_equinox.py)

## Development

* `coremltools` supports up to Python 3.13. Do not run hatch with a newer version.
  Can be controlled using e.g. `export HATCH_PYTHON=python3.13`
* Run tests using `hatch run test:pytest tests`
