Metadata-Version: 2.4
Name: mlx-bessel
Version: 0.1.0
Summary: GPU-accelerated spherical Bessel functions for Apple Silicon using MLX
Author-email: Sheng-Kai Huang <akai@fawstudio.com>
License-Expression: MIT
Project-URL: Homepage, https://github.com/akaiHuang/mlx-bessel
Keywords: bessel,spherical,gpu,mlx,apple-silicon,chebyshev
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: MacOS
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Scientific/Engineering :: Physics
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy
Requires-Dist: mlx>=0.20.0
Requires-Dist: scipy
Provides-Extra: benchmark
Requires-Dist: scipy; extra == "benchmark"
Provides-Extra: test
Requires-Dist: scipy; extra == "test"
Requires-Dist: pytest; extra == "test"
Dynamic: license-file

# mlx-bessel

GPU-accelerated spherical Bessel functions j_l(x) and j_l'(x) for Apple Silicon, using [MLX](https://github.com/ml-explore/mlx).

## What it does

Evaluates spherical Bessel functions of the first kind and their derivatives on Apple GPU via piecewise Chebyshev interpolation. The table is built once (using a hybrid forward-recurrence + scipy strategy on CPU), then stored as GPU tensors for fast repeated evaluation.

## Installation

```bash
pip install mlx-bessel
```

Requires macOS with Apple Silicon (M1/M2/M3/M4).

## Quick start

```python
import numpy as np
from mlx_bessel import BesselTable

ells = np.arange(0, 2001)               # multipole values
table = BesselTable(ells, x_max=5500)    # build table (~2s for 500 ells)
x = np.linspace(1.0, 5000.0, 10000)     # evaluation points

jl = table.eval_jl(x)                   # shape (2001, 10000), on GPU
jl, jlp = table.eval_jl_jlp(x)          # j_l and j_l' together
```

Results are returned as `mlx.core.array`. Convert to numpy with `np.array(jl)`.

## Performance

Benchmarked on Apple M1 Max. Median of 5 runs, warm-up excluded.

| N_ell | N_x   | scipy    | GPU eval | Speedup (eval) | Speedup (incl. build) |
|------:|------:|---------:|---------:|---------------:|----------------------:|
|   100 |  1000 |   0.05 s | 0.001 s  |            34x |                  0.8x |
|   100 |  5000 |   0.22 s | 0.004 s  |            57x |                  3.8x |
|   200 |  5000 |   1.20 s | 0.007 s  |           174x |                  3.9x |
|   200 | 10000 |   2.35 s | 0.013 s  |           181x |                  7.7x |
|   500 |  5000 |   8.95 s | 0.016 s  |           557x |                  1.7x |
|   500 | 10000 |  17.82 s | 0.031 s  |           567x |                  3.4x |
|   525 | 10000 |  19.63 s | 0.034 s  |           579x |                  3.1x |

The table build is a one-time cost (~0.05 s for 100 ells, ~5 s for 500 ells). Once built, subsequent evaluations at any x-array use GPU-only and achieve significant speedups, reaching over 500x for large problems.

## Accuracy

Tested against `scipy.special.spherical_jn` across l = 0..2000, x = 0.5..5000 (155 sampled ells, 5000 x-points):

| Metric                            | Value    |
|-----------------------------------|----------|
| Max absolute error                | 6.3e-07  |
| Median absolute error             | 1.1e-08  |
| Max relative error (\|j_l\| > 1e-5)  | 1.1e-02  |
| Median relative error             | 5.1e-05  |
| P99 relative error                | 2.4e-03  |

Float32 GPU precision limits relative accuracy near zero-crossings of j_l. For physically relevant values (|j_l| > 1e-5), the relative error is below 1.2%.

## Method

1. **Piecewise segments**: [0, x_max] is divided into segments of width ~80.
2. **Chebyshev nodes**: 64 Chebyshev nodes per segment.
3. **Hybrid table build** (CPU):
   - Forward recurrence for the stable regime (x > 1.5l)
   - scipy for the transition zone (x ~ l, ~14% of node pairs)
   - Zero for the evanescent regime (x << l)
4. **DCT to coefficients**: Discrete cosine transform converts node values to Chebyshev expansion coefficients.
5. **GPU evaluation**: Segment lookup + Chebyshev basis matrix multiply, fully vectorized on GPU.

## Running benchmarks

```bash
python -m mlx_bessel.benchmark
```

## Running tests

```bash
pip install pytest scipy
pytest tests/ -v
```

## Author

Sheng-Kai Huang (akai@fawstudio.com)

## License

MIT
