Metadata-Version: 2.3
Name: jax-optix
Version: 0.1.1
Summary: Zero-overhead functional lensing for JAX PyTrees
Project-URL: Homepage, https://github.com/jonkhler/optix
Author-email: Jonas Köhler <jonas.koehler.ks@gmail.com>
License-Expression: MIT
License-File: LICENSE.md
Keywords: equinox,functional,jax,lens,optics
Requires-Python: >=3.12
Requires-Dist: equinox>=0.11.8
Requires-Dist: jax>=0.4.35
Requires-Dist: pytest>=8.3.3
Description-Content-Type: text/markdown

# Optix 🔍

A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.

## Features

- Type-safe lenses for any JAX PyTree structure
- Zero runtime overhead (generates identical HLO code)
- Intuitive API for accessing and modifying nested values
- Complete static typing support

## Example

```python
from optix import focus
import jax.numpy as jnp

# Create a nested PyTree structure
data = MyStruct(
    x=jnp.array([1.0, 2.0]),
    nested=NestedStruct(y=jnp.array(3.0))
)

# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>>     x=Array([1., 2.], dtype=float32),
>>>     nested=NestedStruct(
>>>         y=Array(9., dtype=float32)
>>>     )
>>> )
```

## Installation

```bash
pip install optix
```

## License

MIT License

## Credits

Special thanks to [Patrick Kidger](https://kidger.site/) for providing helpful hints and the [Equinox](https://github.com/patrick-kidger/equinox) library, which this project builds upon.
