Metadata-Version: 2.2
Name: hf_ehr
Version: 0.1.3
Summary: Code for Context Clues paper
Author: Michael Wornow
Project-URL: Homepage, https://github.com/som-shahlab/long_context_clues/
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: polars==1.1
Requires-Dist: transformers==4.48.2
Requires-Dist: trl==0.10.1
Requires-Dist: datasets==2.14.5
Requires-Dist: pandas==2.2
Requires-Dist: tensorboard==2.14.1
Requires-Dist: lightning==2.2.1
Requires-Dist: numpy==1.26.0
Requires-Dist: pytorch-nlp==0.5.0
Requires-Dist: wandb==0.15.11
Requires-Dist: loguru==0.7.2
Requires-Dist: omegaconf==2.3.0
Requires-Dist: jaxtyping==0.2.22
Requires-Dist: mlflow==2.11.1
Requires-Dist: seaborn==0.13.2
Requires-Dist: matplotlib==3.9.2
Requires-Dist: hydra-core==1.3.2
Requires-Dist: lightgbm==3.3.5
Requires-Dist: scikit-learn==1.2.2
Requires-Dist: scipy==1.10.1
Requires-Dist: tqdm==4.65.0
Requires-Dist: tabulate==0.9.0
Requires-Dist: huggingface_hub==0.26.5
Requires-Dist: meds==0.3.3
Requires-Dist: meds_reader==0.1.9
Requires-Dist: meds_etl==0.3.8
Requires-Dist: femr==0.1.16
Requires-Dist: torch

# Training Long Context Models on EHR Data

