Coverage for optimates/search.py: 70%
238 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 11:19 -0500
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 11:19 -0500
1"""A generic API for solving discrete search problems."""
3from __future__ import annotations
5from abc import ABC, abstractmethod
6from collections.abc import Iterable
7from dataclasses import dataclass
8from math import exp, inf
9import random
10from typing import Generic, Optional, TypeVar
12from optimates.utils import TopNHeap, logger
15T = TypeVar('T')
16IterData = tuple[T, float] # node and score
19##########
20# ERRORS #
21##########
23class EmptyNeighborSetError(ValueError):
24 """Error for when a node has no neighbors."""
27##########
28# SEARCH #
29##########
31@dataclass
32class Solutions(Generic[T]):
33 """Stores a set of best solutions to a search problem, and the maximum score."""
34 score: Optional[float]
35 solutions: set[T]
37 @classmethod
38 def empty(cls) -> Solutions[T]:
39 """Constructs an empty set of solutions."""
40 return Solutions(None, set())
42 def add(self, solution: T, score: float) -> None:
43 """Adds a new solution to the set, along with its score."""
44 if (self.score is None) or (score > self.score):
45 self.score = score
46 self.solutions = {solution}
47 elif score == self.score:
48 self.solutions.add(solution)
50 def merge(self, other: Solutions[T]) -> Solutions[T]:
51 """Merges another Solutions set into this one (in-place)."""
52 assert isinstance(other, Solutions)
53 if other.score is not None:
54 if (self.score is None) or (other.score > self.score):
55 self.score = other.score
56 self.solutions = other.solutions
57 elif other.score == self.score:
58 self.solutions |= other.solutions
59 return self
61 def __len__(self) -> int:
62 """Gets the total number of stored solutions."""
63 return len(self.solutions)
66class SearchProblem(ABC, Generic[T]):
67 """Generic search problem.
68 This can be viewed as a directed graph of elements (nodes) to search.
69 A subset of the nodes are considered "solutions."
70 Each node has a score, and we wish to find a solution node with the maximum score.
71 Furthermore, nodes may have directed edges to other "neighbor" nodes which are related in such a way that neighbor nodes are similar to each other in some way (hopefully in score as well).
72 A variety of algorithms may be applied to try to search for an optimal solution in the search graph."""
74 def score(self, node: T) -> float:
75 """Scores a node of the search graph."""
76 raise NotImplementedError
78 @abstractmethod
79 def initial_nodes(self) -> Iterable[T]:
80 """Gets the set of initial nodes of the search graph (i.e. the set of nodes with no predecessor)."""
82 def default_initial_node(self) -> T:
83 """Gets some "canonical default" initial node of the search graph."""
84 return next(iter(self.initial_nodes()))
86 @abstractmethod
87 def is_solution(self, node: T) -> bool:
88 """Returns True if a node is a solution."""
90 @abstractmethod
91 def iter_nodes(self) -> Iterable[T]:
92 """Gets an iterable over all nodes in the search space."""
94 @abstractmethod
95 def random_node(self) -> T:
96 """Gets a random node in the search space."""
98 @abstractmethod
99 def get_neighbors(self, node: T) -> Iterable[T]:
100 """Gets the neighbors of a node."""
102 def num_neighbors(self, node: T) -> int:
103 """Counts the number of neighbors of a node."""
104 ctr = 0
105 for _ in self.get_neighbors(node):
106 ctr += 1
107 return ctr
109 @abstractmethod
110 def random_neighbor(self, node: T) -> T:
111 """Gets a random neighbor of a node.
112 By default, this will be distributed uniformly over the neighbor set.
113 If no neighbors exist, raises an EmptyNeighborSetError."""
116@dataclass
117class FilteredSearchProblem(SearchProblem[T]):
118 """Class for a modified search problem where the search space is filtered by some predicate.
119 A subclass should override the `is_element` method."""
120 problem: SearchProblem[T]
122 @abstractmethod
123 def is_element(self, node: T) -> bool:
124 """This method checks whether a node is a valid element of the search problem."""
126 def score(self, node: T) -> float:
127 return self.problem.score(node)
129 def initial_nodes(self) -> Iterable[T]:
130 return filter(self.is_element, self.problem.initial_nodes())
132 def is_solution(self, node: T) -> bool:
133 return self.is_element(node) and self.problem.is_solution(node)
135 def iter_nodes(self) -> Iterable[T]:
136 # NOTE: this can be inefficient, so you should preferably override it
137 return filter(self.is_element, self.problem.iter_nodes())
139 def random_node(self) -> T:
140 # NOTE: this can be inefficient (or non-terminating), so you should preferably override it
141 while True:
142 node = self.problem.random_node()
143 if self.is_element(node):
144 return node
146 def get_neighbors(self, node: T) -> Iterable[T]:
147 return filter(self.is_element, self.problem.get_neighbors(node))
149 def random_neighbor(self, node: T) -> T:
150 while True:
151 nbr = self.problem.random_neighbor(node)
152 if self.is_element(nbr):
153 return nbr
156class _Search(ABC, Generic[T]):
158 @abstractmethod
159 def run(self, initial: Optional[T] = None) -> Solutions[T]:
160 """Performs the search, starting from an initial node."""
163@dataclass
164class Search(_Search[T]):
165 """Base class for a search algorithm.
166 Starting from some initial state, it will attempt to find a global maximum, possibly using the neighborhood structure of ths earch problem."""
167 problem: SearchProblem[T]
169 @abstractmethod
170 def _run(self, node: T) -> Solutions[T]:
171 """Override this method to perform the search from a given node.
172 This should return a Solutions object containing a set of best solutions and their score."""
174 def run(self, initial: Optional[T] = None) -> Solutions[T]:
175 """Performs the search, starting from an initial node.
176 If none is provided, uses the SearchProblem's default initial node."""
177 if initial is None:
178 initial = self.problem.default_initial_node()
179 return self._run(initial)
182@dataclass
183class SearchWithRestarts(_Search[T]):
184 """Wraps another search problem.
185 Runs that search multiple times, taking the best solution over all iterations."""
186 search: Search[T]
187 num_restarts: int = 25 # number of random restarts
188 random_restart: bool = True # whether to use random restarts
190 def run(self, initial: Optional[T] = None) -> Solutions[T]:
191 solutions: Solutions[T] = Solutions.empty()
192 for t in range(self.num_restarts):
193 logger.verbose(f'RESTART #{t + 1}', 1)
194 node = self.search.problem.random_node() if self.random_restart else initial
195 solutions.merge(self.search.run(node))
196 return solutions
199@dataclass
200class HillClimb(Search[T]):
201 """A general framework for discrete optimization which captures many well-known algorithms.
202 Starting from the initial node, this will generate neighbors in some way, scoring each of them.
203 Once they are scored, an acceptance criterion is applied to determine an accepted subset.
204 If this subset is empty, remains at the current node; otherwise, transitions to one of the highest-scoring accepted nodes at random.
205 Proceeds in this way until a stopping criterion is met (e.g. max number of iterations reached, no neighbors were accepted, etc.)."""
206 max_iters: Optional[int] = None
208 def reset(self) -> None:
209 """Resets the state of the hill climb."""
210 # maintain the set of best solutions
211 self.solutions: Solutions[T] = Solutions.empty()
213 def terminate_early(self) -> bool:
214 """Returns True if the algorithm should terminate when no acceptable neighbors are found."""
215 return True
217 def get_neighbors(self, node: T) -> Iterable[T]:
218 """Gets the set of neighbors for a node.
219 By default, it will simply get the neighborhood from the underlying problem, but the search algorithm is free to modify this in some way."""
220 return self.problem.get_neighbors(node)
222 @abstractmethod
223 def accept(self, cur_score: float, nbr_score: float) -> bool:
224 """Given the current node's score and a neighbor node's score, returns True if the neighbor is accepted."""
226 def iterate_search(self, initial: T) -> tuple[Solutions[T], list[IterData[T]]]:
227 """Runs the optimization, returning a pair (best solutions, node sequence)."""
228 self.reset()
229 prob, solns = self.problem, self.solutions
230 max_iters = inf if (self.max_iters is None) else self.max_iters
231 cur_node = initial
232 cur_score = prob.score(initial)
233 pairs = [(cur_node, cur_score)]
234 if prob.is_solution(cur_node): 234 ↛ 236line 234 didn't jump to line 236 because the condition on line 234 was always true
235 solns.add(cur_node, cur_score)
236 t = 1
237 while t <= max_iters:
238 logger.verbose(f'\tIteration #{t}', 1)
239 logger.verbose(f'\t\tcurrent node = {cur_node}, score = {cur_score}', 2)
240 # store highest-scoring neighbors that are accepted
241 local_solns: Solutions[T] = Solutions.empty()
242 num_nbrs, num_accepted = 0, 0
243 for nbr in self.get_neighbors(cur_node):
244 num_nbrs += 1
245 nbr_score = prob.score(nbr)
246 if prob.is_solution(nbr): 246 ↛ 248line 246 didn't jump to line 248 because the condition on line 246 was always true
247 solns.add(nbr, nbr_score)
248 if self.accept(cur_score, nbr_score):
249 num_accepted += 1
250 local_solns.add(nbr, nbr_score)
251 num_best_accepted = len(local_solns)
252 logger.verbose(f'\t\tnum_neighbors = {num_nbrs}, num_accepted = {num_accepted}, num_best accepted = {num_best_accepted}', 3)
253 if num_best_accepted == 0:
254 if self.terminate_early():
255 logger.verbose(f'\tNo neighbors accepted: terminating at iteration #{t}.', 1)
256 break
257 # otherwise, remain at the current node
258 else:
259 # choose randomly from the set of best accepted solutions
260 cur_node = random.choice(list(local_solns.solutions))
261 cur_score = local_solns.score # type: ignore
262 pairs.append((cur_node, cur_score))
263 t += 1
264 else:
265 logger.verbose(f'\tTerminating after max_iters ({max_iters}) iterations reached.', 1)
266 logger.verbose(f'\t\tcurrent node = {cur_node}, score = {cur_score}', 1)
267 return (solns, pairs)
269 def _run(self, node: T) -> Solutions[T]:
270 (solns, _) = self.iterate_search(node)
271 return solns
274class ExhaustiveSearch(HillClimb[T]):
275 """An exhaustive search checks every node in the search space."""
277 def reset(self) -> None:
278 super().reset()
279 # create a generator over all nodes in the search space
280 self._node_gen = iter(self.problem.iter_nodes())
282 def get_neighbors(self, node: T) -> Iterable[T]:
283 # retrieve the next node from the stored generator
284 while True:
285 try:
286 nbr = next(self._node_gen)
287 if nbr != node:
288 return [nbr]
289 except StopIteration:
290 return []
292 def accept(self, cur_score: float, nbr_score: float) -> bool:
293 return True
296class BlindRandomSearch(HillClimb[T]):
297 """A blind random search randomly chooses a new node to search at each step."""
299 def get_neighbors(self, node: T) -> Iterable[T]:
300 return [self.problem.random_node()]
302 def accept(self, cur_score: float, nbr_score: float) -> bool:
303 return True
306@dataclass
307class StochasticLocalSearch(HillClimb[T]):
308 """A stochastic local search randomly chooses a neighbor node at each step.
309 Accepts the neighbor if its score is at least that of the current node.
310 If strict_improvement = True, requires that the score be strictly higher."""
311 strict_improvement: bool = False
313 def get_neighbors(self, node: T) -> Iterable[T]:
314 return [self.problem.random_neighbor(node)]
316 def accept(self, cur_score: float, nbr_score: float) -> bool:
317 if self.strict_improvement: 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true
318 return nbr_score > cur_score
319 return nbr_score >= cur_score
321 def terminate_early(self) -> bool:
322 return False
325class GreedyLocalSearch(HillClimb[T]):
326 """A greedy local search selects the best-scoring neighbor from among the set of neighbors.
327 If there is a tie, chooses one at random."""
329 def accept(self, cur_score: float, nbr_score: float) -> bool:
330 return nbr_score > cur_score
333class SimulatedAnnealing(HillClimb[T]):
334 """Simulated annealing attempts to find a global maximum by starting off in a more stochastic phase, allowing balances starts at a temperature T0, then gradually cools off the temperature via some exponential decay schedule."""
335 T0: float = 1.0 # initial temperature
336 decay: float = 0.99 # exponential decay coefficient (higher means mo
338 def __post_init__(self) -> None:
339 assert (self.decay > 0.0) and (self.decay < 1.0), 'temperature decay coefficient must be in (0, 1)' # noqa: PT018
341 def reset(self) -> None:
342 super().reset()
343 self.T = self.T0
345 def get_neighbors(self, node: T) -> Iterable[T]:
346 return [self.problem.random_neighbor(node)]
348 def accept(self, cur_score: float, nbr_score: float) -> bool:
349 delta = nbr_score - cur_score
350 logger.verbose(f'\t\tcurrent temperature = {self.T}', 3)
351 logger.verbose(f'\t\tneighbor score = {nbr_score}, delta = {delta}', 3)
352 if delta > 0: # accept any improvement
353 acc = True
354 logger.verbose('\t\tscore increased', 2)
355 else: # accept a worse solution with some probability (likelier with high temperature)
356 p = exp(delta / self.T)
357 acc = random.random() < p
358 logger.verbose('\t\tscore decreased', 3)
359 logger.verbose(f'\t\tacceptance probability = {p}', 3)
360 logger.verbose('\t\t' + ('accepted' if acc else 'rejected') + ' neighbor', 3)
361 # decay the temperature
362 self.T *= self.decay
363 return acc
365 def terminate_early(self) -> bool:
366 return False
369class ExhaustiveDFS(Search[T]):
370 """An exhaustive depth-first search.
371 This employs dynamic programming: starting from the initial node, recursively finds best solutions from each descendant node.
372 NOTE: this may fail to terminate if the search graph has cycles, and it will visit nodes repeatedly if the graph is not a tree."""
374 def _run(self, node: T) -> Solutions[T]:
375 score = self.problem.score(node)
376 if self.problem.is_solution(node):
377 solns = Solutions(score, {node})
378 else:
379 solns = Solutions.empty()
380 # compute solutions for each neighbor
381 for nbr in self.problem.get_neighbors(node):
382 solns.merge(self._run(nbr))
383 return solns
386class BeamSearch(Search[T]):
387 """A beam search."""
388 beam_width: int = 10
390 def _run(self, node: T) -> Solutions[T]:
391 score = self.problem.score(node)
392 if self.problem.is_solution(node):
393 solns = Solutions(score, {node})
394 else:
395 solns = Solutions.empty()
396 # compute heap of top-scoring neighbors (not just the best ones)
397 best_nbrs: TopNHeap[tuple[float, T]] = TopNHeap(N = self.beam_width)
398 for nbr in self.problem.get_neighbors(node):
399 score = self.problem.score(nbr)
400 best_nbrs.push((score, nbr))
401 # for the top-scoring neighbors, compute their best solutions, then merge
402 for (_, nbr) in best_nbrs:
403 solns.merge(self._run(nbr))
404 return solns