Metadata-Version: 2.4
Name: vbjax_dynamics
Version: 0.1.0
Summary: JAX-based integrators for ODEs, DDEs, SDEs, and SDDEs
Author-email: Abolfazl Ziaeemehr <a.ziaeemehr@gmail.com>
License: MIT
Project-URL: Homepage, https://github.com/ziaeemehr/vbjax_dynamics
Project-URL: Documentation, https://vbjax-dynamics.readthedocs.io
Project-URL: Repository, https://github.com/ziaeemehr/vbjax_dynamics
Keywords: jax,ode,dde,sde,sdde,integrator,dynamics
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax>=0.4.0
Requires-Dist: jaxlib>=0.4.0
Requires-Dist: numpy>=1.20.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: pytest-cov>=4.0; extra == "dev"
Requires-Dist: black>=23.0; extra == "dev"
Requires-Dist: flake8>=6.0; extra == "dev"
Requires-Dist: mypy>=1.0; extra == "dev"
Requires-Dist: scipy>=1.9.0; extra == "dev"
Requires-Dist: matplotlib>=3.5.0; extra == "dev"
Provides-Extra: docs
Requires-Dist: sphinx>=5.0; extra == "docs"
Requires-Dist: sphinx-rtd-theme>=1.0; extra == "docs"
Dynamic: license-file

# vbjax_dynamics

A JAX-based library for numerical integration of dynamical systems.

**Note:** This package contains code adapted from [vbjax](https://github.com/ins-amu/vbjax) by INS-AMU. The core integration functions in `loops.py` are derived from vbjax's implementation.

## Features

- **ODE Integration**: Ordinary Differential Equations
  - Efficient loop-based integrators with JIT compilation
  - Full support for `jax.vmap` for parallel trajectory computation
  
- **SDE Integration**: Stochastic Differential Equations  
  - `make_sde()`: Integration with pre-generated noise arrays
  - `make_sde_auto()`: Automatic noise generation from random keys
  - Euler-Maruyama scheme
  - Fully reproducible with random seeds

- **DDE Integration**: Delay Differential Equations
  - Support for fixed delays
  - History function interpolation

- **SDDE Integration**: Stochastic Delay Differential Equations
  - Combined stochastic and delay dynamics

- **Continuation Methods**: Parameter continuation for bifurcation analysis

- **Configuration Utilities**: Easy control over JAX settings
  - `configure_jax()`: Global configuration
  - `precision_context()`: Temporary precision changes
  - `print_jax_config()`: Diagnostic information

- **JAX-Native**: 
  - JIT compilation for speed
  - Automatic differentiation ready
  - GPU/TPU compatible
  - Pure functional approach

## Installation

```bash
pip install vbjax_dynamics
```

For development:
```bash
pip install -e ".[dev]"
```

## Quick Start

```python
import jax.numpy as jnp
from jax import random, vmap
from vbjax_dynamics.loops import make_sde_auto

# Define Ornstein-Uhlenbeck process
def drift(x, p):
    return -p[0] * x  # -theta * x

def diffusion(x, p):
    return p[1]  # sigma

# Create integrator
dt = 0.01
step, loop = make_sde_auto(dt, drift, diffusion)

# Single trajectory
x0 = 2.0
params = (1.0, 0.5)  # (theta, sigma)
n_steps = 1000
key = random.PRNGKey(42)

trajectory = loop(x0, n_steps, params, key)
print(f"Final value: {trajectory[-1]:.4f}")

# Multiple trajectories in parallel with vmap
n_traj = 100
keys = random.split(key, n_traj)
trajectories = vmap(lambda k: loop(x0, n_steps, params, k))(keys)
print(f"Mean: {jnp.mean(trajectories[:, -1]):.4f}")
```

For more examples, see the [`examples/`](examples/) directory.

## Documentation

- **[Examples and Tutorials](examples/README.md)**: Complete guide with detailed examples for ODE, SDE, DDE, and SDDE integration
- **[Testing Guide](tests/README.md)**: Information about running and writing tests
- **[Acknowledgments](ACKNOWLEDGMENTS.md)**: Credits and attribution

## Acknowledgments

This package includes code adapted from [vbjax](https://github.com/ins-amu/vbjax), developed by the Institut de Neurosciences de la Timone (INS-AMU). We are grateful for their work on efficient JAX-based numerical integrators.

## License

MIT License
