SciTeX ML

Getting Started

  • Installation
  • Quickstart

API Reference

  • API Reference
    • scitex_ml
      • ClassificationReporter
        • ClassificationReporter.__init__()
        • ClassificationReporter.calculate_metrics()
        • ClassificationReporter.save()
        • ClassificationReporter.get_summary()
        • ClassificationReporter.save_summary()
        • ClassificationReporter.save_feature_importance()
        • ClassificationReporter.save_feature_importance_summary()
      • Classifier
        • Classifier.__init__()
        • Classifier.list
      • EarlyStopping
        • EarlyStopping.__init__()
        • EarlyStopping.is_best()
        • EarlyStopping.save()
      • LearningCurveLogger
        • LearningCurveLogger.__init__()
        • LearningCurveLogger.dfs
        • LearningCurveLogger.to_metrics_df()
        • LearningCurveLogger.plot_learning_curves()
        • LearningCurveLogger.get_x_of_i_epoch()
        • LearningCurveLogger.print()
      • MultiTaskLoss
        • MultiTaskLoss.__init__()
        • MultiTaskLoss.forward()
      • get_optimizer()
      • set_optimizer()
      • scitex_ml.activation
        • define()
      • scitex_ml.classification
        • ClassificationReporter
        • SingleTaskClassificationReporter
        • Classifier
        • CrossValidationExperiment
        • CVExperiment
        • quick_experiment()
        • TimeSeriesStratifiedSplit
        • TimeSeriesBlockingSplit
        • TimeSeriesSlidingWindowSplit
        • TimeSeriesCalendarSplit
        • TimeSeriesStrategy
        • TimeSeriesMetadata
        • scitex_ml.classification.Classifier
        • scitex_ml.classification.CrossValidationExperiment
        • scitex_ml.classification.reporters
        • scitex_ml.classification.timeseries
      • scitex_ml.clustering
        • main()
        • pca()
        • umap()
      • scitex_ml.feature_extraction
        • scitex_ml.feature_extraction.vit
      • scitex_ml.feature_selection
        • extract_feature_importance()
        • select_features_univariate()
        • analyze_feature_consistency()
        • aggregate_feature_importances()
        • create_feature_importance_dataframe()
        • scitex_ml.feature_selection.feature_selection
      • scitex_ml.loss
        • elastic()
        • l1()
        • l2()
        • MultiTaskLoss
        • scitex_ml.loss.multi_task_loss
      • scitex_ml.metrics
        • calc_bacc()
        • calc_mcc()
        • calc_conf_mat()
        • calc_clf_report()
        • calc_roc_auc()
        • calc_pre_rec_auc()
        • calc_bacc_from_conf_mat()
        • calc_seizure_window_prediction_metrics()
        • calc_seizure_event_prediction_metrics()
        • calc_seizure_prediction_metrics()
        • calc_silhouette_score_slow()
        • calc_silhouette_samples_slow()
        • calc_silhouette_score_block()
        • calc_silhouette_samples_block()
        • calc_feature_importance()
        • calc_permutation_importance()
      • scitex_ml.optim
        • get()
        • get_optimizer()
        • set()
        • set_optimizer()
        • scitex_ml.optim.Ranger_Deep_Learning_Optimizer
      • scitex_ml.plt
        • stx_conf_mat()
        • conf_mat()
        • plot_learning_curve()
        • learning_curve()
        • optuna_study()
        • plot_optuna_study()
        • plot_roc_curve()
        • plot_pre_rec_curve()
        • plot_feature_importance()
        • plot_feature_importance_cv_summary()
        • plot_tra()
        • process_i_global()
        • scatter_tes()
        • scatter_val()
        • select_ticks()
        • set_yaxis_for_acc()
        • vline_at_epochs()
        • calc_bACC_from_conf_mat()
        • calc_bacc_from_conf_mat()
      • scitex_ml.sk
        • rocket_pipeline()
        • to_sktime_df()
      • scitex_ml.sklearn
        • scitex_ml.sklearn.clf
        • scitex_ml.sklearn.to_sktime
      • scitex_ml.training
        • EarlyStopping
        • LearningCurveLogger
      • scitex_ml.utils
        • DefaultDataset
        • LabelEncoder
        • check_params()
        • format_samples_for_sktime()
        • merge_labels()
        • sliding_window_data_augmentation()
        • under_sample()
        • verify_n_gpus()
        • scitex_ml.utils.grid_search
    • scitex_ml.classification
      • ClassificationReporter
        • ClassificationReporter.__init__()
        • ClassificationReporter.calculate_metrics()
        • ClassificationReporter.save()
        • ClassificationReporter.get_summary()
        • ClassificationReporter.save_summary()
        • ClassificationReporter.save_feature_importance()
        • ClassificationReporter.save_feature_importance_summary()
      • SingleTaskClassificationReporter
        • SingleTaskClassificationReporter.__init__()
        • SingleTaskClassificationReporter.set_session_config()
        • SingleTaskClassificationReporter.save_summary()
      • Classifier
        • Classifier.__init__()
        • Classifier.list
      • CrossValidationExperiment
        • CrossValidationExperiment.__init__()
        • CrossValidationExperiment.set_hyperparameters()
        • CrossValidationExperiment.describe_dataset()
        • CrossValidationExperiment.run()
        • CrossValidationExperiment.get_summary()
        • CrossValidationExperiment.get_validation_report()
      • CVExperiment
      • quick_experiment()
      • TimeSeriesStratifiedSplit
        • TimeSeriesStratifiedSplit.__init__()
        • TimeSeriesStratifiedSplit.split()
        • TimeSeriesStratifiedSplit.split_with_val()
        • TimeSeriesStratifiedSplit.get_n_splits()
        • TimeSeriesStratifiedSplit.plot_splits()
        • TimeSeriesStratifiedSplit.set_split_request()
      • TimeSeriesBlockingSplit
        • TimeSeriesBlockingSplit.__init__()
        • TimeSeriesBlockingSplit.split()
        • TimeSeriesBlockingSplit.split_with_val()
        • TimeSeriesBlockingSplit.get_n_splits()
        • TimeSeriesBlockingSplit.plot_splits()
        • TimeSeriesBlockingSplit.set_split_request()
      • TimeSeriesSlidingWindowSplit
        • TimeSeriesSlidingWindowSplit.__init__()
        • TimeSeriesSlidingWindowSplit.set_split_request()
      • TimeSeriesCalendarSplit
        • TimeSeriesCalendarSplit.__init__()
        • TimeSeriesCalendarSplit.split()
        • TimeSeriesCalendarSplit.split_with_val()
        • TimeSeriesCalendarSplit.get_n_splits()
        • TimeSeriesCalendarSplit.plot_splits()
        • TimeSeriesCalendarSplit.set_split_request()
      • TimeSeriesStrategy
        • TimeSeriesStrategy.STRATIFIED
        • TimeSeriesStrategy.BLOCKING
        • TimeSeriesStrategy.SLIDING
        • TimeSeriesStrategy.EXPANDING
        • TimeSeriesStrategy.FIXED
        • TimeSeriesStrategy.STRATIFIED
        • TimeSeriesStrategy.BLOCKING
        • TimeSeriesStrategy.SLIDING
        • TimeSeriesStrategy.EXPANDING
        • TimeSeriesStrategy.FIXED
        • TimeSeriesStrategy.from_string()
        • TimeSeriesStrategy.get_description()
      • TimeSeriesMetadata
        • TimeSeriesMetadata.n_samples
        • TimeSeriesMetadata.n_features
        • TimeSeriesMetadata.n_classes
        • TimeSeriesMetadata.has_groups
        • TimeSeriesMetadata.group_sizes
        • TimeSeriesMetadata.time_range
        • TimeSeriesMetadata.sampling_rate
        • TimeSeriesMetadata.has_gaps
        • TimeSeriesMetadata.max_gap_size
        • TimeSeriesMetadata.is_balanced
        • TimeSeriesMetadata.class_distribution
        • TimeSeriesMetadata.n_samples
        • TimeSeriesMetadata.n_features
        • TimeSeriesMetadata.n_classes
        • TimeSeriesMetadata.has_groups
        • TimeSeriesMetadata.group_sizes
        • TimeSeriesMetadata.time_range
        • TimeSeriesMetadata.sampling_rate
        • TimeSeriesMetadata.has_gaps
        • TimeSeriesMetadata.max_gap_size
        • TimeSeriesMetadata.is_balanced
        • TimeSeriesMetadata.class_distribution
        • TimeSeriesMetadata.get_summary()
        • TimeSeriesMetadata.suggest_strategy()
        • TimeSeriesMetadata.__init__()
      • scitex_ml.classification.Classifier
        • Classifier
      • scitex_ml.classification.CrossValidationExperiment
        • CrossValidationExperiment
      • scitex_ml.classification.reporters
        • ClassificationReporter
        • SingleTaskClassificationReporter
        • scitex_ml.classification.reporters.reporter_utils
      • scitex_ml.classification.timeseries
        • TimeSeriesStratifiedSplit
        • TimeSeriesBlockingSplit
        • TimeSeriesSlidingWindowSplit
        • TimeSeriesCalendarSplit
        • TimeSeriesStrategy
        • TimeSeriesMetadata
        • normalize_timestamp()
    • scitex_ml.clustering
      • main()
      • pca()
      • umap()
    • scitex_ml.feature_extraction
      • scitex_ml.feature_extraction.vit
        • THIS_FILE
        • VitFeatureExtractor
    • scitex_ml.feature_selection
      • extract_feature_importance()
      • select_features_univariate()
      • analyze_feature_consistency()
      • aggregate_feature_importances()
      • create_feature_importance_dataframe()
      • scitex_ml.feature_selection.feature_selection
        • extract_feature_importance()
        • select_features_univariate()
        • analyze_feature_consistency()
        • aggregate_feature_importances()
        • create_feature_importance_dataframe()
    • scitex_ml.loss
      • elastic()
      • l1()
      • l2()
      • MultiTaskLoss
        • MultiTaskLoss.__init__()
        • MultiTaskLoss.forward()
      • scitex_ml.loss.multi_task_loss
        • MultiTaskLoss
    • scitex_ml.metrics
      • calc_bacc()
      • calc_mcc()
      • calc_conf_mat()
      • calc_clf_report()
      • calc_roc_auc()
      • calc_pre_rec_auc()
      • calc_bacc_from_conf_mat()
      • calc_seizure_window_prediction_metrics()
      • calc_seizure_event_prediction_metrics()
      • calc_seizure_prediction_metrics()
      • calc_silhouette_score_slow()
      • calc_silhouette_samples_slow()
      • calc_silhouette_score_block()
      • calc_silhouette_samples_block()
      • calc_feature_importance()
      • calc_permutation_importance()
    • scitex_ml.optim
      • get()
      • get_optimizer()
      • set()
      • set_optimizer()
      • scitex_ml.optim.Ranger_Deep_Learning_Optimizer
        • scitex_ml.optim.Ranger_Deep_Learning_Optimizer.ranger
    • scitex_ml.plt
      • stx_conf_mat()
      • conf_mat()
      • plot_learning_curve()
      • learning_curve()
      • optuna_study()
      • plot_optuna_study()
      • plot_roc_curve()
      • plot_pre_rec_curve()
      • plot_feature_importance()
      • plot_feature_importance_cv_summary()
      • plot_tra()
      • process_i_global()
      • scatter_tes()
      • scatter_val()
      • select_ticks()
      • set_yaxis_for_acc()
      • vline_at_epochs()
      • calc_bACC_from_conf_mat()
      • calc_bacc_from_conf_mat()
    • scitex_ml.sampling
      • scitex_ml.sampling.undersample
        • undersample()
    • scitex_ml.sklearn
      • scitex_ml.sklearn.clf
        • rocket_pipeline()
      • scitex_ml.sklearn.to_sktime
        • to_sktime_df()
    • scitex_ml.training
      • EarlyStopping
        • EarlyStopping.__init__()
        • EarlyStopping.is_best()
        • EarlyStopping.save()
      • LearningCurveLogger
        • LearningCurveLogger.__init__()
        • LearningCurveLogger.dfs
        • LearningCurveLogger.to_metrics_df()
        • LearningCurveLogger.plot_learning_curves()
        • LearningCurveLogger.get_x_of_i_epoch()
        • LearningCurveLogger.print()
    • scitex_ml.utils
      • DefaultDataset
        • DefaultDataset.__init__()
      • LabelEncoder
        • LabelEncoder.classes_
        • LabelEncoder.__init__()
        • LabelEncoder.fit()
        • LabelEncoder.transform()
        • LabelEncoder.inverse_transform()
      • check_params()
      • format_samples_for_sktime()
      • merge_labels()
      • sliding_window_data_augmentation()
      • under_sample()
      • verify_n_gpus()
      • scitex_ml.utils.grid_search
        • yield_grids()
        • count_grids()
