Metadata-Version: 2.4
Name: diffsol-jax
Version: 0.0.1
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: POSIX :: Linux
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Rust
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Dist: equinox>=0.13.7
Requires-Dist: jax>=0.4.31
License-File: LICENSE
Summary: JAX bindings for the diffsol Rust ODE solver, callable inside jax.jit and differentiable with jax.grad.
Keywords: jax,ode,differential-equations,diffsol,scientific-computing,autodiff
Author-email: Jack Montgomery <jackbmontgomery2001@gmail.com>
License: MIT
Requires-Python: >=3.11
Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
Project-URL: Homepage, https://github.com/jackbmontgomery/diffsol-jax
Project-URL: Issues, https://github.com/jackbmontgomery/diffsol-jax/issues
Project-URL: Repository, https://github.com/jackbmontgomery/diffsol-jax

# diffsol-jax

JAX wrapper around [diffsol](https://github.com/martinjrobins/diffsol), a Rust ODE solver library.
Exposes diffsol's ODE solvers via `jax.ffi` so they can be called from inside `jax.jit` and
differentiated with `jax.grad`.

The user writes an RHS function in Python, which gets lowered to a
[diffsl](https://martinjrobins.github.io/diffsl/) source string and compiled on first call.

## Architecture

```
Python rhs fn  ->  DiffSL string  ->  XLA FFI call
                                          |
                                       C++ shim
                                          |
                                       diffsol
```

## Example usage

```python
import jax
import jax.numpy as jnp
from diffsol_jax import ODEProblem

def lotka_volterra(t, y, p):
    x, yy = y[0], y[1]
    alpha, beta, delta, gamma = p[0], p[1], p[2], p[3]
    return (alpha * x - beta * x * yy, delta * x * yy - gamma * yy)

params = jnp.array([1.5, 1.0, 0.75, 3.0])
y0     = jnp.array([1.0, 0.5])

problem = ODEProblem(lotka_volterra, y0=y0, params=params)

t_span = jnp.array([0.0, 10.0])
ts, ys = jax.jit(lambda p: problem.solve(p, t_span))(params)
# ts: float64[200],  ys: float64[200, 2]
```

## Limitations

- CPU only, f64 only.
- No vmap batching rule; `vmap_method="sequential"` gives correct results via a Python loop.
- The DiffSL lowerer handles elementwise ops and the common ODE patterns (Lotka-Volterra, Lorenz).
  Operations like `dot_general`, `reduce_sum`, and `concatenate` are not supported, so neural ODEs
  are not supported yet.
- Differentiation uses JAX-level forward sensitivities (no separate adjoint pass). DiffSL modules
  are JIT-compiled with Cranelift and cached per source string; the first call compiles, subsequent
  calls reuse the compiled module and update parameters via `set_params`.
- Does not work with `jax.pmap`
- Higher-order derivatives

