Metadata-Version: 2.3
Name: dnax
Version: 2024.0.0a1
Summary: DNA models implemented in JAX
Project-URL: Repository, https://github.com/gtca/dnax.git
Project-URL: Bug Tracker, https://github.com/gtca/dnax/issues
Author-email: Danila Bredikhin <danila@stanford.edu>
Maintainer-email: Danila Bredikhin <danila@stanford.edu>
License: BSD3
Keywords: ATAC,CNN,DNA,chromatin
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Requires-Python: >=3.10
Requires-Dist: flax
Requires-Dist: jax
Requires-Dist: optax
Provides-Extra: test
Requires-Dist: pytest; extra == 'test'
Requires-Dist: tensorflow; extra == 'test'
Requires-Dist: tensorflow-probability; extra == 'test'
Requires-Dist: torch; extra == 'test'
Description-Content-Type: text/markdown

<img src="./docs/img/logo.png" width=150/>

# dnax — DNA models in JAX

>[!WARNING]
>This is an experimental implementation.

`dnax` provides JAX-based implementation of models like [BPNet](https://github.com/kundajelab/bpnet), [ChromBPNet](https://github.com/kundajelab/chrombpnet), [DragoNNFruit](https://github.com/jmschrei/dragonnfruit).
The code is heavily based on original implementations ([chrombpnet](https://github.com/kundajelab/chrombpnet), [bpnet-lite](https://github.com/jmschrei/bpnet-lite))
however attempts to be more readable, accessible, and maintainable.

## Installation

```bash
pip install dnax

# or 

pip install git+https://github.com/gtca/dnax.git
```

## Usage

Vanilla ChromBPNet:

```python
from dnax.models.chrombpnet import ChromBPNet

bias = BPNet(n_filters=512, n_layers=8)
accessibility = BPNet(n_filters=512, n_layers=8)

model = ChromBPNet(bias, accessibility)

x = ...  # (batch, 2114, 4) tensor (1-hot)
profile, counts = model(x)
```

For inference, you can load existing ChromBPNet models:

```python
from dnax.io import load_chrombpnet_model

model = load_chrombpnet_model(bias_file, accessibility_file)

profile, counts = model(x)
```

## Implementation

`dnax` currently uses `flax` and its [NNX API](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html).

