Source code for scitex_ml.training._LearningCurveLogger

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2024-11-20 08:49:50 (ywatanabe)"
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/ml/training/_LearningCurveLogger.py
# ----------------------------------------
from __future__ import annotations

import os

__FILE__ = __file__
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------

"""
Functionality:
    - Records and visualizes learning curves during model training
    - Supports tracking of multiple metrics across training/validation/test phases
    - Generates plots showing training progress over iterations and epochs
    - Delegates plotting to scitex_ml.plt.plot_learning_curve for consistency

Input:
    - Training metrics dictionary containing loss, accuracy, predictions etc.
    - Step information (Training/Validation/Test)

Output:
    - Learning curve plots via scitex_ml.plt.plot_learning_curve
    - DataFrames with recorded metrics
    - Training progress prints

Prerequisites:
    - PyTorch
    - scikit-learn
    - matplotlib
    - pandas
    - numpy
    - scitex
"""

import re
import warnings
from collections import defaultdict
from pprint import pprint
from typing import Any, Dict, List, Optional, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


[docs] class LearningCurveLogger: """Records and visualizes learning metrics during model training. Example ------- >>> logger = LearningCurveLogger() >>> metrics = { ... "loss_plot": 0.5, ... "balanced_ACC_plot": 0.8, ... "pred_proba": pred_proba, ... "true_class": labels, ... "i_fold": 0, ... "i_epoch": 1, ... "i_global": 100 ... } >>> logger(metrics, "Training") >>> fig = logger.plot_learning_curves() """
[docs] def __init__(self) -> None: self.logged_dict: Dict[str, Dict] = defaultdict(dict) warnings.warn( '\n"gt_label" will be removed in the future. Please use "true_class" instead.\n', DeprecationWarning, )
def __call__(self, dict_to_log: Dict[str, Any], step: str) -> None: """Logs metrics for a training step. Parameters ---------- dict_to_log : Dict[str, Any] Dictionary containing metrics to log step : str Phase of training ('Training', 'Validation', or 'Test') """ # Handle deprecated gt_label if "gt_label" in dict_to_log: dict_to_log["true_class"] = dict_to_log.pop("gt_label") for k_to_log in dict_to_log: try: self.logged_dict[step][k_to_log].append(dict_to_log[k_to_log]) except KeyError: self.logged_dict[step][k_to_log] = [dict_to_log[k_to_log]] @property def dfs(self) -> Dict[str, pd.DataFrame]: """Returns DataFrames of logged metrics. Returns ------- Dict[str, pd.DataFrame] Dictionary of DataFrames for each step """ return self._to_dfs_pivot(self.logged_dict, pivot_column=None)
[docs] def to_metrics_df(self) -> pd.DataFrame: """Convert logged data to metrics DataFrame for plot_learning_curve. Returns ------- pd.DataFrame DataFrame with columns: step, i_global, i_epoch, i_batch, and metric columns """ all_rows = [] for step, metrics in self.logged_dict.items(): n_samples = len(metrics["i_global"]) for i in range(n_samples): row = {"step": step} for key, values in metrics.items(): # Only include scalar metrics for the DataFrame if key.endswith("_plot") or key in [ "i_global", "i_epoch", "i_batch", "i_fold", ]: row[key] = values[i] all_rows.append(row) df = pd.DataFrame(all_rows) # Rename _plot columns df.columns = [ col.replace("_plot", "") if col.endswith("_plot") else col for col in df.columns ] # Ensure i_batch exists (create dummy if not logged) if "i_batch" not in df.columns: df["i_batch"] = df.groupby(["step", "i_epoch"]).cumcount() return df
[docs] def plot_learning_curves( self, title: Optional[str] = None, max_n_ticks: int = 4, linewidth: float = 1, scattersize: float = 3, yscale: str = "linear", spath: Optional[str] = None, ) -> matplotlib.figure.Figure: """Plots learning curves from logged metrics. Delegates to scitex_ml.plt.plot_learning_curve for consistent plotting. Parameters ---------- title : str, optional Plot title max_n_ticks : int Maximum number of ticks on axes linewidth : float Width of plot lines scattersize : float Size of scatter points yscale : str Y-axis scale ('linear' or 'log') spath : str, optional Save path for the figure Returns ------- matplotlib.figure.Figure Figure containing learning curves """ from scitex_ml.plt import plot_learning_curve # Convert to metrics DataFrame metrics_df = self.to_metrics_df() # Find keys to plot (exclude metadata columns) keys_to_plot = [ col for col in metrics_df.columns if col not in ["step", "i_global", "i_epoch", "i_batch", "i_fold"] ] if len(keys_to_plot) == 0: # Create empty plot when no plot keys found fig, ax = plt.subplots(1, 1) ax.set_xlabel("Iteration #") ax.set_ylabel("No metrics to plot") if title: fig.text(0.5, 0.95, title, ha="center") return fig # Delegate to centralized plotting function fig = plot_learning_curve( metrics_df=metrics_df, keys=keys_to_plot, title=title or "Learning Curves", max_n_ticks=max_n_ticks, scattersize=scattersize, linewidth=linewidth, yscale=yscale, spath=spath, ) return fig
[docs] def get_x_of_i_epoch(self, x: str, step: str, i_epoch: int) -> np.ndarray: """Gets metric values for a specific epoch. Parameters ---------- x : str Name of metric to retrieve step : str Training phase i_epoch : int Epoch number Returns ------- np.ndarray Array of metric values for specified epoch """ indi = np.array(self.logged_dict[step]["i_epoch"]) == i_epoch x_all_arr = np.array(self.logged_dict[step][x]) assert len(indi) == len(x_all_arr) return x_all_arr[indi]
[docs] def print(self, step: str) -> None: """Prints metrics for given step. Parameters ---------- step : str Training phase to print metrics for """ df_pivot_i_epoch = self._to_dfs_pivot(self.logged_dict, pivot_column="i_epoch") df_pivot_i_epoch_step = df_pivot_i_epoch[step] df_pivot_i_epoch_step.columns = self._rename_if_key_to_plot( df_pivot_i_epoch_step.columns ) print("\n----------------------------------------\n") print(f"\n{step}: (mean of batches)\n") pprint(df_pivot_i_epoch_step) print("\n----------------------------------------\n")
@staticmethod def _rename_if_key_to_plot(x: Union[str, pd.Index]) -> Union[str, pd.Index]: """Rename metric keys for plotting. Parameters ---------- x : str or pd.Index Metric name(s) to rename Returns ------- str or pd.Index Renamed metric name(s) """ if isinstance(x, str): if re.search("_plot$", x): return x.replace("_plot", "") else: return x else: return x.str.replace("_plot", "") @staticmethod def _to_dfs_pivot( logged_dict: Dict[str, Dict], pivot_column: Optional[str] = None, ) -> Dict[str, pd.DataFrame]: """Convert logged dictionary to pivot DataFrames. Parameters ---------- logged_dict : Dict[str, Dict] Dictionary of logged metrics pivot_column : str, optional Column to pivot on Returns ------- Dict[str, pd.DataFrame] Dictionary of pivot DataFrames """ dfs_pivot = {} for step_k in logged_dict.keys(): if pivot_column is None: df = pd.DataFrame(logged_dict[step_k]) else: df = ( pd.DataFrame(logged_dict[step_k]) .groupby(pivot_column) .mean() .reset_index() .set_index(pivot_column) ) dfs_pivot[step_k] = df return dfs_pivot
"""Functions & Classes""" def main(args): """Demo learning curve logger with MNIST training.""" import torch import torch.nn as nn from sklearn.metrics import balanced_accuracy_score from torch.utils.data import DataLoader, TensorDataset from torch.utils.data.dataset import Subset from torchvision import datasets import scitex ################################################################################ ## NN ################################################################################ class Perceptron(nn.Module): def __init__(self): super().__init__() self.l1 = nn.Linear(28 * 28, 50) self.l2 = nn.Linear(50, 10) def forward(self, x): x = x.view(-1, 28 * 28) x = self.l1(x) x = self.l2(x) return x ################################################################################ ## Prepare demo data ################################################################################ ## Downloads _ds_tra_val = datasets.MNIST("/tmp/mnist", train=True, download=True) _ds_tes = datasets.MNIST("/tmp/mnist", train=False, download=True) ## Training-Validation splitting n_samples = len(_ds_tra_val) train_size = int(n_samples * 0.8) subset1_indices = list(range(0, train_size)) subset2_indices = list(range(train_size, n_samples)) _ds_tra = Subset(_ds_tra_val, subset1_indices) _ds_val = Subset(_ds_tra_val, subset2_indices) ## to tensors ds_tra = TensorDataset( _ds_tra.dataset.data.to(torch.float32), _ds_tra.dataset.targets, ) ds_val = TensorDataset( _ds_val.dataset.data.to(torch.float32), _ds_val.dataset.targets, ) ds_tes = TensorDataset( _ds_tes.data.to(torch.float32), _ds_tes.targets, ) ## to dataloaders batch_size = 64 dl_tra = DataLoader( dataset=ds_tra, batch_size=batch_size, shuffle=True, drop_last=True, ) dl_val = DataLoader( dataset=ds_val, batch_size=batch_size, shuffle=False, drop_last=True, ) dl_tes = DataLoader( dataset=ds_tes, batch_size=batch_size, shuffle=False, drop_last=True, ) ################################################################################ ## Preparation ################################################################################ model = Perceptron() loss_func = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) softmax = nn.Softmax(dim=-1) ################################################################################ ## Main ################################################################################ lc_logger = LearningCurveLogger() i_global = 0 i_fold = 0 max_epochs = 3 for i_epoch in range(max_epochs): step = "Validation" for i_batch, batch in enumerate(dl_val): X, T = batch logits = model(X) pred_proba = softmax(logits) pred_class = pred_proba.argmax(dim=-1) loss = loss_func(logits, T) with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) bACC = balanced_accuracy_score(T, pred_class) dict_to_log = { "loss_plot": float(loss), "balanced_ACC_plot": float(bACC), "pred_proba": pred_proba.detach().cpu().numpy(), "true_class": T.cpu().numpy(), "i_fold": i_fold, "i_epoch": i_epoch, "i_batch": i_batch, "i_global": i_global, } lc_logger(dict_to_log, step) lc_logger.print(step) step = "Training" for i_batch, batch in enumerate(dl_tra): optimizer.zero_grad() X, T = batch logits = model(X) pred_proba = softmax(logits) pred_class = pred_proba.argmax(dim=-1) loss = loss_func(logits, T) loss.backward() optimizer.step() with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) bACC = balanced_accuracy_score(T, pred_class) dict_to_log = { "loss_plot": float(loss), "balanced_ACC_plot": float(bACC), "pred_proba": pred_proba.detach().cpu().numpy(), "true_class": T.cpu().numpy(), "i_fold": i_fold, "i_epoch": i_epoch, "i_batch": i_batch, "i_global": i_global, } lc_logger(dict_to_log, step) i_global += 1 lc_logger.print(step) step = "Test" for i_batch, batch in enumerate(dl_tes): X, T = batch logits = model(X) pred_proba = softmax(logits) pred_class = pred_proba.argmax(dim=-1) loss = loss_func(logits, T) with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) bACC = balanced_accuracy_score(T, pred_class) dict_to_log = { "loss_plot": float(loss), "balanced_ACC_plot": float(bACC), "pred_proba": pred_proba.detach().cpu().numpy(), "true_class": T.cpu().numpy(), "i_fold": i_fold, "i_epoch": i_epoch, "i_batch": i_batch, "i_global": i_global, } lc_logger(dict_to_log, step) lc_logger.print(step) # Plot using refactored method fig = lc_logger.plot_learning_curves( title=f"fold#{i_fold}", linewidth=1, scattersize=50, spath="learning_curve_logger_demo.jpg", ) return 0 import argparse def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Demo learning curve logger") return parser.parse_args() def run_main() -> None: """Initialize scitex framework, run main function, and cleanup.""" global CONFIG, CC, sys, plt, rng import sys import matplotlib.pyplot as plt import scitex as stx args = parse_args() CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start( sys, plt, args=args, file=__FILE__, sdir_suffix=None, verbose=False, agg=True, ) exit_status = main(args) stx.session.close( CONFIG, verbose=False, notify=False, message="", exit_status=exit_status, ) if __name__ == "__main__": run_main() # EOF