Metadata-Version: 2.4
Name: mdot-tnt
Version: 1.2.0
Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
License-Expression: BSD-3-Clause
Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
Project-URL: Documentation, https://mdot-tnt.readthedocs.io
Project-URL: Repository, https://github.com/metekemertas/mdot_tnt
Project-URL: Issues, https://github.com/metekemertas/mdot_tnt/issues
Keywords: optimal-transport,sinkhorn,entropy-regularization,machine-learning,pytorch,gpu
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Operating System :: OS Independent
Classifier: Typing :: Typed
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: pytest-cov>=4.0; extra == "dev"
Requires-Dist: ruff>=0.1.0; extra == "dev"
Requires-Dist: pre-commit>=3.0; extra == "dev"
Requires-Dist: numpy>=1.20; extra == "dev"
Dynamic: license-file

# MDOT-TNT

<img src="assets/logo.png" alt="MDOT-TNT Logo" width="180" align="right"/>

**A Truncated Newton Method for Optimal Transport**

[![PyPI version](https://badge.fury.io/py/mdot-tnt.svg)](https://badge.fury.io/py/mdot-tnt)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/license-BSD--3--Clause-02B36C)](LICENSE)

A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.

<br clear="right"/>

## Features

- **High Precision**: Stable under extremely weak regularization  (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
- **GPU Accelerated**: Fully compatible with CUDA for fast computation on large problems
- **Multi-GPU Support**: Distribute computation across multiple GPUs via `devices` or `num_gpus` arguments for near-linear speedup
- **Batched Solving**: Solve multiple OT problems simultaneously in batched mode
- **Memory Efficient**: O(nk) working memory mode — never materialise the full cost or transport plan; process columns in blocks of size k
- **Point Cloud Support**: Pass source/target point clouds directly with a custom cost function; achieves true O(nk + (n+m)d) memory
- **PyTorch Native**: Seamless integration with PyTorch, supporting autograd-compatible inputs

## Color Transfer Example

OT naturally solves the color transfer problem: find the optimal map between two color palettes and use it to recolor an image. MDOT-TNT's low-memory point-cloud solver makes this practical for full-resolution images without ever materialising an `n × m` cost matrix.

<table>
  <tr>
    <th align="center">Source image (<code>1.webp</code>)</th>
    <th align="center">Palette donor (<code>2.webp</code>)</th>
    <th align="center">Result: 1 recolored with 2's palette</th>
  </tr>
  <tr>
    <td align="center"><img src="color_transfer/assets/1.webp" width="240"/></td>
    <td align="center"><img src="color_transfer/assets/2.webp" width="240"/></td>
    <td align="center"><img src="color_transfer/assets/out.png" width="240"/></td>
  </tr>
  <tr>
    <th align="center">Source image (<code>2.webp</code>)</th>
    <th align="center">Palette donor (<code>1.webp</code>)</th>
    <th align="center">Result: 2 recolored with 1's palette</th>
  </tr>
  <tr>
    <td align="center"><img src="color_transfer/assets/2.webp" width="240"/></td>
    <td align="center"><img src="color_transfer/assets/1.webp" width="240"/></td>
    <td align="center"><img src="color_transfer/assets/out_reverse.png" width="240"/></td>
  </tr>
</table>

Run the bundled example (from the repo root):

```bash
cd color_transfer
python run.py --both-directions
```

Key options:

| Flag | Default | Description |
|------|---------|-------------|
| `--gamma-f` | `1024` | OT precision (higher = sharper color match) |
| `--color-space` | `lab` | Working color space (`lab` or `rgb`) |
| `--max-pixels` | `8192` | Max pixels sampled per image |
| `--block-size` | `512` | Memory/speed trade-off for the low-mem solver |
| `--both-directions` | off | Also produce the reverse transfer |

## Installation

**Prerequisites**: Install [PyTorch](https://pytorch.org/get-started/locally/) for your system configuration first.

```bash
pip install mdot-tnt
```

For development:

```bash
git clone https://github.com/metekemertas/mdot_tnt.git
cd mdot_tnt
pip install -e ".[dev]"
```

## Quick Start

### Single Problem

```python
import torch
import mdot_tnt

device = "cuda" if torch.cuda.is_available() else "cpu"

# Create marginals (probability distributions)
n, m = 512, 512
r = torch.rand(n, device=device, dtype=torch.float64)
r = r / r.sum()
c = torch.rand(m, device=device, dtype=torch.float64)
c = c / c.sum()

# Cost matrix (e.g., pairwise distances)
C = torch.rand(n, m, device=device, dtype=torch.float64)

# Solve for optimal transport cost
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)

# Or get the full transport plan
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
```

### Batched Solving

When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:

```python
import torch
import mdot_tnt

device = "cuda"
batch_size, n, m = 32, 512, 512

# Multiple marginal pairs
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
r = r / r.sum(-1, keepdim=True)
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
c = c / c.sum(-1, keepdim=True)

# Shared cost matrix (or per-problem: shape [batch_size, n, m])
C = torch.rand(n, m, device=device, dtype=torch.float64)

# Solve all problems at once
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024)  # Returns (batch_size,) tensor
```

The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.

### Low-Memory / Point Cloud Solving

When `n` or `m` is large, use `solve_OT_lowmem` to avoid materialising the full `n × m` cost matrix. Cost blocks are computed on-the-fly, giving O(nk) working memory (k = `block_size`):

```python
import torch
from mdot_tnt import solve_OT_lowmem

device = "cuda" if torch.cuda.is_available() else "cpu"

# Point cloud mode — no n×m matrix is ever allocated
n, m, d = 10000, 10000, 64
X = torch.rand(n, d, device=device, dtype=torch.float64)
Y = torch.rand(m, d, device=device, dtype=torch.float64)
r = torch.ones(n, device=device, dtype=torch.float64) / n
c = torch.ones(m, device=device, dtype=torch.float64) / m

cost = solve_OT_lowmem(r, c, X=X, Y=Y, gamma_f=1024, block_size=512)

# Dense matrix mode — full C provided, but only k columns in memory at once
C = torch.rand(n, m, device=device, dtype=torch.float64)
cost = solve_OT_lowmem(r, c, C=C, gamma_f=1024, block_size=512)
```

The `block_size` parameter controls the memory / speed trade-off:
- `block_size = m` (default): fastest, equivalent to `solve_OT`
- `block_size = sqrt(m)`: good balance
- `block_size = 1`: minimum memory (slowest)

### Multi-GPU

All three solvers accept `devices` (list of GPU indices or `torch.device` objects) or `num_gpus` (int). Passing a single device falls back to the standard single-GPU path.

```python
from mdot_tnt import solve_OT, solve_OT_batched, solve_OT_lowmem

# Single large problem — columns distributed across GPUs
cost = solve_OT(r, c, C, gamma_f=1024, devices=[0, 1, 2])

# Batched — batch split across GPUs (near-linear speedup)
costs = solve_OT_batched(r, c, C, gamma_f=1024, num_gpus=4)

# Low-memory / point cloud — column blocks distributed across GPUs
#   (3.2× speedup at n=8192, m=65536, γ=1024 on 3× RTX 2080 Ti)
cost = solve_OT_lowmem(r, c, X=X, Y=Y, gamma_f=1024, devices=[0, 1, 2])
```

## API Reference

### `solve_OT`

```python
mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False,
                  devices=None, num_gpus=None)
```

| Parameter | Type | Description |
|-----------|------|-------------|
| `r` | `Tensor` | Row marginal of shape `(n,)`, must sum to 1 |
| `c` | `Tensor` | Column marginal of shape `(m,)`, must sum to 1 |
| `C` | `Tensor` | Cost matrix of shape `(n, m)`, recommended to normalize to [0, 1] |
| `gamma_f` | `float` | Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
| `return_plan` | `bool` | If True, return transport plan instead of cost |
| `round` | `bool` | If True, round solution onto feasible set |
| `log` | `bool` | If True, also return optimization logs |
| `devices` | `list` | GPU indices or `torch.device` objects to distribute across. Mutually exclusive with `num_gpus` |
| `num_gpus` | `int` | Number of GPUs to use (starting from index 0). Mutually exclusive with `devices` |

**Returns**: Transport cost (scalar) or plan `(n, m)`, optionally with logs dict.

### `solve_OT_batched`

```python
mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False,
                           devices=None, num_gpus=None)
```

Same parameters as `solve_OT`, but with batched inputs:
- `r`: Shape `(batch, n)`
- `c`: Shape `(batch, m)`
- `C`: Shape `(n, m)` for shared cost, or `(batch, n, m)` for per-problem costs

With `devices` or `num_gpus`, the batch is split across GPUs and each sub-batch is solved independently.

**Returns**: Costs `(batch,)` or plans `(batch, n, m)`.

### `solve_OT_lowmem`

```python
mdot_tnt.solve_OT_lowmem(r, c, C=None, X=None, Y=None, cost_fn=None,
                          gamma_f=1024., block_size=None,
                          drop_tiny=False, return_plan=False, round=True, log=False,
                          devices=None, num_gpus=None)
```

Exactly one of `C` or `(X, Y)` must be provided.

| Parameter | Type | Description |
|-----------|------|-------------|
| `r` | `Tensor` | Row marginal of shape `(n,)`, must sum to 1 |
| `c` | `Tensor` | Column marginal of shape `(m,)`, must sum to 1 |
| `C` | `Tensor` | Dense cost matrix `(n, m)`. Mutually exclusive with `X`/`Y` |
| `X` | `Tensor` | Source points `(n, d)`. Use together with `Y` |
| `Y` | `Tensor` | Target points `(m, d)`. Use together with `X` |
| `cost_fn` | `callable` | `(X, Y_block) -> C_block` of shape `(n, k)`. Defaults to squared Euclidean |
| `gamma_f` | `float` | Temperature parameter. Default: 1024 |
| `block_size` | `int` | Columns per block (controls memory/speed trade-off). Default: `m` |
| `drop_tiny` | `bool` | Drop tiny marginal entries for speedup with sparse marginals |
| `return_plan` | `bool` | If True, return transport plan instead of cost |
| `round` | `bool` | If True, round solution onto feasible set |
| `log` | `bool` | If True, also return optimization logs |
| `devices` | `list` | GPU indices or `torch.device` objects to distribute column blocks across |
| `num_gpus` | `int` | Number of GPUs to use (starting from index 0). Mutually exclusive with `devices` |

**Returns**: Transport cost (scalar) or plan `(n, m)`, optionally with logs dict.

## Performance Tips

1. **Use float64** for `gamma_f > 1024` (automatic conversion with warning)
2. **Normalize cost matrices** to [0, 1] for numerical stability
3. **Use batched solver** when solving multiple problems with shared structure
4. **Increase `gamma_f`** for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)
5. **Use multi-GPU** (`devices` / `num_gpus`) for large point-cloud problems or large batches — `solve_OT_lowmem` with column-parallel blocking achieves near-linear speedup (3.2× on 3 GPUs at n=8192, m=65536)

## Citation

If you use MDOT-TNT in your research, please cite:

```bibtex
@inproceedings{kemertas2025truncated,
  title={A Truncated Newton Method for Optimal Transport},
  author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=gWrWUaCbMa}
}
```

## License

This code is released under the [BSD 3-Clause license.](LICENSE).

## Contact

For questions or issues, please [open an issue](https://github.com/metekemertas/mdot_tnt/issues) or email: kemertas [at] cs [dot] toronto [dot] edu
