scitex_ml.training
Training utilities.
- class scitex_ml.training.EarlyStopping(patience=7, verbose=False, delta=1e-05, direction='minimize')[source]
Early stops the training if the validation score doesn’t improve after a given patience period.
- class scitex_ml.training.LearningCurveLogger[source]
Records and visualizes learning metrics during model training.
Example
>>> logger = LearningCurveLogger() >>> metrics = { ... "loss_plot": 0.5, ... "balanced_ACC_plot": 0.8, ... "pred_proba": pred_proba, ... "true_class": labels, ... "i_fold": 0, ... "i_epoch": 1, ... "i_global": 100 ... } >>> logger(metrics, "Training") >>> fig = logger.plot_learning_curves()
- property dfs: Dict[str, DataFrame]
Returns DataFrames of logged metrics.
- Returns:
Dictionary of DataFrames for each step
- Return type:
Dict[str, pd.DataFrame]
- to_metrics_df()[source]
Convert logged data to metrics DataFrame for plot_learning_curve.
- Returns:
DataFrame with columns: step, i_global, i_epoch, i_batch, and metric columns
- Return type:
pd.DataFrame
- plot_learning_curves(title=None, max_n_ticks=4, linewidth=1, scattersize=3, yscale='linear', spath=None)[source]
Plots learning curves from logged metrics.
Delegates to scitex_ml.plt.plot_learning_curve for consistent plotting.
- Parameters:
- Returns:
Figure containing learning curves
- Return type:
matplotlib.figure.Figure