Metadata-Version: 2.4
Name: mushin-py
Version: 0.1.0
Summary: Boilerplate-free, reproducible ML experiment workflows built on PyTorch Lightning and hydra-zen. Carved out of MIT-LL's responsible-ai-toolbox.
Author: Massachusetts Institute of Technology (original rai_toolbox.mushin authors)
License: MIT
Project-URL: Upstream (origin), https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox
Keywords: machine learning,reproducibility,pytorch-lightning,hydra,hydra-zen
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Classifier: License :: OSI Approved :: MIT License
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE.txt
Requires-Dist: pytorch-lightning>=1.5.0
Requires-Dist: hydra-zen>=0.9.0
Requires-Dist: hydra-core>=1.2.0
Requires-Dist: torch<2.3,>=1.13.0; sys_platform == "darwin" and platform_machine == "x86_64"
Requires-Dist: torch>=1.13.0; sys_platform != "darwin" or platform_machine != "x86_64"
Requires-Dist: numpy<2,>=1.24; sys_platform == "darwin" and platform_machine == "x86_64"
Requires-Dist: numpy>=2; sys_platform != "darwin" or platform_machine != "x86_64"
Requires-Dist: omegaconf>=2.1.1
Requires-Dist: xarray>=0.19.0
Requires-Dist: typing-extensions>=4.1.0
Requires-Dist: torchmetrics>=1.0
Requires-Dist: scipy>=1.10
Requires-Dist: pandas>=1.5
Provides-Extra: viz
Requires-Dist: matplotlib>=3.3; extra == "viz"
Provides-Extra: netcdf
Requires-Dist: netCDF4>=1.5.8; extra == "netcdf"
Dynamic: license-file

<p align="center">
  <picture>
    <source media="(prefers-color-scheme: dark)" srcset="logos/mushin-dark.png">
    <img src="logos/mushin-light.png" alt="mushin logo" width="200">
  </picture>
</p>

<h1 align="center">mushin</h1>

[![CI](https://github.com/martinez-hub/mushin/actions/workflows/ci.yml/badge.svg)](https://github.com/martinez-hub/mushin/actions/workflows/ci.yml)
[![PyPI](https://img.shields.io/pypi/v/mushin-py.svg)](https://pypi.org/project/mushin-py/)
[![Python versions](https://img.shields.io/pypi/pyversions/mushin-py.svg)](https://pypi.org/project/mushin-py/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE.txt)

Boilerplate-free, reproducible machine-learning experiment workflows built on
[PyTorch Lightning](https://lightning.ai/) and
[hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen).

`mushin` is a standalone carve-out of the `rai_toolbox.mushin` subpackage from
MIT Lincoln Laboratory's
[responsible-ai-toolbox](https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox).
The upstream toolbox is no longer maintained (last release May 2023), but the
`mushin` workflow layer still works against current versions of its
dependencies. This package extracts just that layer so it can be maintained and
used on its own.

## Quickstart: run a sweep, get a dataset

Define your experiment as a function, sweep over parameters, and get the results
back as a labeled `xarray.Dataset` — not rows in a dashboard you have to export.

```python
import torch as tr
from mushin import multirun
from mushin.workflows import MultiRunMetricsWorkflow

class LRSweep(MultiRunMetricsWorkflow):
    @staticmethod
    def task(lr: float, seed: int) -> dict:
        tr.manual_seed(seed)
        # ... train a model with this lr/seed ...
        return dict(accuracy=acc)  # whatever you return becomes a data variable

wf = LRSweep()
wf.run(lr=multirun([0.01, 0.1, 1.0]), seed=multirun([0, 1, 2]))  # 9 runs

ds = wf.to_xarray()
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
#   Data variables: accuracy (lr, seed)

ds["accuracy"].mean("seed")   # average over seeds, per learning rate
```

The full runnable version is in [`examples/sweep_to_dataset.py`](examples/sweep_to_dataset.py):

```bash
uv run python examples/sweep_to_dataset.py
```

## What it provides

- `BaseWorkflow`, `MultiRunMetricsWorkflow`, `RobustnessCurve` — declarative,
  reproducible experiment workflows that record configs, checkpoints, and
  metrics, and load results back as labeled `xarray` datasets.
- `MetricsCallback` — a Lightning callback for capturing metrics.
- `HydraDDP` — a Hydra/Lightning strategy for multi-GPU (DDP) launches.
- `multirun`, `hydra_list`, `load_experiment`, `load_from_checkpoint` — helpers.

## Install

From PyPI (the distribution is named `mushin-py`; you still `import mushin`):

```bash
pip install mushin-py
```

For a development environment (runtime deps + dev tooling), this project uses
[uv](https://docs.astral.sh/uv/):

```bash
uv sync
```

Optional runtime extras: `viz` (matplotlib, for `RobustnessCurve` plotting) and
`netcdf` (netCDF4) — e.g. `pip install "mushin-py[viz]"`.

## Develop

```bash
uv run pytest tests/ --hypothesis-profile fast   # tests (DDP test needs >=2 GPUs)
uv run ruff check .                              # lint
uv run ruff format .                             # format
uv run codespell src tests                       # spell check
```

Or use the `make` shortcuts (`make help` to list them): `make check` runs
lint + format-check + spell + tests (what CI runs); `make test-py PYTHON=3.12`
runs the suite on a specific Python version.

Supported Python versions: 3.9 – 3.14.

## Relationship to upstream

This is a fork/extraction, not a replacement endorsed by MIT-LL. The configuration
engine it depends on, `hydra-zen`, is actively maintained by the same group. See
`LICENSE.txt` for attribution; the original MIT copyright is retained.
