Metadata-Version: 2.4
Name: picardax
Version: 0.0.2
Summary: Time-Implicit Differentiable Finite Difference Solvers in JAX.
Project-URL: repository, https://github.com/Ceyron/picardax
Author: Felix Koehler
License-File: LICENSE
Keywords: deep-learning,jax,pde,sciml
Requires-Python: ~=3.10
Requires-Dist: equinox>=0.11.3
Requires-Dist: jax>=0.4.13
Requires-Dist: jaxtyping>=0.2.20
Provides-Extra: docs
Requires-Dist: black==26.1.0; extra == 'docs'
Requires-Dist: griffe==1.15.0; extra == 'docs'
Requires-Dist: mkdocs-jupyter>=0.25.0; extra == 'docs'
Requires-Dist: mkdocs-material==9.7.1; extra == 'docs'
Requires-Dist: mkdocs==1.6.1; extra == 'docs'
Requires-Dist: mkdocstrings-python==2.0.1; extra == 'docs'
Requires-Dist: mkdocstrings==1.0.3; extra == 'docs'
Provides-Extra: gpu
Requires-Dist: jax[cuda12]>=0.8.0; extra == 'gpu'
Provides-Extra: test
Requires-Dist: jaxlib; extra == 'test'
Requires-Dist: pytest; extra == 'test'
Requires-Dist: pytest-cov; extra == 'test'
Description-Content-Type: text/markdown

<p align="center">
  <img src="docs/imgs/logo.svg" alt="picardax logo" width="200">
</p>
<h4 align="center">Time-Implicit Differentiable PDE Solvers in <a href="https://github.com/google/jax" target="_blank">JAX</a>.</h4>

<p align="center">
<a href="https://pypi.org/project/picardax/">
  <img src="https://img.shields.io/pypi/v/picardax.svg" alt="PyPI">
</a>
<a href="https://github.com/ceyron/picardax/actions/workflows/test.yml">
  <img src="https://github.com/ceyron/picardax/actions/workflows/test.yml/badge.svg" alt="Tests">
</a>
<a href="https://fkoehler.site/picardax/">
  <img src="https://img.shields.io/badge/docs-latest-green" alt="docs-latest">
</a>
<a href="https://github.com/ceyron/picardax/blob/main/LICENSE">
  <img src="https://img.shields.io/badge/license-MIT-blue" alt="License">
</a>
</p>

<p align="center">
  <a href="#installation">Installation</a> •
  <a href="#quickstart">Quickstart</a> •
  <a href="#design">Design</a> •
  <a href="#features">Features</a> •
  <a href="#documentation">Documentation</a> •
  <a href="#background">Background</a>
</p>

**🚨 This is a pre-release for me to cite in my dissertation. It is based on
code from the [PRDP paper](https://github.com/tum-pbs/PRDP) and [Emulator
Superiority Paper](https://github.com/tum-pbs/emulator-superiority) slightly
rewritten and enhanced. Please check back end of May 2026 for a full release
with extended features, documentation, and examples.**

`Picardax` is a sister project to
[Exponax](https://github.com/Ceyron/exponax) (which provides *explicit* and
exponential time steppers). Where Exponax sidesteps nonlinear solves entirely,
Picardax embraces them: it solves the implicit system arising from theta-method
(and similar) time discretizations using iterative nonlinear solvers (Picard
iteration, Newton's method) that themselves rely on iterative linear solvers
(CG, GMRES, BiCGStab). Built on [JAX](https://github.com/google/jax) and
[Equinox](https://github.com/patrick-kidger/equinox), every solver is
automatically differentiable, JIT-compilable, and GPU/TPU-ready.

## Installation

```bash
pip install picardax
```

Requires Python 3.10+ and JAX 0.4.13+. See the [JAX install guide](https://jax.readthedocs.io/en/latest/installation.html).

## Quickstart

Use a convenience stepper for common PDEs:

```python
import picardax
from picardax.periodic.stepper import Diffusion

stepper = Diffusion(
    num_spatial_dims=1,
    domain_extent=1.0,
    num_points=64,
    dt=0.01,
    diffusivity=0.01,
)

u_next = stepper(u)
```

Or assemble components manually for full control:

```python
from picardax import CG, Picard, Stepper, ThetaResiduum
from picardax.periodic import CollocatedDerivatives
from picardax.periodic.rhs import Diffusion

derivatives = CollocatedDerivatives(
    num_spatial_dims=1, num_points=64, domain_extent=1.0
)
rhs = Diffusion(derivatives=derivatives, diffusivity=0.01)
residuum = ThetaResiduum(rhs=rhs, theta=0.5, dt=0.01)
solver = Picard(linear_solver=CG(), maxiter=10, tol=1e-6)
stepper = Stepper(residuum=residuum, solver=solver)

u_next = stepper(u)

# Diagnose convergence
n_iters = stepper.diagnose(u)
errors = stepper.get_error_iterates(u)
```

## Design

TODO

## Features

TODO

## Documentation

Documentation is available at [fkoehler.site/picardax](https://fkoehler.site/picardax/).

## Background

TODO

## Related

TODO

## License

MIT, see [here](https://github.com/Ceyron/picardax/blob/main/LICENSE)

---

> [fkoehler.site](https://fkoehler.site/) &nbsp;&middot;&nbsp;
> GitHub [@ceyron](https://github.com/ceyron) &nbsp;&middot;&nbsp;
> X [@felix_m_koehler](https://twitter.com/felix_m_koehler) &nbsp;&middot;&nbsp;
> LinkedIn [Felix Koehler](https://www.linkedin.com/in/felix-koehler)
