Source code for scitex_ml.plt._plot_roc_curve

#!/usr/bin/env python3
# Timestamp: "2025-10-02 19:44:13 (ywatanabe)"
# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/plt/plot_roc_curve.py
# ----------------------------------------
from __future__ import annotations
import scitex_io
import scitex_plt


import os

__FILE__ = __file__
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------

import argparse

import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve

from scitex_plt.colors import get_colors_from_cmap


def _to_onehot(class_indices, n_classes):
    """Convert class indices to one-hot encoding."""
    eye = np.eye(n_classes, dtype=int)
    return eye[class_indices]


[docs] def plot_roc_curve(true_class, pred_proba, labels, ax=None, spath=None): """ Plot ROC-AUC curve. Parameters ---------- true_class : array-like True class labels pred_proba : array-like Predicted probabilities labels : list Class labels ax : matplotlib axis, optional Axis to plot on. If None, creates new figure spath : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure Figure object metrics : dict ROC metrics """ import scitex as stx # Use label_binarize to be multi-label like settings n_classes = len(labels) # Handle 1D pred_proba (binary classification with only positive class probabilities) if pred_proba.ndim == 1: # Convert to 2D: [P(class=0), P(class=1)] pred_proba = np.column_stack([1 - pred_proba, pred_proba]) # Convert string labels to integer indices if needed if true_class.dtype.kind in ("U", "S", "O"): # Unicode, bytes, or object (string) label_to_idx = {label: idx for idx, label in enumerate(labels)} true_class_idx = np.array([label_to_idx[tc] for tc in true_class]) else: true_class_idx = true_class true_class_onehot = _to_onehot(true_class_idx, n_classes) # For each class fpr = dict() tpr = dict() threshold = dict() roc_auc = dict() for i in range(n_classes): true_class_i_onehot = true_class_onehot[:, i] pred_proba_i = pred_proba[:, i] try: fpr[i], tpr[i], threshold[i] = roc_curve(true_class_i_onehot, pred_proba_i) roc_auc[i] = roc_auc_score(true_class_i_onehot, pred_proba_i) except Exception as e: print(e) fpr[i], tpr[i], threshold[i], roc_auc[i] = ( [np.nan], [np.nan], [np.nan], np.nan, ) ## Average fpr: micro and macro # A "micro-average": quantifying score on all classes jointly fpr["micro"], tpr["micro"], threshold["micro"] = roc_curve( true_class_onehot.ravel(), pred_proba.ravel() ) roc_auc["micro"] = roc_auc_score(true_class_onehot, pred_proba, average="micro") # macro _roc_aucs = [] for i in range(n_classes): try: _roc_aucs.append( roc_auc_score( true_class_onehot[:, i], pred_proba[:, i], average="macro" ) ) except Exception: print( f'\nROC-AUC for "{labels[i]}" was not defined and NaN-filled ' "for a calculation purpose (for the macro avg.)\n" ) _roc_aucs.append(np.nan) roc_auc["macro"] = np.nanmean(_roc_aucs) # Plot FPR-TPR curve for each class and iso-f1 curves # Use scitex color palette for consistent styling colors = get_colors_from_cmap("tab10", n_classes) if ax is None: fig, ax = scitex_plt.subplots() else: fig = ax.get_figure() ax.set_box_aspect(1) lines = [] legends = [] ## Chance Level (the diagonal line) (l,) = ax.plot( np.linspace(0.01, 1), np.linspace(0.01, 1), color="gray", lw=2, linestyle="--", alpha=0.8, ) lines.append(l) legends.append("Chance") ## Each Class for i in range(n_classes): (l,) = ax.plot(fpr[i], tpr[i], color=colors[i], lw=2) lines.append(l) legends.append(f"{labels[i]} (AUC = {roc_auc[i]:0.2f})") # fig = plt.gcf() fig.subplots_adjust(bottom=0.25) ax.set_xlim([-0.01, 1.01]) ax.set_ylim([-0.01, 1.01]) ax.set_xticks([0.0, 0.5, 1.0]) ax.set_yticks([0.0, 0.5, 1.0]) ax.set_xlabel("FPR") ax.set_ylabel("TPR") ax.set_title("ROC Curve") ax.legend(lines, legends, loc="lower right") metrics = dict(roc_auc=roc_auc, fpr=fpr, tpr=tpr, threshold=threshold) # Save figure if spath is provided if spath is not None: from pathlib import Path # Resolve to absolute path to prevent _out directory creation spath_abs = Path(spath).resolve() if isinstance(spath, (str, Path)) else spath scitex_io.save(fig, str(spath_abs), use_caller_path=False) return fig, metrics
def main(args): """Demo ROC AUC plotting with MNIST dataset.""" from sklearn import datasets, svm from sklearn.model_selection import train_test_split np.random.seed(42) digits = datasets.load_digits() n_samples = len(digits.images) data = digits.images.reshape((n_samples, -1)) clf = svm.SVC(gamma=0.001, probability=True) X_train, X_test, y_train, y_test = train_test_split( data, digits.target, test_size=0.5, shuffle=False ) clf.fit(X_train, y_train) predicted_proba = clf.predict_proba(X_test) n_classes = len(np.unique(digits.target)) labels = [f"Class {i}" for i in range(n_classes)] # plt.rcParams["font.size"] = 20 # plt.rcParams["legend.fontsize"] = "xx-small" # plt.rcParams["figure.figsize"] = (16 * 1.2, 9 * 1.2) y_test[y_test == 9] = 8 # override 9 as 8 fig, metrics_dict = plot_roc_curve( true_class=y_test, pred_proba=predicted_proba, labels=labels ) scitex_io.save(fig, "plot_roc_curve_demo.jpg") return 0 def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Demo ROC AUC plotting") return parser.parse_args() def run_main() -> None: """Initialize scitex framework, run main function, and cleanup.""" global CONFIG, CC, sys, plt, rng import sys import matplotlib.pyplot as plt import scitex as stx args = parse_args() CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start( sys, plt, args=args, file=__FILE__, sdir_suffix=None, verbose=False, agg=True, ) exit_status = main(args) stx.session.close( CONFIG, verbose=False, notify=False, message="", exit_status=exit_status, ) if __name__ == "__main__": run_main() # Backward compatibility alias roc_auc = plot_roc_curve # EOF