scitex_ml.training

Training utilities.

class scitex_ml.training.EarlyStopping(patience=7, verbose=False, delta=1e-05, direction='minimize')[source]

Early stops the training if the validation score doesn’t improve after a given patience period.

__init__(patience=7, verbose=False, delta=1e-05, direction='minimize')[source]
Parameters:
  • patience (int) – How long to wait after last time validation score improved. Default: 7

  • verbose (bool) – If True, prints a message for each validation score improvement. Default: False

  • delta (float) – Minimum change in the monitored quantity to qualify as an improvement. Default: 0

is_best(val_score)[source]
save(current_score, models_spaths_dict, i_global)[source]

Saves model when validation score decrease.

class scitex_ml.training.LearningCurveLogger[source]

Records and visualizes learning metrics during model training.

Example

>>> logger = LearningCurveLogger()
>>> metrics = {
...     "loss_plot": 0.5,
...     "balanced_ACC_plot": 0.8,
...     "pred_proba": pred_proba,
...     "true_class": labels,
...     "i_fold": 0,
...     "i_epoch": 1,
...     "i_global": 100
... }
>>> logger(metrics, "Training")
>>> fig = logger.plot_learning_curves()
__init__()[source]
property dfs: Dict[str, DataFrame]

Returns DataFrames of logged metrics.

Returns:

Dictionary of DataFrames for each step

Return type:

Dict[str, pd.DataFrame]

to_metrics_df()[source]

Convert logged data to metrics DataFrame for plot_learning_curve.

Returns:

DataFrame with columns: step, i_global, i_epoch, i_batch, and metric columns

Return type:

pd.DataFrame

plot_learning_curves(title=None, max_n_ticks=4, linewidth=1, scattersize=3, yscale='linear', spath=None)[source]

Plots learning curves from logged metrics.

Delegates to scitex_ml.plt.plot_learning_curve for consistent plotting.

Parameters:
  • title (str, optional) – Plot title

  • max_n_ticks (int) – Maximum number of ticks on axes

  • linewidth (float) – Width of plot lines

  • scattersize (float) – Size of scatter points

  • yscale (str) – Y-axis scale (‘linear’ or ‘log’)

  • spath (str, optional) – Save path for the figure

Returns:

Figure containing learning curves

Return type:

matplotlib.figure.Figure

get_x_of_i_epoch(x, step, i_epoch)[source]

Gets metric values for a specific epoch.

Parameters:
  • x (str) – Name of metric to retrieve

  • step (str) – Training phase

  • i_epoch (int) – Epoch number

Returns:

Array of metric values for specified epoch

Return type:

np.ndarray

print(step)[source]

Prints metrics for given step.

Parameters:

step (str) – Training phase to print metrics for

Return type:

None