Metadata-Version: 2.4
Name: jact
Version: 0.1.1
Summary: JAX-based transition probability computation for multi-state models
Author-email: Lucas Støjko Andersen <93032255+stojl@users.noreply.github.com>
License-Expression: Apache-2.0
Project-URL: Homepage, https://github.com/stojl/jact
Project-URL: Repository, https://github.com/stojl/jact
Project-URL: Issues, https://github.com/stojl/jact/issues
Project-URL: Documentation, https://github.com/stojl/jact/blob/main/docs/index.md
Keywords: jax,markov-model,multi-state-model,semi-markov,transition-probabilities
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax>=0.4
Requires-Dist: jaxlib>=0.4
Provides-Extra: dev
Requires-Dist: build>=1.2.0; extra == "dev"
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: pyright>=1.1.0; extra == "dev"
Requires-Dist: ruff>=0.4.0; extra == "dev"
Requires-Dist: twine>=5.0.0; extra == "dev"
Provides-Extra: notebook
Requires-Dist: matplotlib>=3.8; extra == "notebook"
Dynamic: license-file

# jact

JAX-based transition probability computation for multi-state models with duration-dependent transition intensities.

## What is jact?

`jact` computes transition probabilities in semi-Markov multi-state models. It takes fitted intensity models — parametric functions, GLMs, neural networks, or any JIT-compatible callable — and produces transition probabilities for thousands of individuals in a single vectorized pass on GPU. Computations are optimized for JIT-compiled GPU execution.

## Quick example

```python
import jax.numpy as jnp
import jact

# Define the state space
state_space = jact.StateSpace(
    states=["healthy", "disabled", "dead"],
    transitions=[
        ("healthy", "disabled"),
        ("healthy", "dead"),
        ("disabled", "dead"),
    ],
)

# Build a model with intensity functions
model = state_space.build(
    transitions={
        ("healthy", "disabled"): onset_fn,
        ("healthy", "dead"): mortality_fn,
        ("disabled", "dead"): disabled_mort_fn,
    }
)

# Compute transition probabilities for 1000 individuals
ages = jnp.linspace(30, 80, 1_000)
result = model.solve(initial="healthy", horizon=30, steps_per_unit=12, age=ages)
```

## Key features

- **Plug in any model**: Gompertz, GLM, neural network — anything that's JIT-compatible.
- **Swap and compare**: Same `StateSpace`, different intensity models. Experiment easily.
- **Compute only what's needed**: The solver reduces to states reachable from the initial state.
- **Exact seeded starts**: Initial point masses preserve per-individual starting duration `d_0` exactly.
- **Batch-first**: Designed for 100K+ individuals in a single pass.

## Documentation

See the [documentation index](https://github.com/stojl/jact/blob/main/docs/index.md)
for the public documentation set.
For the full API contract, use the
[API specification](https://github.com/stojl/jact/blob/main/docs/api_spec.md).
For a runnable walkthrough of the main workflow, see the
[example notebook](https://github.com/stojl/jact/blob/main/docs/example_notebook.ipynb).

## Namespace

The recommended user API is the top-level `jact` namespace:
`jact.StateSpace`, `jact.InitialDistribution`, `jact.solve`, and the
cashflow declarations such as `jact.StateRate` and `jact.Total`.
Advanced callback state objects remain available from submodules, for
example `jact.callbacks.PointMass` and `jact.model.ReducedModel`.

## Installation

```bash
pip install jax jaxlib
pip install jact
```

For local development from this repository:

```bash
pip install -e '.[dev]'
pytest
```

The package uses a `src/` layout, so editable install is the intended local
workflow.

To run the example notebook with plotting support from a local checkout:

```bash
pip install -e '.[dev,notebook]'
```

## Release checks

Before cutting a PyPI release:

```bash
rm -rf build dist src/*.egg-info
python -m build --no-isolation
python -m twine check dist/*
pytest tests/test_solver.py tests/test_initial_distribution_integration.py -q
```

The tag-driven publish flow is documented in [RELEASING.md](RELEASING.md).

## Requirements

- Python >= 3.10
- JAX >= 0.4

## License

Apache-2.0
