Source code for scitex_ml.plt._stx_conf_mat

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-10-02 07:15:10 (ywatanabe)"
# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/plt/stx_conf_mat.py
# ----------------------------------------
from __future__ import annotations

import scitex_context
import scitex_io
import scitex_plt
from scitex_dev import try_import_optional

# `figrecipe` is the underlying impl that the umbrella exposes as
# `scitex.plt` — `ax`, `configure_mpl`, `_subplots`, etc. live there.
# Use the peer standalone directly (PA304-clean); fall back to None
# when the optional plotting layer isn't installed.
_umbrella_plt = try_import_optional("figrecipe", extra="plt", pkg="scitex-ml")

import os

import scitex_str

__FILE__ = __file__
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------

import argparse

import matplotlib
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable
from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix

# Import metric calculation from centralized location (SoC: metrics in scitex_ml.metrics)
from scitex_ml.metrics import calc_bacc_from_conf_mat

# Aliases for backward compatibility
calc_bACC_from_conf_mat = calc_bacc_from_conf_mat


[docs] def stx_conf_mat( cm=None, y_true=None, y_pred=None, y_pred_proba=None, labels=None, sorted_labels=None, pred_labels=None, sorted_pred_labels=None, true_labels=None, sorted_true_labels=None, label_rotation_xy=(15, 15), title="Confusion Matrix", colorbar=True, x_extend_ratio=1.0, y_extend_ratio=1.0, ax=None, spath=None, ): """ Plot confusion matrix as a heatmap with inverted y-axis. Parameters ---------- cm : array-like, optional Pre-computed confusion matrix y_true : array-like, optional True labels y_pred : array-like, optional Predicted labels y_pred_proba : array-like, optional Predicted probabilities labels : list, optional List of labels sorted_labels : list, optional Sorted list of labels pred_labels : list, optional Predicted label names sorted_pred_labels : list, optional Sorted predicted label names true_labels : list, optional True label names sorted_true_labels : list, optional Sorted true label names label_rotation_xy : tuple, optional Rotation angles for x and y labels title : str, optional Title of the plot colorbar : bool, optional Whether to include a colorbar x_extend_ratio : float, optional Ratio to extend x-axis y_extend_ratio : float, optional Ratio to extend y-axis spath : str, optional Path to save the figure plt : deprecated Kept for backward compatibility only cm : array-like, optional Pre-computed confusion matrix y_true : array-like, optional True labels y_pred : array-like, optional Predicted labels y_pred_proba : array-like, optional Predicted probabilities labels : list, optional List of labels sorted_labels : list, optional Sorted list of labels pred_labels : list, optional Predicted label names sorted_pred_labels : list, optional Sorted predicted label names true_labels : list, optional True label names sorted_true_labels : list, optional Sorted true label names label_rotation_xy : tuple, optional Rotation angles for x and y labels title : str, optional Title of the plot colorbar : bool, optional Whether to include a colorbar x_extend_ratio : float, optional Ratio to extend x-axis y_extend_ratio : float, optional Ratio to extend y-axis spath : str, optional Path to save the figure Returns ------- fig : matplotlib.figure.Figure The figure object containing the plot cm : pandas.DataFrame The confusion matrix as a DataFrame Example ------- y_true = [0, 1, 2, 0, 1, 2] y_pred = [0, 2, 1, 0, 0, 1] labels = ['A', 'B', 'C'] fig, cm = conf_mat(plt, y_true=y_true, y_pred=y_pred, labels=labels) plt.show() """ if (y_pred_proba is not None) and (y_pred is None): y_pred = y_pred_proba.argmax(axis=-1) assert (cm is not None) or ((y_true is not None) and (y_pred is not None)) if cm is None: with scitex_context.suppress_output(): cm = sklearn_confusion_matrix(y_true, y_pred, labels=labels) bacc = calc_bACC_from_conf_mat(cm) title = f"{title} (bACC = {bacc:.3f})" # Only reconstruct full confusion matrix if we have y_true and y_pred # When cm is passed directly, use it as-is if labels is not None and y_true is not None and y_pred is not None: full_cm = np.zeros((len(labels), len(labels))) unique_true = np.unique(y_true) unique_pred = np.unique(y_pred) for idx_true, true_label in enumerate(labels): for idx_pred, pred_label in enumerate(labels): if true_label in unique_true and pred_label in unique_pred: full_cm[idx_true, idx_pred] = cm[ np.where(unique_true == true_label)[0][0], np.where(unique_pred == pred_label)[0][0], ] cm = full_cm cm = pd.DataFrame(data=cm).copy() labels_to_latex = lambda labels: ( [scitex_str.to_latex_style(label) for label in labels] if labels is not None else None ) pred_labels = labels_to_latex(pred_labels) true_labels = labels_to_latex(true_labels) labels = labels_to_latex(labels) sorted_labels = labels_to_latex(sorted_labels) if pred_labels is not None: cm.columns = pred_labels elif labels is not None: cm.columns = labels if true_labels is not None: cm.index = true_labels elif labels is not None: cm.index = labels if sorted_labels is not None: assert set(sorted_labels) == set(labels) cm = cm.reindex(index=sorted_labels, columns=sorted_labels) # Use scitex_plt.subplots() for CSV export functionality # The AxisWrapper provides export_as_csv method for automatic data export import matplotlib.pyplot as mpl_plt if ax is None: fig, ax = scitex_plt.subplots() else: fig = ax.get_figure() res = sns.heatmap( cm, annot=True, annot_kws={"size": mpl_plt.rcParams["font.size"]}, fmt=".0f", cmap="Blues", cbar=False, vmin=0, ) for text in ax.texts: text.set_text("{:,d}".format(int(text.get_text()))) res.invert_yaxis() for _, spine in res.spines.items(): spine.set_visible(False) ax.set_xlabel("Predicted label") ax.set_ylabel("True label") ax.set_title(title) ax = _umbrella_plt.ax.extend(ax, x_extend_ratio, y_extend_ratio) if cm.shape[0] == cm.shape[1]: ax.set_box_aspect(1) ax.set_xticklabels( ax.get_xticklabels(), rotation=label_rotation_xy[0], fontdict={"verticalalignment": "top"}, ) ax.set_yticklabels( ax.get_yticklabels(), rotation=label_rotation_xy[1], fontdict={"horizontalalignment": "right"}, ) bbox = ax.get_position() left_orig = bbox.x0 width_orig = bbox.x1 - bbox.x0 width_tgt = width_orig * x_extend_ratio dx = width_orig - width_tgt if colorbar: # Extract underlying matplotlib axes if wrapped by SciTeX ax_mpl = ax._axis_mpl if hasattr(ax, "_axis_mpl") else ax divider = make_axes_locatable(ax_mpl) cax = divider.append_axes("right", size="5%", pad=0.15) # Get the underlying figure fig_mpl = fig._fig_mpl if hasattr(fig, "_fig_mpl") else fig vmax = np.array(cm).max().astype(int) norm = matplotlib.colors.Normalize(vmin=0, vmax=vmax) cbar = fig_mpl.colorbar( matplotlib.cm.ScalarMappable(norm=norm, cmap="Blues"), cax=cax, ) cbar.locator = ticker.MaxNLocator(nbins=4) cbar.update_ticks() cbar.outline.set_edgecolor("white") if spath is not None: from pathlib import Path # Resolve to absolute path to prevent _out directory creation spath_abs = Path(spath).resolve() if isinstance(spath, (str, Path)) else spath scitex_io.save(fig, str(spath_abs), use_caller_path=False) return fig, cm
# Backward compatibility alias conf_mat = stx_conf_mat # Metric calculation functions have been moved to scitex_ml.metrics # They are imported above for use in this module # This maintains SoC: plotting in scitex_ml.plt, metrics in scitex_ml.metrics def main(args): """Demo confusion matrix plotting with Iris dataset.""" import sklearn from sklearn import datasets, svm from sklearn.model_selection import train_test_split iris = datasets.load_iris() X = iris.data y = iris.target class_names = iris.target_names X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) classifier = svm.SVC(kernel="linear", C=0.01, random_state=42).fit(X_train, y_train) y_pred = classifier.predict(X_test) cm_data = sklearn.metrics.confusion_matrix(y_test, y_pred) cm_data **= 3 fig, cm_result = conf_mat( cm=cm_data, pred_labels=["A", "B", "C"], true_labels=["a", "b", "c"], label_rotation_xy=(60, 60), x_extend_ratio=1.0, colorbar=True, ) fig.axes[-1] = _umbrella_plt.ax.sci_note( fig.axes[-1], fformat="%3.1f", y=True, ) scitex_io.save(fig, "plot_conf_mat_demo.jpg") return 0 def parse_args() -> argparse.Namespace: """Parse command line arguments.""" import argparse parser = argparse.ArgumentParser(description="Demo confusion matrix plotting") return parser.parse_args() def run_main() -> None: """Initialize scitex framework, run main function, and cleanup.""" global CONFIG, CC, sys, plt, rng import sys import matplotlib.pyplot as plt import scitex as stx args = parse_args() CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start( sys, plt, args=args, file=__FILE__, sdir_suffix=None, verbose=False, agg=True, ) exit_status = main(args) stx.session.close( CONFIG, verbose=False, notify=False, message="", exit_status=exit_status, ) if __name__ == "__main__": run_main() # EOF # #!/usr/bin/env python3 # import matplotlib # import scitex # import numpy as np # import pandas as pd # import seaborn as sns # from matplotlib import ticker # from mpl_toolkits.axes_grid1 import make_axes_locatable # from sklearn.metrics import balanced_accuracy_score # from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix # def conf_mat( # plt, # cm=None, # y_true=None, # y_pred=None, # y_pred_proba=None, # labels=None, # sorted_labels=None, # pred_labels=None, # sorted_pred_labels=None, # true_labels=None, # sorted_true_labels=None, # label_rotation_xy=(15, 15), # title="Confuxion Matrix", # colorbar=True, # x_extend_ratio=1.0, # y_extend_ratio=1.0, # spath=None, # ): # """ # Inverse the y-axis and plot the confusion matrix as a heatmap. # The predicted labels (in x-axis) is symboled with hat (^). # The plt object is passed to adjust the figure size # cm = sklearn.metrics.confusion_matrix(y_test, y_pred) # cm = np.random.randint(low=0, high=10, size=[3,4]) # x: predicted labels # y: true_labels # kwargs: # "extend_ratio": # Determines how much the axes objects (not the fig object) are extended # in the vertical direction. # """ # if (y_pred_proba is not None) and (y_pred is None): # y_pred = y_pred_proba.argmax(axis=-1) # assert (cm is not None) or ((y_true is not None) and (y_pred is not None)) # if cm is None: # with scitex_context.suppress_output(): # cm = sklearn_confusion_matrix(y_true, y_pred, labels=labels) # # cm = sklearn_confusion_matrix(y_true, y_pred) # bacc = calc_bACC_from_conf_mat(cm) # title = f"{title} (bACC = {bacc})" # # Ensure all labels are present in the confusion matrix # if labels is not None: # full_cm = np.zeros((len(labels), len(labels))) # unique_true = np.unique(y_true) # unique_pred = np.unique(y_pred) # for i, true_label in enumerate(labels): # for j, pred_label in enumerate(labels): # if true_label in unique_true and pred_label in unique_pred: # full_cm[i, j] = cm[ # np.where(unique_true == true_label)[0][0], # np.where(unique_pred == pred_label)[0][0], # ] # cm = full_cm # # Dataframe # cm = pd.DataFrame( # data=cm, # ).copy() # # To LaTeX styles # if pred_labels is not None: # pred_labels = [scitex_str.to_latex_style(l) for l in pred_labels] # if true_labels is not None: # true_labels = [scitex_str.to_latex_style(l) for l in true_labels] # if labels is not None: # labels = [scitex_str.to_latex_style(l) for l in labels] # if sorted_labels is not None: # sorted_labels = [scitex_str.to_latex_style(l) for l in sorted_labels] # # Prediction Labels: columns # if pred_labels is not None: # cm.columns = pred_labels # elif (pred_labels is None) and (labels is not None): # cm.columns = labels # # Ground Truth Labels: index # if true_labels is not None: # cm.index = true_labels # elif (true_labels is None) and (labels is not None): # cm.index = labels # # Sort based on sorted_labels here # if sorted_labels is not None: # assert set(sorted_labels) == set(labels) # cm = cm.reindex(index=sorted_labels, columns=sorted_labels) # # Main # # Use scitex_plt.subplots for advanced CSV export and SciTeX Viz integration fig, ax = scitex_plt.subplots() # res = sns.heatmap( # cm, # annot=True, # annot_kws={"size": mpl_plt.rcParams["font.size"]}, # fmt=".0f", # cmap="Blues", # cbar=False, # vmin=0, # ) # Here, don't plot color bar. # # Adds comma separator for the annotated int texts # for t in ax.texts: # t.set_text("{:,d}".format(int(t.get_text()))) # # Inverts the y-axis # res.invert_yaxis() # # Makes the frame visible # for _, spine in res.spines.items(): # spine.set_visible(False) # # Labels # ax.set_xlabel("Predicted label") # ax.set_ylabel("True label") # ax.set_title(title) # # Appearances # ax = _umbrella_plt.ax.extend(ax, x_extend_ratio, y_extend_ratio) # if cm.shape[0] == cm.shape[1]: # ax.set_box_aspect(1) # ax.set_xticklabels( # ax.get_xticklabels(), # rotation=label_rotation_xy[0], # fontdict={"verticalalignment": "top"}, # ) # ax.set_yticklabels( # ax.get_yticklabels(), # rotation=label_rotation_xy[1], # fontdict={"horizontalalignment": "right"}, # ) # # The size # bbox = ax.get_position() # left_orig = bbox.x0 # width_orig = bbox.x1 - bbox.x0 # g_x_orig = left_orig + width_orig / 2.0 # width_tgt = width_orig * x_extend_ratio # x_extend_ratio # dx = width_orig - width_tgt # # Adjusts the sizes of the confusion matrix and colorbar # if colorbar: # divider = make_axes_locatable(ax) # Gets region from the ax # cax = divider.append_axes("right", size="5%", pad=0.1) # # cax = divider.new_horizontal(size="5%", pad=1, pack_start=True) # cax = _umbrella_plt.ax.shift(cax, dx=-dx * 2.54, dy=0) # fig.add_axes(cax) # """ # axpos = ax.get_position() # caxpos = cax.get_position() # AddAxesBBoxRect(fig, ax, ec="r") # AddAxesBBoxRect(fig, cax, ec="b") # fig.text( # axpos.x0 + 0.01, axpos.y0 + 0.01, "after colorbar", weight="bold", color="r" # ) # fig.text( # caxpos.x1 + 0.01, # caxpos.y1 - 0.01, # "cax position", # va="top", # weight="bold", # color="b", # rotation="vertical", # ) # """ # # Plots colorbar and adjusts the size # vmax = np.array(cm).max().astype(int) # norm = matplotlib.colors.Normalize(vmin=0, vmax=vmax) # cbar = fig.colorbar( # plt.cm.ScalarMappable(norm=norm, cmap="Blues"), # cax=cax, # # shrink=0.68, # ) # cbar.locator = ticker.MaxNLocator(nbins=4) # tick_locator # cbar.update_ticks() # # cbar.outline.set_edgecolor("#f9f2d7") # cbar.outline.set_edgecolor("white") # if spath is not None: # scitex_io.save(fig, spath) # return fig, cm # def calc_bACC_from_conf_mat(cm): # with scitex_context.suppress_output(): # try: # per_class = np.diag(cm) / np.nansum(cm, axis=1) # bacc = round(np.nanmean(per_class), 3) # except: # bacc = np.nan # return bacc # # def AddAxesBBoxRect(fig, ax, ec="k"): # # from matplotlib.patches import Rectangle # # axpos = ax.get_position() # # rect = fig.patches.append( # # Rectangle( # # (axpos.x0, axpos.y0), # # axpos.width, # # axpos.height, # # ls="solid", # # lw=2, # # ec=ec, # # fill=False, # # transform=fig.transFigure, # # ) # # ) # # return rect # if __name__ == "__main__": # import sys # import matplotlib.pyplot as plt # import scitex # import numpy as np # import sklearn # from sklearn import datasets, svm # # from sklearn.metrics import plot_confusion_matrix # from sklearn.model_selection import train_test_split # sys.path.append(".") # import scitex # # https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py # # Imports some data to play with # iris = datasets.load_iris() # X = iris.data # y = iris.target # class_names = iris.target_names # # Splits the data into a training set and a test set # X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) # # Runs classifier, using a model that is too regularized (C too low) to see # # the impact on the results # classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train) # ## Checks the confusion_matrix function # y_pred = classifier.predict(X_test) # cm = sklearn.metrics.confusion_matrix(y_test, y_pred) # cm **= 3 # fig, cm = conf_mat( # plt, # # y_true=y_test, # # y_pred=y_pred, # cm, # pred_labels=["A", "B", "C"], # true_labels=["a", "b", "c"], # label_rotation_xy=(60, 60), # x_extend_ratio=1.0, # colorbar=True, # ) # fig.axes[-1] = _umbrella_plt.ax.sci_note( # fig.axes[-1], # fformat="%3.1f", # y=True, # ) # plt.show() # ## EOF # EOF