This repo contains code and pretrained models for the [**Context Clues paper**](https://arxiv.org/abs/2412.16178). It is designed to enable **training any model on HuggingFace on structured EHR data.** It comes with Hydra configs + Wandb logging + PyTorch Lightning distributed training support.

It currently supports EHR data defined using the [**MEDS data standard**](https://github.com/Medical-Event-Data-Standard/) or [**FEMR package**](https://github.com/som-shahlab/femr).

### 📖 Table of Contents

1. 🤗 [Pretrained HuggingFace Models](#models)
1. 📀 [Installation](#installation)
1. 🚀 [Quick Start](#quick_start)
1. 🏋️‍♀️ [Training](#training)
1. 📊 [Evaluation](#evaluation)
1. 💊 [MEDS Demo](#meds_demo)
1. ℹ️ [Other](#other)
1. 🎓 [Citation](#citation)

<a name="models" />

## 🤗 Pretrained HuggingFace Models

Please see our [HuggingFace Collection](https://huggingface.co/collections/StanfordShahLab/context-clues-models-6757f893f6a2918c7ab809f1) to download the following models pretrained from scratch on 2 billion tokens of deidentified structured EHR data:

| Model | Context Lengths |
| ----- | ------------- |
| gpt | [512](https://huggingface.co/StanfordShahLab/gpt-base-512-clmbr), [1024](https://huggingface.co/StanfordShahLab/gpt-base-1024-clmbr), [2048](https://huggingface.co/StanfordShahLab/gpt-base-2048-clmbr), [4096](https://huggingface.co/StanfordShahLab/gpt-base-4096-clmbr) |
| llama | [512](https://huggingface.co/StanfordShahLab/llama-base-512-clmbr), [1024](https://huggingface.co/StanfordShahLab/llama-base-1024-clmbr), [2048](https://huggingface.co/StanfordShahLab/llama-base-2048-clmbr), [4096](https://huggingface.co/StanfordShahLab/llama-base-4096-clmbr) |
| mamba | [1024](https://huggingface.co/StanfordShahLab/mamba-tiny-1024-clmbr), [4096](https://huggingface.co/StanfordShahLab/mamba-tiny-4096-clmbr), [8192](https://huggingface.co/StanfordShahLab/mamba-tiny-8192-clmbr), [16384](https://huggingface.co/StanfordShahLab/mamba-tiny-16384-clmbr) |
| hyena | [1024](https://huggingface.co/StanfordShahLab/hyena-large-1024-clmbr), [4096](https://huggingface.co/StanfordShahLab/hyena-large-4096-clmbr), [8192](https://huggingface.co/StanfordShahLab/hyena-large-8192-clmbr), [16384](https://huggingface.co/StanfordShahLab/hyena-large-16384-clmbr) |

<a name="installation" />

## 📀 Installation

Direct install:
```bash
pip install hf-ehr
```

For faster Mamba runs, install:
```bash
pip install mamba-ssm causal-conv1d
```

Development install:
```bash
conda create -n hf_env python=3.10 -y
conda activate hf_env
pip install -r requirements.txt --no-cache-dir
pip install -e .

# [Optional] If you haven't already created your **Tokenizers**, run the following. If you're on Carina, then skip this step.
cd hf_ehr/scripts/tokenizers
sbatch clmbr.sh # Takes ~5 seconds
sbatch desc.sh # Takes ~30 min
sbatch cookbook.sh # Takes many hours
```

<a name="quick_start"/>

## 🚀 Quick Start

Launch a GPT training run with the ability to configure common hyperparameters:

```bash
cd hf_ehr/scripts/carina
python3 main.py --model gpt2 --size base --tokenizer clmbr --context_length 1024 --dataloader approx --dataset v8 --is_run_local --is_force_refresh
```

Launch a Llama run on a MEDS dataset:
```bash
cd hf_ehr/scripts/carina
python3 main.py --model llama --size base --tokenizer clmbr --context_length 1024 --dataloader approx --dataset meds_mimic4_demo --is_run_local --is_force_refresh
```

To launch 4 GPT-base runs on one SLURM node (in parallel), and 4 Mamba runs on another SLURM node (in parallel):

```bash
cd hf_ehr/scripts/carina

# GPT runs
sbatch parallel_gpt.sh

# Mamba runs
sbatch parallel_mamba.sh
```

<a name="training" />

## 🏋️‍♀️ Training

We use [Hydra](https://github.com/facebookresearch/hydra) to manage our configurations and [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning) for training. 

You can either overwrite the config files in `configs/` or pass in CLI arguments to override the defaults.

There are 3 ways to launch a training run. 

### Easy Mode

Launch multiple runs in parallel on the same SLURM node  (each job gets 1 GPU) using `hf_ehr/scripts/carina/parallel_{model}.sh`:

```bash
cd hf_ehr/scripts/carina

# Launch 4 gpt runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_gpt.sh

# Launch 4 bert runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_bert.sh

# Launch 4 hyena runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_hyena.sh

# Launch 4 mamba runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_mamba.sh
```

### Medium Mode

Launch one run on a SLURM node using `hf_ehr/scripts/carina/{model}.sh`:

```bash
cd hf_ehr/scripts/carina

# Launch GPT-2 base model on v8 dataset with CLMBRTokenizer, ApproxBatchSampler dataloader, and 2048 context length; force train from scratch and not resume prior run (even if exists)
python3 main.py --model gpt2 --size base --tokenizer clmbr --context_length 2048 --dataloader approx --dataset v8 --is_force_refresh

# Launch Mamba tiny model on v8 dataset with CookbookTokenizer, ApproxBatchSampler dataloader, and 16384 context length; resume prior run if exists
python3 main.py --model mamba --size tiny --tokenizer cookbook --context_length 16384 --dataloader approx --dataset v8

# Launch BERT-base model on v8 dataset with DescTokenizer, ApproxBatchSampler dataloader, and 4096 context length; resume prior run if exists; overwrite the default device assignment to GPU 1; give wandb run a name of `custom`
python3 main.py --model bert --size base --tokenizer desc --context_length 4096 --dataloader approx --dataset v8 --extra "+trainer.devices=[1] logging.wandb.name=custom"

# Run locally a GPT-2 large model on v8 AllTokens dataset with CLMBRTokenizer, ApproxBatchSampler dataloader, and 1024 context length
python3 main.py --model gpt2 --size large --tokenizer clmbr --context_length 2048 --dataloader approx --dataset v8-alltokens --is_run_local

# Launch Mamba tiny model on v8 dataset with CookbookTokenizer, ApproxBatchSampler dataloader, and 16384 context length; resume prior run if exists; run on 8 H100's
python3 main.py --model mamba --size tiny --tokenizer cookbook --context_length 16384 --dataloader approx --dataset v8 --partitions nigam-h100 --extra "trainer=multi_gpu trainer.devices=[0,1,2,3,4,5,6,7]"
```

General usage:
```bash
python3 main.py --model <model> --size <size> --tokenizer <tokenizer> --context_length <context_length> --dataloader <dataloader> --dataset <dataset> [--extra <extra>] [--partitions <partitions>] [--is_force_refresh] [--is_skip_base] [--is_run_local]
```

where...
- `<model>`: str -- Architecture to use. Choices are `gpt`, `bert`, `hyena`, `mamba`
- `<size>`: str -- Model size to use. Choices are `tiny`, `small`, `base`, `medium`, `large`, `huge`
- `<tokenizer>`: str -- Tokenizer to use. Choices are `clmbr`, `desc`, `cookbook`
- `<context_length>`: int -- Context length to use
- `<dataloader>`: str -- Dataloader to use. Choices are `approx`, `exact`
- `<dataset>`: str -- Dataset to use. Choices are `v8`, `v8-alltokens`, `v9`, `v9-alltokens`
- `[--extra <extra>]`: Optional[str] -- An optional string that will get appended to the end of the `python ../run.py` command verbatim
- `[--partitions <partitions>]`: Optional[str] -- An optional string that specifies the partitions to use. Defaults to `nigam-v100,gpu` for gpt2 and BERT, and `nigam-h100,nigam-a100` for HYENA and MAMBA
- `[--is_force_refresh]`: Optional -- An optional flag that triggers a force refresh of the run (i.e., delete the existing run and start from scratch)
- `[--is_skip_base]`: Optional -- An optional flag that skips running `source base.sh`. Useful when running `parallel.sh` and we don't want to reinit the conda environment multiple times
- `[--is_run_local]`: Optional -- An optional flag that runs the script locally as `python run.py` instead of as a SLURM `sbatch` command

### Advanced Mode

Directly call `run.py`, which allows maximum flexibility for configs. 

See the [Config README](hf_ehr/configs/README.md) for details on all config settings.

```bash
cd hf_ehr/scripts/carina

# Launch gpt with: size=base, dataset=v8, context_length=2048, tokenizer=CLMBRTokenizer, sampler=ApproxBatchSampler, max_tokens_per_batch=16384, use_cuda_devices=2,3, wandb_logging_name=gpt2-custom-run, force_restart_existing_run=True, save_to_path=/share/pi/nigam/mwornow/hf_ehr/cache/runs/bert-test/
python3 ../run.py \
    +data=v8 \
    +trainer=single_gpu \
    +model=gpt2-base \
    +tokenizer=clmbr \
    data.dataloader.mode=approx \
    data.dataloader.approx_batch_sampler.max_tokens=16384 \
    data.dataloader.max_length=2048 \
    model.config_kwargs.n_positions=2048 \
    trainer.devices=[2,3] \
    logging.wandb.name=gpt2-custom-run \
    main.is_force_restart=True \
    main.path_to_output_dir=/share/pi/nigam/mwornow/hf_ehr/cache/runs/bert-test/
```

### How to Configure Runs

See the [Config README](hf_ehr/configs/README.md) for details on all config settings (models, training, dataloaders, tokenizers, etc.).


<a name="evaluation"/>
    
## 📊 Evaluation

### EHRSHOT

How to use this repo with EHRSHOT.

#### 1. Generate Patient Representations
This all occurs within the `hf_ehr` repo.

1. Identify the path (`<path_to_ckpt>`) to the model checkpoint you want to evaluate.

2. Generate patient representations with your model. This will create a folder in `/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/models` for this model checkpoint.

```bash
cd hf_ehr/scripts/eval/
sbatch ehrshot.sh <path_to_ckpt>
```

#### 2. Generate EHRSHOT Results

This all occurs within the `ehrshot-benchmark` repo.

1. Generate your model's AUROC/AUPRC results by running `7_eval.sh`:

```bash
# cd to ehrshot-benchmark/ehrshot/bash_scripts/ directory
bash 7_eval.sh --is_use_slurm
```

#### 3. Generate EHRSHOT Plots

This all occurs within the `ehrshot-benchmark` repo.

1. Generate plots by running: `8_make_results_plots.sh`. You might need to modify the `--model_heads` parameter in the file before running to specify what gets included in your plots.

```bash
# cd to ehrshot-benchmark/ehrshot/bash_scripts/ directory
bash 8_make_results_plots.sh
```

<a name="meds_demo"/>

## 💊 MEDS Demo

We support training and inference on [MEDS formatted datasets](https://github.com/Medical-Event-Data-Standard/meds/). 

Here is a quick tutorial using the publicly available **MIMIC-IV demo dataset** (inspired by [this tutorial](https://colab.research.google.com/drive/1R1LrDIzhQyWldQWM0lyfjeF_n9I_iZT3)).

1. **Download** the [MIMIC-IV demo dataset](https://physionet.org/content/mimiciv-demo/1.4/) from PhysioNet.

```bash
export PATH_TO_DOWNLOAD=mimic4_demo
export PATH_TO_MEDS=meds_mimic4_demo
export PATH_TO_MEDS_READER=meds_mimic4_demo_reader

!wget -q -r -N -c --no-host-directories --cut-dirs=1 -np -P $PATH_TO_DOWNLOAD https://physionet.org/files/mimic-iv-demo/2.2/
```

2. **Convert** the MIMIC-IV demo dataset to [**MEDS format**](https://github.com/Medical-Event-Data-Standard/meds/).

```bash
rm -rf $PATH_TO_MEDS 2>/dev/null
meds_etl_mimic $PATH_TO_DOWNLOAD $PATH_TO_MEDS
```

3. **Convert** the MEDS dataset into a [**MEDS Reader Database**](https://github.com/som-shahlab/meds_reader) (to enable faster data ingestion during training).

```bash
rm -rf $PATH_TO_MEDS_READER 2>/dev/null
meds_reader_convert $PATH_TO_MEDS $PATH_TO_MEDS_READER --num_threads 4
```

4. **Verify** everything worked.

```bash
meds_reader_verify $PATH_TO_MEDS $PATH_TO_MEDS_READER
```

5. **Create train/val/test splits** (80/10/10) by running this Python script:
```python
import meds_reader
import polars as pl
import os

database = meds_reader.SubjectDatabase(os.environ["PATH_TO_MEDS_READER"])
subject_ids = list(database)
splits = [
    ('train' if idx < 80 else 'tuning' if idx < 90 else 'held_out', subject_ids[idx])
    for idx in range(len(subject_ids))
]
df = pl.DataFrame(splits, schema=["split", "subject_id"])
df.write_parquet(os.path.join(os.environ["PATH_TO_MEDS_READER"], 'metadata', 'subject_splits.parquet'))
```

6. **Create** a **Hydra config** for your dataset.

```bash
cp hf_ehr/configs/data/meds_mimic4_demo.yaml hf_ehr/configs/meds/meds_mimic4_demo_custom.yaml
sed -i 's|/share/pi/nigam/mwornow/mimic-iv-demo-meds-reader|$PATH_TO_MEDS_READER|g' hf_ehr/configs/meds/meds_mimic4_demo_custom.yaml
```

7. **Train** a **Llama model** on the dataset.

```bash
cd hf_ehr/scripts/carina
python3 main.py --model llama --size base --tokenizer clmbr --context_length 1024 --dataloader approx --dataset meds_mimic4_demo_custom --is_run_local --is_force_refresh
``` 


<a name="other" />

## ℹ️ Other

### Llama
The **llama** model checkpoints we saved only work with transformers 4.44.2.

### Based
To get the **based** model to run, you need to do the following installations on an A100 or above node:

```bash
pip install -v \
    --disable-pip-version-check \
    --no-cache-dir \
    --no-build-isolation \
    --config-settings "--build-option=--cpp_ext" \
    --config-settings "--build-option=--cuda_ext" \
    'git+https://github.com/NVIDIA/apex@b496d85'  --no-cache-dir

pip install --no-cache-dir \
    torch==2.1.2 \
    torchvision==0.16.2 \
    torchaudio==2.1.2 \
    --index-url https://download.pytorch.org/whl/cu118 --no-cache-dir

# Install FLA triton kernel
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2' --no-build-isolation --no-cache-dir
pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/fused_dense_lib'  --no-build-isolation --no-cache-dir
pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/layer_norm' --no-build-isolation --no-cache-dir

git clone git@github.com:HazyResearch/based.git
cd based
pip install -e . --no-cache-dir
```

### 🤖 Creating a Model

Let's say we want to create a new model called `{model}` of size `{size}`.

1. Create the Hydra config YAML for your model architecture in `hf_ehr/configs/architecture/{model}.yaml`. Copy the contents of `hf_ehr/configs/architecture/bert.yaml` and modify as needed. 

2. Create the Hydra config YAML for your model instantiation in `hf_ehr/configs/models/{model}-{size}.yaml`. Copy the contents of `hf_ehr/configs/models/bert-base.yaml` and modify as needed.

3. Create the model itself by creating a new file `hf_ehr/models/{model}.py`. Copy the contents of `models/bert.py` and modify as needed.

4. Add your model to `hf_ehr/scripts/run.py` above the line `raise ValueError(f"Model `{config.model.name}` not supported.")`

### ✂️ Creating a Tokenizer

See the [Tokenizer README](hf_ehr/tokenizers/README.md) for details on creating tokenizers and how they are stored on the file system.

### 🤗 Uploading a Model to Hugging Face

See the [Hugging Face README](hf_ehr/scripts/huggingface/README.md) for details on uploading models to Hugging Face.

<a name="citation" />

### 📦 Package for PyPi

```bash
make release
```

## 🎓 Citation

If you found this work useful, please consider citing it:

```
@article{wornow2024contextclues,
      title={Context Clues: Evaluating Long Context Models for Clinical Prediction Tasks on EHRs}, 
      author={Michael Wornow and Suhana Bedi and Miguel Angel Fuentes Hernandez and Ethan Steinberg and Jason Alan Fries and Christopher Ré and Sanmi Koyejo and Nigam H. Shah},
      year={2024},
      eprint={2412.16178},
      url={https://arxiv.org/abs/2412.16178}, 
}
```
