Metadata-Version: 2.4
Name: mtl_uncertainty_loss
Version: 0.1.0
Summary: Uncertainty-weighted multi-task loss for PyTorch (Kendall et al. 2018)
Author-email: elna4os <shiriusu@ya.ru>
License: MIT
Project-URL: Homepage, https://github.com/elna4os/mtl_uncertainty_loss
Project-URL: Repository, https://github.com/elna4os/mtl_uncertainty_loss
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: numpy>=1.24.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Dynamic: license-file

# mtl_uncertainty_loss

Uncertainty-weighted multi-task loss for PyTorch, based on Kendall et al. (2018):
"Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics" ([arXiv:1705.07115](https://arxiv.org/abs/1705.07115))

## Features

- Arbitrary number of tasks
- Learnable uncertainty parameters (log σ per task)
- PyTorch `nn.Module` interface — `log_sigmas` are registered parameters and updated during training
- Easy to monitor uncertainties via `get_sigmas()`

## Installation

```bash
pip install mtl_uncertainty_loss
```

## Usage

```python
from mtl_uncertainty_loss import UncertaintyWeightedMultiTaskLoss

loss_fn = UncertaintyWeightedMultiTaskLoss(num_tasks=3)

# Pass already-reduced (scalar) per-task losses
total_loss = loss_fn([loss1, loss2, loss3])
total_loss.backward()

print(loss_fn.get_sigmas())  # tensor of current σ values
```

## Testing

```bash
pip install -e ".[dev]"
pytest tests
```

## License

MIT
