Metadata-Version: 2.4
Name: jax-neural-operators
Version: 0.2.2
Summary: Jax Neural Operators
Author-email: Leon Armbruster <leon.armbruster@iisb.fraunhofer.de>
License-Expression: EPL-2.0
Project-URL: Homepage, https://github.com/FhG-IISB/jno
Project-URL: Documentation, https://fhg-iisb.github.io/jNO_docs/
Project-URL: Source, https://github.com/FhG-IISB/jno
Project-URL: Issues, https://github.com/FhG-IISB/jno/issues
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Python: <3.14,>=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax[cuda]<0.11,>=0.10.1
Requires-Dist: equinox>=0.11.0
Requires-Dist: optax>=0.2.0
Requires-Dist: orbax-checkpoint>=0.6.0
Requires-Dist: foundax>=0.1.6
Requires-Dist: pygmsh>=7.0.0
Requires-Dist: cloudpickle>=3.0.0
Requires-Dist: einops>=0.7.0
Requires-Dist: matplotlib>=3.8.0
Requires-Dist: pylotte>=0.3
Requires-Dist: shapely>=2.0
Provides-Extra: cuda
Requires-Dist: jax[cuda]>=0.4.30; extra == "cuda"
Provides-Extra: fem
Requires-Dist: feax>=0.5.1; extra == "fem"
Requires-Dist: diffrax; extra == "fem"
Requires-Dist: optimistix; extra == "fem"
Requires-Dist: lineax; extra == "fem"
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: black>=22.0; extra == "dev"
Requires-Dist: flake8>=4.0; extra == "dev"
Requires-Dist: mypy>=0.950; extra == "dev"
Requires-Dist: feax>=0.5.1; extra == "dev"
Requires-Dist: diffrax; extra == "dev"
Requires-Dist: optimistix; extra == "dev"
Requires-Dist: lineax; extra == "dev"
Requires-Dist: paramax; extra == "dev"
Requires-Dist: flax; extra == "dev"
Requires-Dist: xprof; extra == "dev"
Requires-Dist: pylotte; extra == "dev"
Requires-Dist: nevergrad; extra == "dev"
Requires-Dist: lineax; extra == "dev"
Requires-Dist: pyfiglet; extra == "dev"
Requires-Dist: fenics-basix; extra == "dev"
Provides-Extra: iree
Requires-Dist: iree-base-compiler; extra == "iree"
Requires-Dist: iree-base-runtime; extra == "iree"
Dynamic: license-file

<p align="center">
  <img src="assets/logo.png" alt="jNO logo" width="500"/>
</p>

<p align="center">
    <a href="https://fhg-iisb.github.io/jNO/">
        <img src="https://img.shields.io/badge/docs-GitHub%20Pages-0aa?style=for-the-badge" alt="Dev Docs"/>
    </a>
    <a href="https://github.com/FhG-IISB/jno/actions/workflows/ci.yml">
        <img src="https://img.shields.io/github/actions/workflow/status/FhG-IISB/jno/ci.yml?branch=main&style=for-the-badge&label=tests" alt="Tests"/>
    </a>
    <a href="LICENSE">
        <img src="https://img.shields.io/badge/license-EPL--2.0-2ea44f?style=for-the-badge" alt="License"/>
    </a>
    <a href="CITATION.cff">
        <img src="https://img.shields.io/badge/cite-CITATION.cff-6b5b95?style=for-the-badge" alt="Citation"/>
    </a>
    <img src="https://img.shields.io/badge/docker-image%20available-2496ED?style=for-the-badge&logo=docker&logoColor=white" alt="Docker image available"/>
</p>

Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.


# Install

Quick install from PyPI:

```bash
pip install jax-neural-operators
```


Foundation models and other neural operators  are maintained in a separate repository ([foundax](https://github.com/FhG-IISB/foundax)) so they can also be used independently (foundax is installed automatically with this repository).

# Example

```python
import jno
import jax
import jax.numpy as jnp
import optax
import foundax
import equinox as eqx
from jno import LearningRateSchedule as lrs
from jno.numpy import tracker

dir = jno.setup("./runs/poisson2d")

# ── Domain — rect with named boundary sides ────────────────────────────────────
dom        = jno.domain(constructor=jno.domain.rect(mesh_size=0.05, x_range=(0, 1), y_range=(0, 1)))
x,  y,  _  = dom.variable("interior")
xl, yl, _  = dom.variable("left")    # x = 0  →  soft Dirichlet

# ── Network — LoRA adapters on hidden layers, 10× LR on output layer ───────────
net = jno.nn.wrap(
    foundax.mlp(in_features=2, hidden_dims=64, num_layers=4,
                activation=jnp.tanh, key=jax.random.PRNGKey(0))
)
net.lora(rank=4, alpha=1.0)                                            # parameter-efficient training
net.optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, 5_000, 1e-5))

all_false = jax.tree_util.tree_map(lambda _: False, net.module)
out_mask  = eqx.tree_at(lambda m: m.output_layer.weight, all_false, True)
net.mask(out_mask).lr(lrs.exponential(1e-2, 0.5, 5_000, 1e-4))       # output layer at 10× LR

# ── Forward pass — hard BCs on right (x=1), bottom (y=0), top (y=1) ───────────
π  = jno.np.pi
u  = net(jno.np.concat([x,  y ], axis=-1)) * (1 - x)  * y  * (1 - y)
ul = net(jno.np.concat([xl, yl], axis=-1)) * (1 - xl) * yl * (1 - yl)

# ── PDE: −∇²u = 2π²sin(πx)sin(πy),  exact u* = sin(πx)sin(πy) ───────────────
pde     = -(u.dd(x) + u.dd(y)) - 2 * π**2 * jno.np.sin(π * x) * jno.np.sin(π * y)
bc_left = ul                                          # soft: u(0, y) = 0

# ── Integral: ∫u dΩ → 4/π² ≈ 0.405 for the exact solution ────────────────────
vol_tracker = tracker(u.integrate(), interval=500)

# ── Gradient alignment — PDE vs. left-BC loss, output-layer params only ───────
J_pde      = pde.mse.grad(net.mask(out_mask))         # ∂L_pde / ∂θ_out
J_bc       = bc_left.mse.grad(net.mask(out_mask))     # ∂L_bc  / ∂θ_out
grad_align = tracker(jno.np.dot(J_pde, J_bc), interval=500)

# ── Solve ──────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, bc_left.mse, vol_tracker, grad_align], domain=dom)
crux.solve(20_000).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")
```

