Metadata-Version: 2.4
Name: nlls-gram
Version: 0.2.0
Summary: Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX
Project-URL: Homepage, https://github.com/HighDimensionalEconLab/nlls_gram
Project-URL: Repository, https://github.com/HighDimensionalEconLab/nlls_gram
Project-URL: Documentation, https://highdimensionaleconlab.github.io/nlls_gram/
Project-URL: Issues, https://github.com/HighDimensionalEconLab/nlls_gram/issues
Author-email: Jesse Perla <jesseperla@gmail.com>
License-Expression: MIT
License-File: LICENSE
Keywords: jax,levenberg-marquardt,nonlinear-least-squares,optimization
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Python: >=3.11
Requires-Dist: jax>=0.7.0
Description-Content-Type: text/markdown

# nlls_gram

[![CI](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/ci.yml/badge.svg)](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/ci.yml)
[![Docs](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/docs.yml/badge.svg)](https://highdimensionaleconlab.github.io/nlls_gram/)
[![PyPI](https://img.shields.io/pypi/v/nlls-gram.svg)](https://pypi.org/project/nlls-gram/)
[![Python versions](https://img.shields.io/pypi/pyversions/nlls-gram.svg)](https://pypi.org/project/nlls-gram/)
[![License: MIT](https://img.shields.io/github/license/HighDimensionalEconLab/nlls_gram)](https://github.com/HighDimensionalEconLab/nlls_gram/blob/main/LICENSE)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)

Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX.

`GramLevenbergMarquardt` minimizes `||r(params)||^2` for a user-supplied
`residual_fn(params, batch)`, where `params` is any JAX pytree (a flat array, a
dict, `nnx.state(model, nnx.Param)`, ...). It follows an `init`/`update` protocol:
`update(params, state, batch)` returns the **new params pytree** (same structure),
the next state, and an `LMInfo`. For overparameterized systems (many more
parameters `p` than residual rows `n`) it factors the small `n x n` gram (dual)
system instead of the `p x p` normal equations.

The solver depends only on `jax` — it knows nothing about `flax`/`nnx`/`optax`. It
performs no float casts: dtypes flow from your `params`/residual and JAX decides
`float32` vs `float64` via `jax_enable_x64`.

## Install

```bash
pip install nlls-gram
```

## Minimal example

Fit `y = a * exp(b * x)` to noise-free data generated from `(a, b) = (2, -1)`,
using a plain dict pytree of parameters:

```python
import jax
import jax.numpy as jnp

from nlls_gram import GramLevenbergMarquardt

jax.config.update("jax_enable_x64", True)


# residual_fn(params, batch) -> 1-D residual array; the solver minimizes its SSQ.
def residual_fn(params, batch):
    x, y = batch
    return params["a"] * jnp.exp(params["b"] * x) - y


x = jnp.linspace(0.0, 2.0, 20)
y = 2.0 * jnp.exp(-1.0 * x)

params = {"a": jnp.asarray(1.0), "b": jnp.asarray(0.0)}
solver = GramLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()


# The solver does not jit internally; wrap the train step yourself.
@jax.jit
def train_step(params, lm_state, batch):
    return solver.update(params, lm_state, batch)


for _ in range(50):
    params, lm_state, info = train_step(params, lm_state, (x, y))

print(params["a"], params["b"])  # ~2.0, ~-1.0
```

`params` can be any pytree. With Flax NNX, pass `nnx.state(model, nnx.Param)` as
`params` and write `residual_fn(state, batch)` using `nnx.merge`; the solver itself
stays NNX-agnostic.

## Filtering / freezing parameters

`update` optimizes exactly the `params` pytree you pass, so freezing is just
"pass fewer params": keep the frozen values in `residual_fn`'s closure and hand
the solver only the trainable subset. Frozen leaves get no Jacobian column and
never move — no `wrt`/masking argument needed.

```python
# Optimize only "a"; "b" is frozen at its current value.
frozen = {"b": jnp.asarray(-1.0)}


def residual_fn(trainable, batch):
    x, y = batch
    params = {**frozen, **trainable}  # frozen from the closure, trainable optimized
    return params["a"] * jnp.exp(params["b"] * x) - y


trainable = {"a": jnp.asarray(1.0)}
solver = GramLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()
for _ in range(50):
    trainable, lm_state, info = solver.update(trainable, lm_state, (x, y))
# trainable["a"] -> ~2.0; "b" stayed -1.0
```

With Flax NNX, split the model into frozen and trainable states with a filter and
merge them back inside `residual_fn`. `freeze_filter` is any nnx `Filter` (a type,
path, or predicate) picking the params to hold fixed; `...` captures the rest as
trainable:

```python
graphdef, frozen, trainable = nnx.split(model, freeze_filter, ...)


def residual_fn(trainable, batch):
    m = nnx.merge(graphdef, frozen, trainable)
    ...  # compute residuals from m


trainable, lm_state, info = solver.update(trainable, lm_state, batch)
new_model = nnx.merge(graphdef, frozen, trainable)
```

## Documentation

Full docs: https://highdimensionaleconlab.github.io/nlls_gram/

## License

MIT
