src.slune.searchers.grid
1from typing import List, Tuple 2from slune.base import BaseSearcher, BaseSaver 3from slune.utils import dict_to_strings 4 5class SearcherGrid(BaseSearcher): 6 """ Searcher for grid search. 7 8 Given dictionary of parameters and values to try, creates grid of all possible configurations, 9 and returns them one by one for each call to next_tune. 10 11 Attributes: 12 - configs (dict): Parameters and values to create grid from. 13 Structure of dictionary should be: { "--parameter_name" : [Value_1, Value_2, ...], ... } 14 - runs (int): Controls search based on number of runs we want for each config. 15 if runs > 0 -> run each config 'runs' times. 16 if runs = 0 -> run each config once even if it already exists. 17 This behavior is modified if we want to (use) check_existing_runs, see methods description. 18 - grid (list of dict): List of dictionaries, each containing one combination of argument values. 19 - grid_index (int): Index of the current configuration in the grid. 20 - saver_exists (function): Pointer to the savers exists method, used to check if there are existing runs. 21 22 """ 23 24 def __init__(self, configs: dict, runs: int = 0): 25 """ Initialises the searcher. 26 27 Args: 28 - configs (dict): Dictionary of parameters and values to try. 29 Structure of dictionary should be: { "--parameter_name" : [Value_1, Value_2, ...], ... } 30 - runs (int, optional): Controls search based on number of runs we want for each config. 31 if runs > 0 -> run each config 'runs' times. 32 if runs = 0 -> run each config once even if it already exists. 33 This behavior is modified if we want to (use) check_existing_runs, see methods description. 34 35 """ 36 37 super().__init__() 38 self.runs = runs 39 self.configs = configs 40 self.grid = self.get_grid(configs) 41 self.grid_index = None 42 self.saver_exists = None 43 44 def __len__(self): 45 """ Returns the number of configurations defined by search space. 46 47 This may not be accurate if we want to (use) check_existing_runs, 48 as we may skip configurations, 49 see methods description. 50 51 Returns: 52 - num_configs (int): Number of configurations defined by search space. 53 54 """ 55 56 return len(self.grid) * self.runs 57 58 def get_grid(self, param_dict: dict) -> List: 59 """ Creates search grid. 60 61 Generates all possible combinations of values for each argument in the given dictionary using recursion. 62 63 Args: 64 - param_dict (dict): A dictionary where keys are argument names and values are lists of values. 65 66 Returns: 67 - all_combinations (list): A list of dictionaries, each containing one combination of argument values. 68 69 """ 70 71 # Helper function to recursively generate combinations 72 def generate_combinations(param_names, current_combination, all_combinations): 73 if not param_names: 74 # If there are no more parameters to combine, add the current combination to the result 75 all_combinations.append(dict(current_combination)) 76 return 77 78 param_name = param_names[0] 79 param_values = param_dict[param_name] 80 81 for value in param_values: 82 current_combination[param_name] = value 83 # Recursively generate combinations for the remaining parameters 84 generate_combinations(param_names[1:], current_combination, all_combinations) 85 86 # Start with an empty combination and generate all combinations 87 all_combinations = [] 88 generate_combinations(list(param_dict.keys()), {}, all_combinations) 89 90 return all_combinations 91 92 def check_existing_runs(self, saver: BaseSaver): 93 """ We save a pointer to the savers exists method to check if there are existing runs. 94 95 If there are n existing runs: 96 n < runs -> run the remaining runs 97 n >= runs -> skip all runs 98 99 Args: 100 - saver (BaseSaver): Pointer to the savers exists method, used to check if there are existing runs. 101 102 """ 103 104 if self.runs != 0: 105 self.saver_exists = saver.exists 106 else: 107 raise ValueError("Won't check for existing runs if runs = 0, Set runs > 0.") 108 109 def skip_existing_runs(self, grid_index: int) -> Tuple[int, int]: 110 """ Skips runs if they are in storage already. 111 112 Will check if there are existing runs for the current configuration, 113 if there are existing runs we tally them up 114 and skip configs or runs of a config based on the number of runs we want for each config. 115 116 Args: 117 - grid_index (int): Index of the current configuration in the grid. 118 119 Returns: 120 - grid_index (int): Index of the next configuration in the grid. 121 - run_index (int): Index of the next run for the current configuration. 122 """ 123 if self.saver_exists != None: 124 # Check if there are existing runs, if so skip them 125 existing_runs = self.saver_exists(dict_to_strings(self.grid[grid_index])) 126 if self.runs - existing_runs > 0: 127 run_index = existing_runs 128 return grid_index, run_index 129 else: 130 grid_index += 1 131 run_index = 0 132 return self.skip_existing_runs(grid_index) 133 else: 134 if grid_index == len(self.grid): 135 raise IndexError('Reached end of grid, no more configurations to try.') 136 return grid_index, 0 137 138 def next_tune(self) -> dict: 139 """ Returns the next configuration to try. 140 141 Will skip existing runs if check_existing_runs has been called. 142 For more information on how this works check the methods descriptions for check_existing_runs and skip_existing_runs. 143 Will raise an error if we have reached the end of the grid. 144 To iterate through all configurations, use a for loop like so: 145 for config in searcher: ... 146 147 Returns: 148 - next_config (dict): The next configuration to try. 149 """ 150 # If this is the first call to next_tune, set grid_index to 0 151 if self.grid_index is None: 152 self.grid_index = 0 153 self.grid_index, self.run_index = self.skip_existing_runs(self.grid_index) 154 elif self.run_index < self.runs - 1: 155 self.run_index += 1 156 else: 157 self.grid_index += 1 158 self.grid_index, self.run_index = self.skip_existing_runs(self.grid_index) 159 # If we have reached the end of the grid, raise an error 160 if self.grid_index == len(self.grid): 161 raise IndexError('Reached end of grid, no more configurations to try.') 162 # Return the next configuration to try 163 next_config = dict_to_strings(self.grid[self.grid_index]) 164 return next_config
6class SearcherGrid(BaseSearcher): 7 """ Searcher for grid search. 8 9 Given dictionary of parameters and values to try, creates grid of all possible configurations, 10 and returns them one by one for each call to next_tune. 11 12 Attributes: 13 - configs (dict): Parameters and values to create grid from. 14 Structure of dictionary should be: { "--parameter_name" : [Value_1, Value_2, ...], ... } 15 - runs (int): Controls search based on number of runs we want for each config. 16 if runs > 0 -> run each config 'runs' times. 17 if runs = 0 -> run each config once even if it already exists. 18 This behavior is modified if we want to (use) check_existing_runs, see methods description. 19 - grid (list of dict): List of dictionaries, each containing one combination of argument values. 20 - grid_index (int): Index of the current configuration in the grid. 21 - saver_exists (function): Pointer to the savers exists method, used to check if there are existing runs. 22 23 """ 24 25 def __init__(self, configs: dict, runs: int = 0): 26 """ Initialises the searcher. 27 28 Args: 29 - configs (dict): Dictionary of parameters and values to try. 30 Structure of dictionary should be: { "--parameter_name" : [Value_1, Value_2, ...], ... } 31 - runs (int, optional): Controls search based on number of runs we want for each config. 32 if runs > 0 -> run each config 'runs' times. 33 if runs = 0 -> run each config once even if it already exists. 34 This behavior is modified if we want to (use) check_existing_runs, see methods description. 35 36 """ 37 38 super().__init__() 39 self.runs = runs 40 self.configs = configs 41 self.grid = self.get_grid(configs) 42 self.grid_index = None 43 self.saver_exists = None 44 45 def __len__(self): 46 """ Returns the number of configurations defined by search space. 47 48 This may not be accurate if we want to (use) check_existing_runs, 49 as we may skip configurations, 50 see methods description. 51 52 Returns: 53 - num_configs (int): Number of configurations defined by search space. 54 55 """ 56 57 return len(self.grid) * self.runs 58 59 def get_grid(self, param_dict: dict) -> List: 60 """ Creates search grid. 61 62 Generates all possible combinations of values for each argument in the given dictionary using recursion. 63 64 Args: 65 - param_dict (dict): A dictionary where keys are argument names and values are lists of values. 66 67 Returns: 68 - all_combinations (list): A list of dictionaries, each containing one combination of argument values. 69 70 """ 71 72 # Helper function to recursively generate combinations 73 def generate_combinations(param_names, current_combination, all_combinations): 74 if not param_names: 75 # If there are no more parameters to combine, add the current combination to the result 76 all_combinations.append(dict(current_combination)) 77 return 78 79 param_name = param_names[0] 80 param_values = param_dict[param_name] 81 82 for value in param_values: 83 current_combination[param_name] = value 84 # Recursively generate combinations for the remaining parameters 85 generate_combinations(param_names[1:], current_combination, all_combinations) 86 87 # Start with an empty combination and generate all combinations 88 all_combinations = [] 89 generate_combinations(list(param_dict.keys()), {}, all_combinations) 90 91 return all_combinations 92 93 def check_existing_runs(self, saver: BaseSaver): 94 """ We save a pointer to the savers exists method to check if there are existing runs. 95 96 If there are n existing runs: 97 n < runs -> run the remaining runs 98 n >= runs -> skip all runs 99 100 Args: 101 - saver (BaseSaver): Pointer to the savers exists method, used to check if there are existing runs. 102 103 """ 104 105 if self.runs != 0: 106 self.saver_exists = saver.exists 107 else: 108 raise ValueError("Won't check for existing runs if runs = 0, Set runs > 0.") 109 110 def skip_existing_runs(self, grid_index: int) -> Tuple[int, int]: 111 """ Skips runs if they are in storage already. 112 113 Will check if there are existing runs for the current configuration, 114 if there are existing runs we tally them up 115 and skip configs or runs of a config based on the number of runs we want for each config. 116 117 Args: 118 - grid_index (int): Index of the current configuration in the grid. 119 120 Returns: 121 - grid_index (int): Index of the next configuration in the grid. 122 - run_index (int): Index of the next run for the current configuration. 123 """ 124 if self.saver_exists != None: 125 # Check if there are existing runs, if so skip them 126 existing_runs = self.saver_exists(dict_to_strings(self.grid[grid_index])) 127 if self.runs - existing_runs > 0: 128 run_index = existing_runs 129 return grid_index, run_index 130 else: 131 grid_index += 1 132 run_index = 0 133 return self.skip_existing_runs(grid_index) 134 else: 135 if grid_index == len(self.grid): 136 raise IndexError('Reached end of grid, no more configurations to try.') 137 return grid_index, 0 138 139 def next_tune(self) -> dict: 140 """ Returns the next configuration to try. 141 142 Will skip existing runs if check_existing_runs has been called. 143 For more information on how this works check the methods descriptions for check_existing_runs and skip_existing_runs. 144 Will raise an error if we have reached the end of the grid. 145 To iterate through all configurations, use a for loop like so: 146 for config in searcher: ... 147 148 Returns: 149 - next_config (dict): The next configuration to try. 150 """ 151 # If this is the first call to next_tune, set grid_index to 0 152 if self.grid_index is None: 153 self.grid_index = 0 154 self.grid_index, self.run_index = self.skip_existing_runs(self.grid_index) 155 elif self.run_index < self.runs - 1: 156 self.run_index += 1 157 else: 158 self.grid_index += 1 159 self.grid_index, self.run_index = self.skip_existing_runs(self.grid_index) 160 # If we have reached the end of the grid, raise an error 161 if self.grid_index == len(self.grid): 162 raise IndexError('Reached end of grid, no more configurations to try.') 163 # Return the next configuration to try 164 next_config = dict_to_strings(self.grid[self.grid_index]) 165 return next_config
Searcher for grid search.
Given dictionary of parameters and values to try, creates grid of all possible configurations, and returns them one by one for each call to next_tune.
Attributes:
- - configs (dict): Parameters and values to create grid from. Structure of dictionary should be: { "--parameter_name" : [Value_1, Value_2, ...], ... }
- - runs (int): Controls search based on number of runs we want for each config. if runs > 0 -> run each config 'runs' times. if runs = 0 -> run each config once even if it already exists. This behavior is modified if we want to (use) check_existing_runs, see methods description.
- - grid (list of dict): List of dictionaries, each containing one combination of argument values.
- - grid_index (int): Index of the current configuration in the grid.
- - saver_exists (function): Pointer to the savers exists method, used to check if there are existing runs.
25 def __init__(self, configs: dict, runs: int = 0): 26 """ Initialises the searcher. 27 28 Args: 29 - configs (dict): Dictionary of parameters and values to try. 30 Structure of dictionary should be: { "--parameter_name" : [Value_1, Value_2, ...], ... } 31 - runs (int, optional): Controls search based on number of runs we want for each config. 32 if runs > 0 -> run each config 'runs' times. 33 if runs = 0 -> run each config once even if it already exists. 34 This behavior is modified if we want to (use) check_existing_runs, see methods description. 35 36 """ 37 38 super().__init__() 39 self.runs = runs 40 self.configs = configs 41 self.grid = self.get_grid(configs) 42 self.grid_index = None 43 self.saver_exists = None
Initialises the searcher.
Arguments:
- - configs (dict): Dictionary of parameters and values to try. Structure of dictionary should be: { "--parameter_name" : [Value_1, Value_2, ...], ... }
- - runs (int, optional): Controls search based on number of runs we want for each config. if runs > 0 -> run each config 'runs' times. if runs = 0 -> run each config once even if it already exists. This behavior is modified if we want to (use) check_existing_runs, see methods description.
59 def get_grid(self, param_dict: dict) -> List: 60 """ Creates search grid. 61 62 Generates all possible combinations of values for each argument in the given dictionary using recursion. 63 64 Args: 65 - param_dict (dict): A dictionary where keys are argument names and values are lists of values. 66 67 Returns: 68 - all_combinations (list): A list of dictionaries, each containing one combination of argument values. 69 70 """ 71 72 # Helper function to recursively generate combinations 73 def generate_combinations(param_names, current_combination, all_combinations): 74 if not param_names: 75 # If there are no more parameters to combine, add the current combination to the result 76 all_combinations.append(dict(current_combination)) 77 return 78 79 param_name = param_names[0] 80 param_values = param_dict[param_name] 81 82 for value in param_values: 83 current_combination[param_name] = value 84 # Recursively generate combinations for the remaining parameters 85 generate_combinations(param_names[1:], current_combination, all_combinations) 86 87 # Start with an empty combination and generate all combinations 88 all_combinations = [] 89 generate_combinations(list(param_dict.keys()), {}, all_combinations) 90 91 return all_combinations
Creates search grid.
Generates all possible combinations of values for each argument in the given dictionary using recursion.
Arguments:
- - param_dict (dict): A dictionary where keys are argument names and values are lists of values.
Returns:
- all_combinations (list): A list of dictionaries, each containing one combination of argument values.
93 def check_existing_runs(self, saver: BaseSaver): 94 """ We save a pointer to the savers exists method to check if there are existing runs. 95 96 If there are n existing runs: 97 n < runs -> run the remaining runs 98 n >= runs -> skip all runs 99 100 Args: 101 - saver (BaseSaver): Pointer to the savers exists method, used to check if there are existing runs. 102 103 """ 104 105 if self.runs != 0: 106 self.saver_exists = saver.exists 107 else: 108 raise ValueError("Won't check for existing runs if runs = 0, Set runs > 0.")
We save a pointer to the savers exists method to check if there are existing runs.
If there are n existing runs:
n < runs -> run the remaining runs n >= runs -> skip all runs
Arguments:
- - saver (BaseSaver): Pointer to the savers exists method, used to check if there are existing runs.
110 def skip_existing_runs(self, grid_index: int) -> Tuple[int, int]: 111 """ Skips runs if they are in storage already. 112 113 Will check if there are existing runs for the current configuration, 114 if there are existing runs we tally them up 115 and skip configs or runs of a config based on the number of runs we want for each config. 116 117 Args: 118 - grid_index (int): Index of the current configuration in the grid. 119 120 Returns: 121 - grid_index (int): Index of the next configuration in the grid. 122 - run_index (int): Index of the next run for the current configuration. 123 """ 124 if self.saver_exists != None: 125 # Check if there are existing runs, if so skip them 126 existing_runs = self.saver_exists(dict_to_strings(self.grid[grid_index])) 127 if self.runs - existing_runs > 0: 128 run_index = existing_runs 129 return grid_index, run_index 130 else: 131 grid_index += 1 132 run_index = 0 133 return self.skip_existing_runs(grid_index) 134 else: 135 if grid_index == len(self.grid): 136 raise IndexError('Reached end of grid, no more configurations to try.') 137 return grid_index, 0
Skips runs if they are in storage already.
Will check if there are existing runs for the current configuration, if there are existing runs we tally them up and skip configs or runs of a config based on the number of runs we want for each config.
Arguments:
- - grid_index (int): Index of the current configuration in the grid.
Returns:
- grid_index (int): Index of the next configuration in the grid.
- run_index (int): Index of the next run for the current configuration.
139 def next_tune(self) -> dict: 140 """ Returns the next configuration to try. 141 142 Will skip existing runs if check_existing_runs has been called. 143 For more information on how this works check the methods descriptions for check_existing_runs and skip_existing_runs. 144 Will raise an error if we have reached the end of the grid. 145 To iterate through all configurations, use a for loop like so: 146 for config in searcher: ... 147 148 Returns: 149 - next_config (dict): The next configuration to try. 150 """ 151 # If this is the first call to next_tune, set grid_index to 0 152 if self.grid_index is None: 153 self.grid_index = 0 154 self.grid_index, self.run_index = self.skip_existing_runs(self.grid_index) 155 elif self.run_index < self.runs - 1: 156 self.run_index += 1 157 else: 158 self.grid_index += 1 159 self.grid_index, self.run_index = self.skip_existing_runs(self.grid_index) 160 # If we have reached the end of the grid, raise an error 161 if self.grid_index == len(self.grid): 162 raise IndexError('Reached end of grid, no more configurations to try.') 163 # Return the next configuration to try 164 next_config = dict_to_strings(self.grid[self.grid_index]) 165 return next_config
Returns the next configuration to try.
Will skip existing runs if check_existing_runs has been called. For more information on how this works check the methods descriptions for check_existing_runs and skip_existing_runs. Will raise an error if we have reached the end of the grid. To iterate through all configurations, use a for loop like so: for config in searcher: ...
Returns:
- next_config (dict): The next configuration to try.