Metadata-Version: 2.4
Name: jax-bounded-while
Version: 0.1
Summary: Efficiently get the index-0 element of an iterable.
Project-URL: Bug Tracker, https://github.com/GalacticDynamics/jax-bounded-while/issues
Project-URL: Changelog, https://github.com/GalacticDynamics/jax-bounded-while/releases
Project-URL: Discussions, https://github.com/GalacticDynamics/jax-bounded-while/discussions
Project-URL: Homepage, https://github.com/GalacticDynamics/jax-bounded-while
Author-email: Nathaniel Starkman <nstarman@users.noreply.github.com>
License: Copyright 2024 Nathaniel Starkman
        
        Permission is hereby granted, free of charge, to any person obtaining a copy of
        this software and associated documentation files (the "Software"), to deal in
        the Software without restriction, including without limitation the rights to
        use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
        of the Software, and to permit persons to whom the Software is furnished to do
        so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Classifier: Development Status :: 1 - Planning
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Classifier: Topic :: Scientific/Engineering
Classifier: Typing :: Typed
Requires-Python: >=3.10
Requires-Dist: equinox>=0.11.0
Requires-Dist: jax>=0.4.20
Description-Content-Type: text/markdown

<h1 align='center'> jax-bounded-while </h1>
<h3 align="center">Bounded while loop in JAX.</h3>

<p align="center">
    <a href="https://pypi.org/project/jax-bounded-while/"> <img alt="PyPI version" src="https://img.shields.io/pypi/v/jax-bounded-while" /> </a>
    <a href="https://pypi.org/project/jax-bounded-while/"> <img alt="PyPI platforms" src="https://img.shields.io/pypi/pyversions/jax-bounded-while" /> </a>
    <a href="https://github.com/GalacticDynamics/jax-bounded-while/actions"> <img alt="Actions status" src="https://github.com/GalacticDynamics/jax-bounded-while/workflows/CI/badge.svg" /> </a>
</p>

This is a micro-package, containing the single function `bounded_while_loop`.
</br> Reverse-mode-friendly, bounded `while_loop` implemented via `lax.scan`.

## Installation

<!-- [![Conda-Forge][conda-badge]][conda-link] -->

```bash
pip install jax-bounded-while
```

## Examples

Simple loop over a scalar:

```python
import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(x):
    return x < 5


def body_fn(x):
    return x + 1


result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result)  # Array(5, dtype=int32)
```

PyTree carry (tuple):

```python
import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(state):
    x, _ = state
    return x < 3


def body_fn(state):
    x, y = state
    return x + 1, y * 2


result = bounded_while_loop(
    cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result)  # (Array(3, dtype=int32), Array(8, dtype=int32))
```
