scitex_ml.classification.CrossValidationExperiment

class scitex_ml.classification.CrossValidationExperiment(name, model_fn, cv=None, output_dir=None, metrics=None, save_models=True, verbose=True)[source]

Streamlined cross-validation experiment runner.

This class handles: - Cross-validation splitting - Model training and evaluation - Automatic metric calculation - Hyperparameter tracking - Progress monitoring - Report generation

Parameters:
  • name (str) – Experiment name

  • model_fn (Callable) – Function that returns a model instance

  • cv (BaseCrossValidator, optional) – Cross-validation splitter (default: 5-fold stratified)

  • output_dir (Union[str, Path], optional) – Output directory for results

  • metrics (List[str], optional) – List of metrics to calculate

  • save_models (bool) – Whether to save trained models

  • verbose (bool) – Whether to print progress

__init__(name, model_fn, cv=None, output_dir=None, metrics=None, save_models=True, verbose=True)[source]

Methods

__init__(name, model_fn[, cv, output_dir, ...])

describe_dataset(X, y[, feature_names, ...])

Record dataset information.

get_summary()

Get summary statistics across folds.

get_validation_report()

Get validation report.

run(X, y[, feature_names, class_names, ...])

Run complete cross-validation experiment.

set_hyperparameters(**kwargs)

Set hyperparameters for tracking.

__init__(name, model_fn, cv=None, output_dir=None, metrics=None, save_models=True, verbose=True)[source]
set_hyperparameters(**kwargs)[source]

Set hyperparameters for tracking.

Parameters:

**kwargs – Hyperparameter key-value pairs

Return type:

None

describe_dataset(X, y, feature_names=None, class_names=None)[source]

Record dataset information.

Parameters:
  • X (np.ndarray) – Features

  • y (np.ndarray) – Labels

  • feature_names (List[str], optional) – Feature names

  • class_names (List[str], optional) – Class names

Return type:

None

run(X, y, feature_names=None, class_names=None, calculate_curves=True)[source]

Run complete cross-validation experiment.

Parameters:
  • X (np.ndarray) – Features

  • y (np.ndarray) – Labels

  • feature_names (List[str], optional) – Feature names

  • class_names (List[str], optional) – Class names

  • calculate_curves (bool) – Whether to calculate and plot ROC/PR curves

Returns:

Experiment results and paths

Return type:

Dict[str, Any]

get_summary()[source]

Get summary statistics across folds.

Return type:

DataFrame

get_validation_report()[source]

Get validation report.

Return type:

Dict[str, Any]