Metadata-Version: 2.4
Name: threshopt
Version: 1.1.1
Summary: Automatic threshold optimization for classifiers.
Author-email: Salvatore Zizzi <salvo.zizzi@gmail.com>
License-Expression: Apache-2.0
Project-URL: Homepage, https://github.com/salvo-zizzi/threshopt
Classifier: Programming Language :: Python :: 3
Classifier: Operating System :: OS Independent
Classifier: Intended Audience :: Developers
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: scikit-learn
Requires-Dist: matplotlib
Requires-Dist: numpy
Dynamic: license-file

![Logo](https://raw.githubusercontent.com/Salvo-zizzi/threshopt/main/images/logo.png)

[![PyPI version](https://img.shields.io/pypi/v/threshopt?cacheSeconds=1)](https://pypi.org/project/threshopt/)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0)
![GitHub last commit](https://img.shields.io/github/last-commit/salvo-zizzi/threshopt)

**Threshold Optimization Library for Binary and Multiclass Classification**

`threshopt` is a lightweight Python library that automatically finds the optimal decision threshold for classification models.  
Instead of relying on default thresholds (e.g. `0.5`), it optimizes them according to a chosen evaluation metric, improving model performance—especially on imbalanced datasets.

The library is fully compatible with **scikit-learn estimators** and supports both **binary** and **multiclass (one-vs-rest)** scenarios.

---

## Requirements

- Python >= 3.8
- scikit-learn
- numpy
- matplotlib

---

## Installation

```bash
pip install threshopt
```

---

## Features

- Automatic optimization of decision thresholds
- Metric-driven optimization (any `sklearn`-style metric or custom metric)
- Cross-validated threshold optimization
- Works with any scikit-learn compatible classifier
- Built-in metrics for imbalanced classification
- Optional visualization utilities
- Multiclass support via one-vs-rest logic

---

## Quickstart

### Binary classification

```python
from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score

X, y = load_breast_cancer(return_X_y=True)

model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize threshold on the full dataset
best_thresh, best_metric = optimize_threshold(
    model, X, y,
    metric=f1_score,
    plot=False, cm=False
)
print(f"Best threshold: {best_thresh:.2f}")
print(f"Best F1-score: {best_metric:.4f}")

# Optimize threshold using cross-validation
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
    model, X, y,
    metric=gmean_score,
    cv=5,
    plot=False, cm=False
)
print(f"CV best threshold: {best_thresh_cv:.2f}")
print(f"CV best G-Mean: {best_metric_cv:.4f}")
```

### Multiclass classification

```python
from threshopt import optimize_threshold, optimize_threshold_cv
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import f1_score

X, y = load_iris(return_X_y=True)

model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize thresholds (one per class)
best_thresh, best_metric = optimize_threshold(
    model, X, y,
    metric=f1_score,
    multiclass=True,
    plot=False, cm=False
)
print("Best thresholds per class:", best_thresh)
print(f"Best F1-score: {best_metric:.4f}")

# Cross-validated multiclass optimization
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
    model, X, y,
    metric=f1_score,
    cv=5,
    multiclass=True,
    plot=False, cm=False
)
print("CV best thresholds per class:", best_thresh_cv)
print(f"CV best F1-score: {best_metric_cv:.4f}")
```

---

## API

### `optimize_threshold`

```python
optimize_threshold(model, X, y_true, metric, multiclass=False,
                   use_predict_if_no_proba=False, plot=True, cm=True, report=True)
```

Optimizes the decision threshold on the full dataset.

| Parameter | Type | Description |
|---|---|---|
| `model` | estimator | Trained classifier with `predict_proba` or `decision_function` |
| `X` | array-like | Feature matrix |
| `y_true` | array-like | True labels |
| `metric` | callable | Metric function `(y_true, y_pred) -> float` |
| `multiclass` | bool | If `True`, performs one-vs-rest optimization |
| `use_predict_if_no_proba` | bool | Fall back to `predict()` if no probability scores available |
| `plot` | bool | If `True`, plots probability distributions |
| `cm` | bool | If `True`, displays confusion matrix |
| `report` | bool | If `True`, prints classification report |

Returns: `(thresholds, best_scores)` — scalar for binary, array for multiclass.

---

### `optimize_threshold_cv`

```python
optimize_threshold_cv(model, X, y_true, metric, cv=5, multiclass=False,
                      plot=True, cm=True, report=True)
```

Optimizes the decision threshold using cross-validation, reducing the risk of overfitting to a specific train/test split.

| Parameter | Type | Description |
|---|---|---|
| `model` | estimator | Trained classifier with `predict_proba` or `decision_function` |
| `X` | array-like | Feature matrix |
| `y_true` | array-like | True labels |
| `metric` | callable | Metric function `(y_true, y_pred) -> float` |
| `cv` | int | Number of cross-validation folds |
| `multiclass` | bool | If `True`, performs one-vs-rest optimization |
| `plot` | bool | If `True`, plots probability distributions |
| `cm` | bool | If `True`, displays confusion matrix |
| `report` | bool | If `True`, prints classification report |

Returns: `(thresholds, best_scores)` — scalar for binary, array for multiclass.

---

## Metrics

### Built-in metrics

| Metric | Signature | Description |
|---|---|---|
| `gmean_score` | `(y_true, y_pred) -> float` | Geometric mean of sensitivity and specificity |
| `youden_j_stat` | `(y_true, y_pred) -> float` | Sensitivity + specificity - 1 |
| `balanced_acc_score` | `(y_true, y_pred) -> float` | Balanced accuracy (wrapper around scikit-learn) |

### Custom metrics

Any function with the following signature can be passed as `metric`:

```python

 metric(y_true, y_pred) # return a float 

```

---

## Contributing

Contributions are welcome! Please follow these steps:

1. Fork the repository
2. Create a feature branch (`git checkout -b feature/my-feature`)
3. Commit your changes (`git commit -m 'Add my feature'`)
4. Push to the branch (`git push origin feature/my-feature`)
5. Open a pull request

For bug reports or feature requests, please open an issue.

---

## License

This project is licensed under the Apache License 2.0 — see the [LICENSE](LICENSE) file for details.