SciTeX ML
  • API Reference
  • scitex_ml
  • scitex_ml.utils
  • scitex_ml.utils.grid_search
  • Edit on GitHub

scitex_ml.utils.grid_search

This script defines scitex_ml.utils.grid_search

Functions

count_grids(params_grid)

Calculate the total number of combinations possible from the given parameter grid.

yield_grids(params_grid[, random])

Generator function that yields combinations of parameters from a grid.

scitex_ml.utils.grid_search.yield_grids(params_grid, random=False)[source]

Generator function that yields combinations of parameters from a grid.

Parameters:
  • params_grid (dict) – A dictionary where keys are parameter names and values are lists of parameter values.

  • random (bool) – If True, yields the parameter combinations in random order.

Yields:

dict – A dictionary of parameters for one set of conditions from the grid.

Example

# Parameters params_grid = {

“batch_size”: [2**i for i in range(7)], “n_chs”: [2**i for i in range(7)], “seq_len”: [2**i for i in range(15)], “fs”: [2**i for i in range(8, 11)], “n_segments”: [2**i for i in range(6)], “n_bands_pha”: [2**i for i in range(7)], “n_bands_amp”: [2**i for i in range(7)], “precision”: [‘fp16’, ‘fp32’], “device”: [‘cpu’, ‘cuda’], “package”: [‘tensorpac’, ‘scitex’],

}

# Example of using the generator for param_dict in yield_grids(params_grid, random=True):

print(param_dict)

scitex_ml.utils.grid_search.count_grids(params_grid)[source]

Calculate the total number of combinations possible from the given parameter grid.

Parameters:

params_grid (dict) – A dictionary where keys are parameter names and values are lists of parameter values.

Returns:

The total number of combinations that can be generated from the parameter grid.

Return type:

int

Previous Next

© Copyright 2024-2026, Yusuke Watanabe.

Built with Sphinx using a theme provided by Read the Docs.