Metadata-Version: 2.4
Name: blocktrix
Version: 0.0.4
Summary: A JAX library for solving block tri-diagonal matrix systems
Author-email: Philip Mocz <philip.mocz@gmail.com>
License-Expression: Apache-2.0
Project-URL: Documentation, https://github.com/pmocz/blocktrix
Project-URL: Homepage, https://github.com/pmocz/blocktrix
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax==0.6.0
Provides-Extra: dev
Requires-Dist: ruff; extra == "dev"
Requires-Dist: pytest; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Requires-Dist: coverage-badge; extra == "dev"
Requires-Dist: sphinx-lint>=1.0.0; extra == "dev"
Requires-Dist: typos; extra == "dev"
Provides-Extra: cuda12
Requires-Dist: jax[cuda12]==0.6.0; extra == "cuda12"
Dynamic: license-file

# blocktrix

A JAX library for efficiently solving block tri-diagonal matrix systems

Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)

Implements both a block **Thomas** (serial) solver and a block cyclic reduction **B-cyclic** (parallel) solver.

⚠️ **Warning: Work in Progress**

This library is still under active development and is not guaranteed to work at this point XXX.


## Installation

```bash
pip install blocktrix
```

## Usage


```python
import jax
from blocktrix import solve_block_tridiagonal_bcyclic, random_block_tridiagonal

# Generate a random test system
key = jax.random.PRNGKey(42)
n_blocks, block_size = 8, 4

lower, diag, upper, rhs = random_block_tridiagonal(key, n_blocks, block_size)

# Solve the system
x = solve_block_tridiagonal_bcyclic(n_blocks, lower, diag, upper, rhs)
```

![Quickstart](examples/quickstart/solution.png)


## It's fast!

![Timing comparison](examples/timing/timing.png)

![Speedup](examples/timing/speedup.png)


## Links

* [Code repository](https://github.com/pmocz/blocktrix)
* [Documentation](https://github.com/pmocz/blocktrix)


## Cite this repository

If you use this software, please cite it as below.

XXX
