scitex_ml.plt
Scitex centralized plotting module.
Note: Metric calculation functions (calc_*) are imported from scitex_ml.metrics but re-exported here for backward compatibility. New code should import directly from scitex_ml.metrics instead.
- scitex_ml.plt.stx_conf_mat(cm=None, y_true=None, y_pred=None, y_pred_proba=None, labels=None, sorted_labels=None, pred_labels=None, sorted_pred_labels=None, true_labels=None, sorted_true_labels=None, label_rotation_xy=(15, 15), title='Confusion Matrix', colorbar=True, x_extend_ratio=1.0, y_extend_ratio=1.0, ax=None, spath=None)[source]
Plot confusion matrix as a heatmap with inverted y-axis.
- Parameters:
cm (array-like, optional) – Pre-computed confusion matrix
y_true (array-like, optional) – True labels
y_pred (array-like, optional) – Predicted labels
y_pred_proba (array-like, optional) – Predicted probabilities
labels (list, optional) – List of labels
sorted_labels (list, optional) – Sorted list of labels
pred_labels (list, optional) – Predicted label names
sorted_pred_labels (list, optional) – Sorted predicted label names
true_labels (list, optional) – True label names
sorted_true_labels (list, optional) – Sorted true label names
label_rotation_xy (tuple, optional) – Rotation angles for x and y labels
title (str, optional) – Title of the plot
colorbar (bool, optional) – Whether to include a colorbar
x_extend_ratio (float, optional) – Ratio to extend x-axis
y_extend_ratio (float, optional) – Ratio to extend y-axis
spath (str, optional) – Path to save the figure
plt (deprecated) – Kept for backward compatibility only
cm – Pre-computed confusion matrix
y_true – True labels
y_pred – Predicted labels
y_pred_proba – Predicted probabilities
labels – List of labels
sorted_labels – Sorted list of labels
pred_labels – Predicted label names
sorted_pred_labels – Sorted predicted label names
true_labels – True label names
sorted_true_labels – Sorted true label names
label_rotation_xy – Rotation angles for x and y labels
title – Title of the plot
colorbar – Whether to include a colorbar
x_extend_ratio – Ratio to extend x-axis
y_extend_ratio – Ratio to extend y-axis
spath – Path to save the figure
- Returns:
fig (matplotlib.figure.Figure) – The figure object containing the plot
cm (pandas.DataFrame) – The confusion matrix as a DataFrame
Example
y_true = [0, 1, 2, 0, 1, 2] y_pred = [0, 2, 1, 0, 0, 1] labels = [‘A’, ‘B’, ‘C’] fig, cm = conf_mat(plt, y_true=y_true, y_pred=y_pred, labels=labels) plt.show()
- scitex_ml.plt.conf_mat(cm=None, y_true=None, y_pred=None, y_pred_proba=None, labels=None, sorted_labels=None, pred_labels=None, sorted_pred_labels=None, true_labels=None, sorted_true_labels=None, label_rotation_xy=(15, 15), title='Confusion Matrix', colorbar=True, x_extend_ratio=1.0, y_extend_ratio=1.0, ax=None, spath=None)
Plot confusion matrix as a heatmap with inverted y-axis.
- Parameters:
cm (array-like, optional) – Pre-computed confusion matrix
y_true (array-like, optional) – True labels
y_pred (array-like, optional) – Predicted labels
y_pred_proba (array-like, optional) – Predicted probabilities
labels (list, optional) – List of labels
sorted_labels (list, optional) – Sorted list of labels
pred_labels (list, optional) – Predicted label names
sorted_pred_labels (list, optional) – Sorted predicted label names
true_labels (list, optional) – True label names
sorted_true_labels (list, optional) – Sorted true label names
label_rotation_xy (tuple, optional) – Rotation angles for x and y labels
title (str, optional) – Title of the plot
colorbar (bool, optional) – Whether to include a colorbar
x_extend_ratio (float, optional) – Ratio to extend x-axis
y_extend_ratio (float, optional) – Ratio to extend y-axis
spath (str, optional) – Path to save the figure
plt (deprecated) – Kept for backward compatibility only
cm – Pre-computed confusion matrix
y_true – True labels
y_pred – Predicted labels
y_pred_proba – Predicted probabilities
labels – List of labels
sorted_labels – Sorted list of labels
pred_labels – Predicted label names
sorted_pred_labels – Sorted predicted label names
true_labels – True label names
sorted_true_labels – Sorted true label names
label_rotation_xy – Rotation angles for x and y labels
title – Title of the plot
colorbar – Whether to include a colorbar
x_extend_ratio – Ratio to extend x-axis
y_extend_ratio – Ratio to extend y-axis
spath – Path to save the figure
- Returns:
fig (matplotlib.figure.Figure) – The figure object containing the plot
cm (pandas.DataFrame) – The confusion matrix as a DataFrame
Example
y_true = [0, 1, 2, 0, 1, 2] y_pred = [0, 2, 1, 0, 0, 1] labels = [‘A’, ‘B’, ‘C’] fig, cm = conf_mat(plt, y_true=y_true, y_pred=y_pred, labels=labels) plt.show()
- scitex_ml.plt.plot_learning_curve(metrics_df, keys, title='Title', max_n_ticks=4, scattersize=3, linewidth=1, yscale='linear', spath=None)[source]
Plot learning curves from training metrics.
This is mainly used by scitex/ml/training/_LearningCurveLogger.py
- Parameters:
metrics_df (pd.DataFrame) – DataFrame with columns: step, i_global, i_epoch, i_batch, and metric columns
title (str) – Plot title
max_n_ticks (int) – Maximum number of ticks on x-axis
scattersize (float) – Size of scatter points for validation/test
linewidth (float) – Width of training line
yscale (str) – Y-axis scale (‘linear’ or ‘log’)
spath (str, optional) – Save path for the figure
- Returns:
fig – Figure containing learning curves
- Return type:
matplotlib.figure.Figure
Example
>>> print(metrics_df) # step i_global i_epoch i_batch loss # 0 Training 0 0 0 0.717023 # 1 Training 1 0 1 0.703844 # ... # [123271 rows x 5 columns]
- scitex_ml.plt.learning_curve(metrics_df, keys, title='Title', max_n_ticks=4, scattersize=3, linewidth=1, yscale='linear', spath=None)
Plot learning curves from training metrics.
This is mainly used by scitex/ml/training/_LearningCurveLogger.py
- Parameters:
metrics_df (pd.DataFrame) – DataFrame with columns: step, i_global, i_epoch, i_batch, and metric columns
title (str) – Plot title
max_n_ticks (int) – Maximum number of ticks on x-axis
scattersize (float) – Size of scatter points for validation/test
linewidth (float) – Width of training line
yscale (str) – Y-axis scale (‘linear’ or ‘log’)
spath (str, optional) – Save path for the figure
- Returns:
fig – Figure containing learning curves
- Return type:
matplotlib.figure.Figure
Example
>>> print(metrics_df) # step i_global i_epoch i_batch loss # 0 Training 0 0 0 0.717023 # 1 Training 1 0 1 0.703844 # ... # [123271 rows x 5 columns]
- scitex_ml.plt.optuna_study(lpath, value_str, sort=False)
Loads an Optuna study and generates various visualizations for each target metric.
Parameters: - lpath (str): Path to the Optuna study database. - value_str (str): The name of the column to be used as the optimization target.
Returns: - None
- scitex_ml.plt.plot_optuna_study(lpath, value_str, sort=False)[source]
Loads an Optuna study and generates various visualizations for each target metric.
Parameters: - lpath (str): Path to the Optuna study database. - value_str (str): The name of the column to be used as the optimization target.
Returns: - None
- scitex_ml.plt.plot_roc_curve(true_class, pred_proba, labels, ax=None, spath=None)[source]
Plot ROC-AUC curve.
- Parameters:
- Returns:
fig (matplotlib.figure.Figure) – Figure object
metrics (dict) – ROC metrics
- scitex_ml.plt.plot_pre_rec_curve(true_class, pred_proba, labels, ax=None, spath=None)[source]
Plot precision-recall curve.
- Parameters:
- Returns:
fig (matplotlib.figure.Figure) – Figure object
metrics (dict) – Precision-recall metrics
- scitex_ml.plt.plot_feature_importance(importance, feature_names=None, top_n=20, title='Feature Importance', xlabel='Importance', figsize=(10, 8), spath=None)[source]
Plot feature importance as a horizontal bar chart.
- Parameters:
importance (np.ndarray or Dict[str, float]) – Feature importance values. If array, must match feature_names length. If dict, keys are feature names and values are importances.
feature_names (List[str], optional) – Names of features (required if importance is array)
top_n (int, default 20) – Number of top features to display
title (str, default "Feature Importance") – Plot title
xlabel (str, default "Importance") – X-axis label
figsize (tuple, default (10, 8)) – Figure size
spath (Union[str, Path], optional) – Path to save the figure
- Returns:
fig – The figure object
- Return type:
matplotlib.figure.Figure
Examples
>>> from sklearn.ensemble import RandomForestClassifier >>> import numpy as np >>> X = np.random.rand(100, 5) >>> y = np.random.randint(0, 2, 100) >>> model = RandomForestClassifier().fit(X, y) >>> fig = plot_feature_importance( ... model.feature_importances_, ... feature_names=['f1', 'f2', 'f3', 'f4', 'f5'], ... spath='feature_importance.jpg' ... )
- scitex_ml.plt.plot_feature_importance_cv_summary(all_importances, top_n=20, title=None, figsize=(12, 8), spath=None)[source]
Plot feature importance summary across cross-validation folds with error bars.
- Parameters:
all_importances (List[Dict[str, float]]) – List of importance dictionaries from each fold
top_n (int, default 20) – Number of top features to display
title (str, optional) – Plot title (auto-generated if None)
figsize (tuple, default (12, 8)) – Figure size
spath (Union[str, Path], optional) – Path to save the figure
- Returns:
fig – The figure object
- Return type:
matplotlib.figure.Figure
Examples
>>> # After cross-validation >>> all_importances = [ ... {'feature1': 0.3, 'feature2': 0.7}, ... {'feature1': 0.4, 'feature2': 0.6}, ... ] >>> fig = plot_feature_importance_cv_summary( ... all_importances, ... spath='feature_importance_cv_summary.jpg' ... )
- scitex_ml.plt.plot_tra(ax, metrics_df, metric_key, linewidth=1, color=None)
Plot training phase data as line.
- scitex_ml.plt.process_i_global(metrics_df)
Prepare metrics DataFrame with i_global as index.
- scitex_ml.plt.scatter_tes(ax, metrics_df, metric_key, markersize=3, color=None)
Plot test phase data as scatter.
- scitex_ml.plt.scatter_val(ax, metrics_df, metric_key, markersize=3, color=None)
Plot validation phase data as scatter.
- scitex_ml.plt.select_ticks(metrics_df, max_n_ticks=4)
Select representative epoch tick positions and labels.
- scitex_ml.plt.set_yaxis_for_acc(ax, metric_key)
Configure y-axis for accuracy metrics.
- scitex_ml.plt.vline_at_epochs(ax, metrics_df, color='grey')
Add vertical lines at epoch boundaries.
- scitex_ml.plt.calc_bACC_from_conf_mat(cm)
Calculate balanced accuracy from confusion matrix.
- Parameters:
cm (np.ndarray) – Confusion matrix
- Returns:
Balanced accuracy
- Return type: