Metadata-Version: 2.4
Name: mosaic-harmonize
Version: 0.1.0
Summary: Multi-site Optimal-transport Shift Alignment with Interval Calibration for clinical data harmonization
Author-email: Peigen Chen <chenpg5@mail.sysu.edu.cn>
License-Expression: Apache-2.0
Project-URL: Homepage, https://github.com/chenpg2/mosaic-harmonize
Project-URL: Repository, https://github.com/chenpg2/mosaic-harmonize
Project-URL: Documentation, https://github.com/chenpg2/mosaic-harmonize#readme
Project-URL: Bug Tracker, https://github.com/chenpg2/mosaic-harmonize/issues
Keywords: harmonization,optimal-transport,conformal-prediction,multi-center,clinical-data,batch-effect,domain-adaptation
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Healthcare Industry
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Operating System :: OS Independent
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy>=1.24
Requires-Dist: pandas>=2.0
Requires-Dist: scipy>=1.10
Requires-Dist: scikit-learn>=1.3
Provides-Extra: boosting
Requires-Dist: lightgbm>=4.0; extra == "boosting"
Requires-Dist: anchorboosting>=0.3; extra == "boosting"
Provides-Extra: conformal
Requires-Dist: mapie>=1.0; extra == "conformal"
Provides-Extra: viz
Requires-Dist: matplotlib>=3.7; extra == "viz"
Requires-Dist: seaborn>=0.12; extra == "viz"
Provides-Extra: full
Requires-Dist: mosaic-harmonize[boosting,conformal,viz]; extra == "full"
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Requires-Dist: mosaic-harmonize[full]; extra == "dev"
Dynamic: license-file

# MOSAIC

**Multi-site Optimal-transport Shift Alignment with Interval Calibration**

