Metadata-Version: 2.4
Name: mdot-tnt
Version: 1.0.0
Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
License: # Non-Commercial Research License (NCRL-1.0)
        
        Copyright (C) 2025 Mete Kemertas
        
        ## 1. License Grant
        Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to use, copy, modify, merge, publish, and distribute copies of the Software **solely for non-commercial research, educational, and personal purposes**, subject to the following conditions:
        
        ## 2. Restrictions
        ### 2.1 **Non-Commercial Use Only**
        - The Software **may NOT** be used for any commercial purpose without explicit written permission from the Licensor.
        - "Commercial purpose" includes, but is not limited to:
          - Selling or licensing the Software.
          - Using the Software in proprietary products or services.
          - Offering the Software as part of a paid or revenue-generating service.
        
        ### 2.2 **No Warranty & Liability**
        THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE SOFTWARE.
        
        ### 2.3 **Commercial Licensing**
        For commercial use, a separate license must be obtained from the Licensor. To inquire about licensing, please contact: **kemertas@cs.toronto.edu**.
        
        ## 3. Termination
        This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
        
        ## 4. Governing Law and Enforcement
        This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued under applicable copyright laws in the jurisdiction where infringement occurs.
        
        ## 5. Contact
        For licensing inquiries, please contact: **kemertas@cs.toronto.edu**.
        
Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
Project-URL: Documentation, https://mdot-tnt.readthedocs.io
Project-URL: Repository, https://github.com/metekemertas/mdot_tnt
Project-URL: Issues, https://github.com/metekemertas/mdot_tnt/issues
Keywords: optimal-transport,sinkhorn,entropy-regularization,machine-learning,pytorch,gpu
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Operating System :: OS Independent
Classifier: Typing :: Typed
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: pytest-cov>=4.0; extra == "dev"
Requires-Dist: ruff>=0.1.0; extra == "dev"
Requires-Dist: pre-commit>=3.0; extra == "dev"
Requires-Dist: numpy>=1.20; extra == "dev"
Dynamic: license-file

# MDOT-TNT

<img src="assets/logo.png" alt="MDOT-TNT Logo" width="180" align="right"/>

**A Truncated Newton Method for Optimal Transport**

[![PyPI version](https://badge.fury.io/py/mdot-tnt.svg)](https://badge.fury.io/py/mdot-tnt)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/license-Non--Commercial-green.svg)](LICENSE)

A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.

<br clear="right"/>

## Features

- **High Precision**: Stable under extremely weak regularization  (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
- **GPU Accelerated**: Fully compatible with CUDA for fast computation on large problems
- **Batched Solving**: Solve multiple OT problems simultaneously in batched mode
- **Memory Efficient**: Log-domain computations and efficient rounding avoid storing full transport plans
- **PyTorch Native**: Seamless integration with PyTorch, supporting autograd-compatible inputs

## Installation

**Prerequisites**: Install [PyTorch](https://pytorch.org/get-started/locally/) for your system configuration first.

```bash
pip install mdot-tnt
```

For development:

```bash
git clone https://github.com/metekemertas/mdot_tnt.git
cd mdot_tnt
pip install -e ".[dev]"
```

## Quick Start

### Single Problem

```python
import torch
import mdot_tnt

device = "cuda" if torch.cuda.is_available() else "cpu"

# Create marginals (probability distributions)
n, m = 512, 512
r = torch.rand(n, device=device, dtype=torch.float64)
r = r / r.sum()
c = torch.rand(m, device=device, dtype=torch.float64)
c = c / c.sum()

# Cost matrix (e.g., pairwise distances)
C = torch.rand(n, m, device=device, dtype=torch.float64)

# Solve for optimal transport cost
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)

# Or get the full transport plan
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
```

### Batched Solving

When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:

```python
import torch
import mdot_tnt

device = "cuda"
batch_size, n, m = 32, 512, 512

# Multiple marginal pairs
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
r = r / r.sum(-1, keepdim=True)
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
c = c / c.sum(-1, keepdim=True)

# Shared cost matrix (or per-problem: shape [batch_size, n, m])
C = torch.rand(n, m, device=device, dtype=torch.float64)

# Solve all problems at once
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024)  # Returns (batch_size,) tensor
```

The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.

## API Reference

### `solve_OT`

```python
mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
```

| Parameter | Type | Description |
|-----------|------|-------------|
| `r` | `Tensor` | Row marginal of shape `(n,)`, must sum to 1 |
| `c` | `Tensor` | Column marginal of shape `(m,)`, must sum to 1 |
| `C` | `Tensor` | Cost matrix of shape `(n, m)`, recommended to normalize to [0, 1] |
| `gamma_f` | `float` | Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
| `return_plan` | `bool` | If True, return transport plan instead of cost |
| `round` | `bool` | If True, round solution onto feasible set |
| `log` | `bool` | If True, also return optimization logs |

**Returns**: Transport cost (scalar) or plan `(n, m)`, optionally with logs dict.

### `solve_OT_batched`

```python
mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
```

Same parameters as `solve_OT`, but with batched inputs:
- `r`: Shape `(batch, n)`
- `c`: Shape `(batch, m)`  
- `C`: Shape `(n, m)` for shared cost, or `(batch, n, m)` for per-problem costs

**Returns**: Costs `(batch,)` or plans `(batch, n, m)`.

## Performance Tips

1. **Use float64** for `gamma_f > 1024` (automatic conversion with warning)
2. **Normalize cost matrices** to [0, 1] for numerical stability
3. **Use batched solver** when solving multiple problems with shared structure
4. **Increase `gamma_f`** for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)

## Citation

If you use MDOT-TNT in your research, please cite:

```bibtex
@inproceedings{kemertas2025truncated,
  title={A Truncated Newton Method for Optimal Transport},
  author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=gWrWUaCbMa}
}
```

## License

This code is released under a [non-commercial use license](LICENSE). For commercial licensing inquiries, please contact the authors.

## Contact

For questions or issues, please [open an issue](https://github.com/metekemertas/mdot_tnt/issues) or email: kemertas [at] cs [dot] toronto [dot] edu
