Source code for scitex_ml.plt._plot_pre_rec_curve

#!/usr/bin/env python3
# Timestamp: "2025-10-02 19:44:06 (ywatanabe)"
# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/plt/plot_pre_rec_curve.py
# ----------------------------------------
from __future__ import annotations
import scitex_io
import scitex_plt


import os

__FILE__ = __file__
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------

import argparse

import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve

from scitex_plt.colors import get_colors_from_cmap


def _solve_intersection(f1, a, b):
    """Determine intersection of line (y = ax + b) and iso-f1 curve."""
    _a = 2 * a
    _b = -a * f1 + 2 * b - f1
    _c = -b * f1

    x_f = (-_b + np.sqrt(_b**2 - 4 * _a * _c)) / (2 * _a)
    y_f = a * x_f + b

    return (x_f, y_f)


def _to_onehot(class_indices, n_classes):
    """Convert class indices to one-hot encoding."""
    eye = np.eye(n_classes, dtype=int)
    return eye[class_indices]


[docs] def plot_pre_rec_curve(true_class, pred_proba, labels, ax=None, spath=None): """ Plot precision-recall curve. Parameters ---------- true_class : array-like True class labels pred_proba : array-like Predicted probabilities labels : list Class labels ax : matplotlib axis, optional Axis to plot on. If None, creates new figure spath : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure Figure object metrics : dict Precision-recall metrics """ import scitex as stx # Use label_binarize to be multi-label like settings n_classes = len(labels) # Handle 1D pred_proba (binary classification with only positive class probabilities) if pred_proba.ndim == 1: # Convert to 2D: [P(class=0), P(class=1)] pred_proba = np.column_stack([1 - pred_proba, pred_proba]) # Convert string labels to integer indices if needed if true_class.dtype.kind in ("U", "S", "O"): # Unicode, bytes, or object (string) label_to_idx = {label: idx for idx, label in enumerate(labels)} true_class_idx = np.array([label_to_idx[tc] for tc in true_class]) else: true_class_idx = true_class true_class_onehot = _to_onehot(true_class_idx, n_classes) # For each class precision = dict() recall = dict() threshold = dict() pre_rec_auc = dict() for i in range(n_classes): true_class_i_onehot = true_class_onehot[:, i] pred_proba_i = pred_proba[:, i] try: precision[i], recall[i], threshold[i] = precision_recall_curve( true_class_i_onehot, pred_proba_i, ) pre_rec_auc[i] = average_precision_score(true_class_i_onehot, pred_proba_i) except Exception as e: print(e) precision[i], recall[i], threshold[i], pre_rec_auc[i] = ( np.nan, np.nan, np.nan, np.nan, ) ## Average precision: micro and macro # A "micro-average": quantifying score on all classes jointly precision["micro"], recall["micro"], threshold["micro"] = precision_recall_curve( true_class_onehot.ravel(), pred_proba.ravel() ) pre_rec_auc["micro"] = average_precision_score( true_class_onehot, pred_proba, average="micro" ) # macro _pre_rec_aucs = [] for i in range(n_classes): try: _pre_rec_aucs.append( average_precision_score( true_class_onehot[:, i], pred_proba[:, i], average="macro" ) ) except Exception: print( f'\nPRE-REC-AUC for "{labels[i]}" was not defined and NaN-filled ' "for a calculation purpose (for the macro avg.)\n" ) _pre_rec_aucs.append(np.nan) pre_rec_auc["macro"] = np.nanmean(_pre_rec_aucs) # pre_rec_auc["macro"] = average_precision_score( # true_class_onehot, pred_proba, average="macro" # ) # Plot Precision-Recall curve for each class and iso-f1 curves # Use scitex color palette for consistent styling colors = get_colors_from_cmap("tab10", n_classes) if ax is None: fig, ax = scitex_plt.subplots() else: fig = ax.get_figure() ax.set_box_aspect(1) lines = [] legends = [] # iso-F1: By definition, an iso-F1 curve contains all points # in the precision/recall space whose F1 scores are the same. f_scores = np.linspace(0.2, 0.8, num=4) # for f_score in f_scores: for i_f, f_score in enumerate(f_scores): x = np.linspace(0.01, 1) # num=50 y = f_score * x / (2 * x - f_score) (l,) = ax.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2) # ax.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02)) x_f, y_f = _solve_intersection(f_score, 0.5, 0.5) ax.annotate(f"f1={f_score:0.1f}", xy=(x_f - 0.1, y_f - 0.1 * 0.5)) # ax.annotate("f1={0:0.1f}".format(f_score), xy=(y[35] - 0.02 * (3 - i_f), 0.85)) lines.append(l) legends.append("iso-f1 curves") """ ## In this project, average precision-recall curve is not drawn. (l,) = ax.plot(recall["micro"], precision["micro"], color="gold", lw=2) lines.append(l) legends.append("micro-average\n(AUC = {0:0.2f})" "".format(pre_rec_auc["micro"])) """ ## Each Class for i in range(n_classes): (l,) = ax.plot(recall[i], precision[i], color=colors[i], lw=2) lines.append(l) legends.append(f"{labels[i]} (AUC = {pre_rec_auc[i]:0.2f})") # fig = plt.gcf() fig.subplots_adjust(bottom=0.25) ax.set_xlim([-0.01, 1.01]) ax.set_ylim([-0.01, 1.01]) ax.set_xticks([0.0, 0.5, 1.0]) ax.set_yticks([0.0, 0.5, 1.0]) ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.set_title("Precision-Recall Curve") ax.legend(lines, legends, loc="lower left") metrics = dict( pre_rec_auc=pre_rec_auc, precision=precision, recall=recall, threshold=threshold, ) # Save figure if spath is provided if spath is not None: from pathlib import Path # Resolve to absolute path to prevent _out directory creation spath_abs = Path(spath).resolve() if isinstance(spath, (str, Path)) else spath scitex_io.save(fig, str(spath_abs), use_caller_path=False) return fig, metrics
def main(args): """Demo Precision-Recall curve plotting with MNIST dataset.""" from sklearn import datasets, svm from sklearn.model_selection import train_test_split np.random.seed(42) digits = datasets.load_digits() n_samples = len(digits.images) data = digits.images.reshape((n_samples, -1)) clf = svm.SVC(gamma=0.001, probability=True) X_train, X_test, y_train, y_test = train_test_split( data, digits.target, test_size=0.5, shuffle=False ) clf.fit(X_train, y_train) predicted_proba = clf.predict_proba(X_test) n_classes = len(np.unique(digits.target)) labels = [f"Class {i}" for i in range(n_classes)] # plt.rcParams["font.size"] = 20 # plt.rcParams["legend.fontsize"] = "xx-small" # plt.rcParams["figure.figsize"] = (16 * 1.2, 9 * 1.2) fig, metrics_dict = plot_pre_rec_curve(y_test, predicted_proba, labels) import scitex scitex_io.save(fig, "plot_pre_rec_curve_demo.jpg") return 0 def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Demo Precision-Recall curve plotting") return parser.parse_args() def run_main() -> None: """Initialize scitex framework, run main function, and cleanup.""" global CONFIG, CC, sys, plt, rng import sys import matplotlib.pyplot as plt import scitex as stx args = parse_args() CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start( sys, plt, args=args, file=__FILE__, sdir_suffix=None, verbose=False, agg=True, ) exit_status = main(args) stx.session.close( CONFIG, verbose=False, notify=False, message="", exit_status=exit_status, ) if __name__ == "__main__": run_main() # EOF