[![PyPI](https://img.shields.io/pypi/v/mosaic-harmonize)](https://pypi.org/project/mosaic-harmonize/)
[![Python](https://img.shields.io/pypi/pyversions/mosaic-harmonize)](https://pypi.org/project/mosaic-harmonize/)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue)](LICENSE)
<!-- [![DOI](https://img.shields.io/badge/DOI-10.1038%2Fs41746--XXX-blue)](https://doi.org/10.1038/s41746-XXX) -->

MOSAIC is a Python package for harmonizing clinical tabular data collected across multiple sites. It combines 1-D optimal transport for distribution alignment, anchor regression for domain-robust prediction, and weighted conformal inference for uncertainty quantification. The three components can be used independently or chained through a single pipeline.

The package was developed for multi-center IVF (in vitro fertilization) outcome prediction, but the methods are general and apply to any multi-site clinical or biomedical dataset with batch effects.

## Overview

MOSAIC has three tiers, each usable on its own:

| Tier | Class | What it does |
|------|-------|-------------|
| 1. Harmonization | `OTHarmonizer` | Per-feature quantile-based optimal transport mapping to a reference distribution. Reduces cross-center distribution shift while preserving within-center rank order. |
| 2. Robust learning | `AnchorEstimator` | Wraps any sklearn estimator with anchor regression (via [anchorboosting](https://github.com/mlondschien/anchorboosting)) or V-REx reweighting. Penalizes predictions that rely on center-specific patterns. |
| 3. Uncertainty | `ConformalCalibrator` | Split conformal prediction with optional covariate-shift correction (Tibshirani et al., NeurIPS 2019). Produces prediction intervals (regression) or prediction sets (classification) with finite-sample coverage guarantees. |

`MOSAICPipeline` chains all three into a single `fit` / `predict` interface.

## Installation

Core (OT harmonization + anchor regression + conformal):

```bash
pip install mosaic-harmonize
```

With all optional dependencies (LightGBM, anchorboosting, MAPIE, matplotlib):

```bash
pip install mosaic-harmonize[full]
```

Individual extras: `boosting`, `conformal`, `viz`. For development: `dev`.

Requires Python 3.9+.

## Quick start

### Full pipeline

```python
from mosaic import MOSAICPipeline
from lightgbm import LGBMRegressor

pipe = MOSAICPipeline(
    harmonizer="ot",
    robust_learner="anchor",
    uncertainty="weighted_conformal",
    base_estimator=LGBMRegressor(),
)

# center_ids: array of site labels, one per row
pipe.fit(X_train, y_train, center_ids=train_centers)

result = pipe.predict(X_test, center_id="new_hospital")
print(result.prediction)       # point predictions
print(result.lower, result.upper)  # 90% prediction intervals
```

### Individual components

```python
from mosaic import OTHarmonizer, AnchorEstimator, ConformalCalibrator

# Tier 1: align distributions
ot = OTHarmonizer(n_quantiles=1000, reference="global")
X_harmonized = ot.fit_transform(X_train, center_ids=train_centers)

# Inspect shift reduction
print(ot.wasserstein_distances())
print(ot.feature_shift_report())

# Tier 2: train a domain-robust model
anchor = AnchorEstimator(base_estimator=LGBMRegressor(), task_type="regression")
anchor.fit(X_harmonized, y_train, anchors=train_centers)
print(f"Best gamma: {anchor.best_gamma_}")
print(f"Cross-center stability: {anchor.stability_score_:.3f}")

# Tier 3: calibrate with conformal prediction
cal = ConformalCalibrator(method="weighted", alpha=0.10)
cal.calibrate(anchor, X_cal, y_cal, X_test=X_test)
result = cal.predict(X_test)
print(f"Interval widths: {(result.upper - result.lower).mean():.2f}")
```

### Save and load

```python
pipe.save("model.mosaic")
pipe = MOSAICPipeline.load("model.mosaic")
```

### Register a new center at inference time

```python
pipe.register_center("hospital_B", X_new_center)
result = pipe.predict(X_query, center_id="hospital_B")
```

## API reference

### OTHarmonizer

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `n_quantiles` | int | 1000 | Number of quantile points for the OT map |
| `features` | list[str] or None | None | Columns to harmonize (None = all numeric) |
| `reference` | str | "global" | Reference distribution: "global" or a center name |
| `min_samples` | int | 50 | Minimum non-null samples to build a map |

Methods: `fit(X, center_ids)`, `transform(X, center_id=..., center_ids=...)`, `fit_transform(X, center_ids)`, `wasserstein_distances()`, `feature_shift_report()`.

### AnchorEstimator

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `base_estimator` | sklearn estimator or None | None | Base learner (None = Ridge/LogisticRegression) |
| `gammas` | list[float] or None | [1.5, 3.0, 7.0] | Anchor penalty strengths to search |
| `task_type` | str | "auto" | "auto", "regression", "binary", or "multiclass" |
| `n_vrex_rounds` | int | 5 | V-REx reweighting iterations (fallback mode) |

Methods: `fit(X, y, anchors, X_val=None, y_val=None)`, `predict(X)`, `predict_proba(X)`. Properties: `best_gamma_`, `stability_score_`.

### ConformalCalibrator

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `method` | str | "weighted" | "weighted", "standard", or "lac" |
| `alpha` | float | 0.10 | Miscoverage level (0.10 = 90% target coverage) |

Methods: `calibrate(model, X_cal, y_cal, X_test=None)`, `predict(X_test)` returning `ConformalResult`.

### MOSAICPipeline

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `harmonizer` | str or None | "ot" | "ot" or None |
| `robust_learner` | str or None | "anchor" | "anchor" or None |
| `uncertainty` | str or None | "weighted_conformal" | "weighted_conformal", "standard", "lac", or None |
| `base_estimator` | sklearn estimator | None | Base learner passed to AnchorEstimator |

Methods: `fit(X_train, y_train, center_ids, X_cal=None, y_cal=None)`, `predict(X, center_id=..., center_ids=...)`, `register_center(name, X_new)`, `diagnose(X, center_id)`, `save(path)`, `load(path)`.

## Benchmarks

Evaluated on a multi-center IVF dataset (334K rows, 5 centers, 15 prediction targets). Full results in `benchmarks/results/`.

### Ablation (Exp 1): each tier adds value

| Target | Baseline R² | +OT | +OT+Anchor | Full MOSAIC |
|--------|------------|-----|------------|-------------|
| HCG_Day_E2 | -0.665 | 0.127 | 0.229 | 0.229 |
| egg_num | 0.210 | 0.205 | 0.352 | 0.352 |
| HCG_Day_Endo | -0.775 | -0.177 | 0.120 | 0.120 |

OT corrects distribution shift (E2: R² from -0.67 to 0.13). Anchor regression adds further gains for regression targets (egg_num: 0.21 to 0.35).

### Cross-center generalization gap (Exp 2)

On an external test center unseen during training, MOSAIC reduces the validation-to-test performance gap by 42-76% for high-shift features (HCG_Day_E2: 75%, HCG_Day_P: 52%, HCG_Day_Endo: 72%).

### Conformal coverage (Exp 3)

All 11 regression targets achieve 81-93% empirical coverage at the 90% nominal level. Weighted conformal consistently produces narrower intervals than standard split conformal at comparable coverage.

### Comparison with existing methods (Exp 5)

| Feature | No harmonization | Z-score | ComBat | MOSAIC (OT) |
|---------|-----------------|---------|--------|-------------|
| HCG_Day_E2 (R²) | -0.665 | 0.069 | -0.383 | 0.127 |
| Clinical_pregnancy (AUC) | 0.838 | 0.837 | 0.837 | 0.839 |
| total_Gn (R²) | -0.121 | -0.327 | -0.034 | -0.022 |

MOSAIC outperforms Z-score and ComBat on high-shift features while maintaining comparable performance on low-shift targets.

## Citation

If you use MOSAIC in your research, please cite:

```
@article{chen2026mosaic,
  title={MOSAIC: Multi-site Optimal-transport Shift Alignment with Interval
         Calibration for Clinical Data Harmonization},
  author={Chen, Peigen},
  journal={npj Digital Medicine},
  year={2026},
  note={Manuscript in preparation}
}
```

## License

Apache-2.0. See [LICENSE](LICENSE) for details.
