Coverage for optimates/search.py: 70%

238 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 11:19 -0500

1"""A generic API for solving discrete search problems.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from collections.abc import Iterable 

7from dataclasses import dataclass 

8from math import exp, inf 

9import random 

10from typing import Generic, Optional, TypeVar 

11 

12from optimates.utils import TopNHeap, logger 

13 

14 

15T = TypeVar('T') 

16IterData = tuple[T, float] # node and score 

17 

18 

19########## 

20# ERRORS # 

21########## 

22 

23class EmptyNeighborSetError(ValueError): 

24 """Error for when a node has no neighbors.""" 

25 

26 

27########## 

28# SEARCH # 

29########## 

30 

31@dataclass 

32class Solutions(Generic[T]): 

33 """Stores a set of best solutions to a search problem, and the maximum score.""" 

34 score: Optional[float] 

35 solutions: set[T] 

36 

37 @classmethod 

38 def empty(cls) -> Solutions[T]: 

39 """Constructs an empty set of solutions.""" 

40 return Solutions(None, set()) 

41 

42 def add(self, solution: T, score: float) -> None: 

43 """Adds a new solution to the set, along with its score.""" 

44 if (self.score is None) or (score > self.score): 

45 self.score = score 

46 self.solutions = {solution} 

47 elif score == self.score: 

48 self.solutions.add(solution) 

49 

50 def merge(self, other: Solutions[T]) -> Solutions[T]: 

51 """Merges another Solutions set into this one (in-place).""" 

52 assert isinstance(other, Solutions) 

53 if other.score is not None: 

54 if (self.score is None) or (other.score > self.score): 

55 self.score = other.score 

56 self.solutions = other.solutions 

57 elif other.score == self.score: 

58 self.solutions |= other.solutions 

59 return self 

60 

61 def __len__(self) -> int: 

62 """Gets the total number of stored solutions.""" 

63 return len(self.solutions) 

64 

65 

66class SearchProblem(ABC, Generic[T]): 

67 """Generic search problem. 

68 This can be viewed as a directed graph of elements (nodes) to search. 

69 A subset of the nodes are considered "solutions." 

70 Each node has a score, and we wish to find a solution node with the maximum score. 

71 Furthermore, nodes may have directed edges to other "neighbor" nodes which are related in such a way that neighbor nodes are similar to each other in some way (hopefully in score as well). 

72 A variety of algorithms may be applied to try to search for an optimal solution in the search graph.""" 

73 

74 def score(self, node: T) -> float: 

75 """Scores a node of the search graph.""" 

76 raise NotImplementedError 

77 

78 @abstractmethod 

79 def initial_nodes(self) -> Iterable[T]: 

80 """Gets the set of initial nodes of the search graph (i.e. the set of nodes with no predecessor).""" 

81 

82 def default_initial_node(self) -> T: 

83 """Gets some "canonical default" initial node of the search graph.""" 

84 return next(iter(self.initial_nodes())) 

85 

86 @abstractmethod 

87 def is_solution(self, node: T) -> bool: 

88 """Returns True if a node is a solution.""" 

89 

90 @abstractmethod 

91 def iter_nodes(self) -> Iterable[T]: 

92 """Gets an iterable over all nodes in the search space.""" 

93 

94 @abstractmethod 

95 def random_node(self) -> T: 

96 """Gets a random node in the search space.""" 

97 

98 @abstractmethod 

99 def get_neighbors(self, node: T) -> Iterable[T]: 

100 """Gets the neighbors of a node.""" 

101 

102 def num_neighbors(self, node: T) -> int: 

103 """Counts the number of neighbors of a node.""" 

104 ctr = 0 

105 for _ in self.get_neighbors(node): 

106 ctr += 1 

107 return ctr 

108 

109 @abstractmethod 

110 def random_neighbor(self, node: T) -> T: 

111 """Gets a random neighbor of a node. 

112 By default, this will be distributed uniformly over the neighbor set. 

113 If no neighbors exist, raises an EmptyNeighborSetError.""" 

114 

115 

116@dataclass 

117class FilteredSearchProblem(SearchProblem[T]): 

118 """Class for a modified search problem where the search space is filtered by some predicate. 

119 A subclass should override the `is_element` method.""" 

120 problem: SearchProblem[T] 

121 

122 @abstractmethod 

123 def is_element(self, node: T) -> bool: 

124 """This method checks whether a node is a valid element of the search problem.""" 

125 

126 def score(self, node: T) -> float: 

127 return self.problem.score(node) 

128 

129 def initial_nodes(self) -> Iterable[T]: 

130 return filter(self.is_element, self.problem.initial_nodes()) 

131 

132 def is_solution(self, node: T) -> bool: 

133 return self.is_element(node) and self.problem.is_solution(node) 

134 

135 def iter_nodes(self) -> Iterable[T]: 

136 # NOTE: this can be inefficient, so you should preferably override it 

137 return filter(self.is_element, self.problem.iter_nodes()) 

138 

139 def random_node(self) -> T: 

140 # NOTE: this can be inefficient (or non-terminating), so you should preferably override it 

141 while True: 

142 node = self.problem.random_node() 

143 if self.is_element(node): 

144 return node 

145 

146 def get_neighbors(self, node: T) -> Iterable[T]: 

147 return filter(self.is_element, self.problem.get_neighbors(node)) 

148 

149 def random_neighbor(self, node: T) -> T: 

150 while True: 

151 nbr = self.problem.random_neighbor(node) 

152 if self.is_element(nbr): 

153 return nbr 

154 

155 

156class _Search(ABC, Generic[T]): 

157 

158 @abstractmethod 

159 def run(self, initial: Optional[T] = None) -> Solutions[T]: 

160 """Performs the search, starting from an initial node.""" 

161 

162 

163@dataclass 

164class Search(_Search[T]): 

165 """Base class for a search algorithm. 

166 Starting from some initial state, it will attempt to find a global maximum, possibly using the neighborhood structure of ths earch problem.""" 

167 problem: SearchProblem[T] 

168 

169 @abstractmethod 

170 def _run(self, node: T) -> Solutions[T]: 

171 """Override this method to perform the search from a given node. 

172 This should return a Solutions object containing a set of best solutions and their score.""" 

173 

174 def run(self, initial: Optional[T] = None) -> Solutions[T]: 

175 """Performs the search, starting from an initial node. 

176 If none is provided, uses the SearchProblem's default initial node.""" 

177 if initial is None: 

178 initial = self.problem.default_initial_node() 

179 return self._run(initial) 

180 

181 

182@dataclass 

183class SearchWithRestarts(_Search[T]): 

184 """Wraps another search problem. 

185 Runs that search multiple times, taking the best solution over all iterations.""" 

186 search: Search[T] 

187 num_restarts: int = 25 # number of random restarts 

188 random_restart: bool = True # whether to use random restarts 

189 

190 def run(self, initial: Optional[T] = None) -> Solutions[T]: 

191 solutions: Solutions[T] = Solutions.empty() 

192 for t in range(self.num_restarts): 

193 logger.verbose(f'RESTART #{t + 1}', 1) 

194 node = self.search.problem.random_node() if self.random_restart else initial 

195 solutions.merge(self.search.run(node)) 

196 return solutions 

197 

198 

199@dataclass 

200class HillClimb(Search[T]): 

201 """A general framework for discrete optimization which captures many well-known algorithms. 

202 Starting from the initial node, this will generate neighbors in some way, scoring each of them. 

203 Once they are scored, an acceptance criterion is applied to determine an accepted subset. 

204 If this subset is empty, remains at the current node; otherwise, transitions to one of the highest-scoring accepted nodes at random. 

205 Proceeds in this way until a stopping criterion is met (e.g. max number of iterations reached, no neighbors were accepted, etc.).""" 

206 max_iters: Optional[int] = None 

207 

208 def reset(self) -> None: 

209 """Resets the state of the hill climb.""" 

210 # maintain the set of best solutions 

211 self.solutions: Solutions[T] = Solutions.empty() 

212 

213 def terminate_early(self) -> bool: 

214 """Returns True if the algorithm should terminate when no acceptable neighbors are found.""" 

215 return True 

216 

217 def get_neighbors(self, node: T) -> Iterable[T]: 

218 """Gets the set of neighbors for a node. 

219 By default, it will simply get the neighborhood from the underlying problem, but the search algorithm is free to modify this in some way.""" 

220 return self.problem.get_neighbors(node) 

221 

222 @abstractmethod 

223 def accept(self, cur_score: float, nbr_score: float) -> bool: 

224 """Given the current node's score and a neighbor node's score, returns True if the neighbor is accepted.""" 

225 

226 def iterate_search(self, initial: T) -> tuple[Solutions[T], list[IterData[T]]]: 

227 """Runs the optimization, returning a pair (best solutions, node sequence).""" 

228 self.reset() 

229 prob, solns = self.problem, self.solutions 

230 max_iters = inf if (self.max_iters is None) else self.max_iters 

231 cur_node = initial 

232 cur_score = prob.score(initial) 

233 pairs = [(cur_node, cur_score)] 

234 if prob.is_solution(cur_node): 234 ↛ 236line 234 didn't jump to line 236 because the condition on line 234 was always true

235 solns.add(cur_node, cur_score) 

236 t = 1 

237 while t <= max_iters: 

238 logger.verbose(f'\tIteration #{t}', 1) 

239 logger.verbose(f'\t\tcurrent node = {cur_node}, score = {cur_score}', 2) 

240 # store highest-scoring neighbors that are accepted 

241 local_solns: Solutions[T] = Solutions.empty() 

242 num_nbrs, num_accepted = 0, 0 

243 for nbr in self.get_neighbors(cur_node): 

244 num_nbrs += 1 

245 nbr_score = prob.score(nbr) 

246 if prob.is_solution(nbr): 246 ↛ 248line 246 didn't jump to line 248 because the condition on line 246 was always true

247 solns.add(nbr, nbr_score) 

248 if self.accept(cur_score, nbr_score): 

249 num_accepted += 1 

250 local_solns.add(nbr, nbr_score) 

251 num_best_accepted = len(local_solns) 

252 logger.verbose(f'\t\tnum_neighbors = {num_nbrs}, num_accepted = {num_accepted}, num_best accepted = {num_best_accepted}', 3) 

253 if num_best_accepted == 0: 

254 if self.terminate_early(): 

255 logger.verbose(f'\tNo neighbors accepted: terminating at iteration #{t}.', 1) 

256 break 

257 # otherwise, remain at the current node 

258 else: 

259 # choose randomly from the set of best accepted solutions 

260 cur_node = random.choice(list(local_solns.solutions)) 

261 cur_score = local_solns.score # type: ignore 

262 pairs.append((cur_node, cur_score)) 

263 t += 1 

264 else: 

265 logger.verbose(f'\tTerminating after max_iters ({max_iters}) iterations reached.', 1) 

266 logger.verbose(f'\t\tcurrent node = {cur_node}, score = {cur_score}', 1) 

267 return (solns, pairs) 

268 

269 def _run(self, node: T) -> Solutions[T]: 

270 (solns, _) = self.iterate_search(node) 

271 return solns 

272 

273 

274class ExhaustiveSearch(HillClimb[T]): 

275 """An exhaustive search checks every node in the search space.""" 

276 

277 def reset(self) -> None: 

278 super().reset() 

279 # create a generator over all nodes in the search space 

280 self._node_gen = iter(self.problem.iter_nodes()) 

281 

282 def get_neighbors(self, node: T) -> Iterable[T]: 

283 # retrieve the next node from the stored generator 

284 while True: 

285 try: 

286 nbr = next(self._node_gen) 

287 if nbr != node: 

288 return [nbr] 

289 except StopIteration: 

290 return [] 

291 

292 def accept(self, cur_score: float, nbr_score: float) -> bool: 

293 return True 

294 

295 

296class BlindRandomSearch(HillClimb[T]): 

297 """A blind random search randomly chooses a new node to search at each step.""" 

298 

299 def get_neighbors(self, node: T) -> Iterable[T]: 

300 return [self.problem.random_node()] 

301 

302 def accept(self, cur_score: float, nbr_score: float) -> bool: 

303 return True 

304 

305 

306@dataclass 

307class StochasticLocalSearch(HillClimb[T]): 

308 """A stochastic local search randomly chooses a neighbor node at each step. 

309 Accepts the neighbor if its score is at least that of the current node. 

310 If strict_improvement = True, requires that the score be strictly higher.""" 

311 strict_improvement: bool = False 

312 

313 def get_neighbors(self, node: T) -> Iterable[T]: 

314 return [self.problem.random_neighbor(node)] 

315 

316 def accept(self, cur_score: float, nbr_score: float) -> bool: 

317 if self.strict_improvement: 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true

318 return nbr_score > cur_score 

319 return nbr_score >= cur_score 

320 

321 def terminate_early(self) -> bool: 

322 return False 

323 

324 

325class GreedyLocalSearch(HillClimb[T]): 

326 """A greedy local search selects the best-scoring neighbor from among the set of neighbors. 

327 If there is a tie, chooses one at random.""" 

328 

329 def accept(self, cur_score: float, nbr_score: float) -> bool: 

330 return nbr_score > cur_score 

331 

332 

333class SimulatedAnnealing(HillClimb[T]): 

334 """Simulated annealing attempts to find a global maximum by starting off in a more stochastic phase, allowing balances starts at a temperature T0, then gradually cools off the temperature via some exponential decay schedule.""" 

335 T0: float = 1.0 # initial temperature 

336 decay: float = 0.99 # exponential decay coefficient (higher means mo 

337 

338 def __post_init__(self) -> None: 

339 assert (self.decay > 0.0) and (self.decay < 1.0), 'temperature decay coefficient must be in (0, 1)' # noqa: PT018 

340 

341 def reset(self) -> None: 

342 super().reset() 

343 self.T = self.T0 

344 

345 def get_neighbors(self, node: T) -> Iterable[T]: 

346 return [self.problem.random_neighbor(node)] 

347 

348 def accept(self, cur_score: float, nbr_score: float) -> bool: 

349 delta = nbr_score - cur_score 

350 logger.verbose(f'\t\tcurrent temperature = {self.T}', 3) 

351 logger.verbose(f'\t\tneighbor score = {nbr_score}, delta = {delta}', 3) 

352 if delta > 0: # accept any improvement 

353 acc = True 

354 logger.verbose('\t\tscore increased', 2) 

355 else: # accept a worse solution with some probability (likelier with high temperature) 

356 p = exp(delta / self.T) 

357 acc = random.random() < p 

358 logger.verbose('\t\tscore decreased', 3) 

359 logger.verbose(f'\t\tacceptance probability = {p}', 3) 

360 logger.verbose('\t\t' + ('accepted' if acc else 'rejected') + ' neighbor', 3) 

361 # decay the temperature 

362 self.T *= self.decay 

363 return acc 

364 

365 def terminate_early(self) -> bool: 

366 return False 

367 

368 

369class ExhaustiveDFS(Search[T]): 

370 """An exhaustive depth-first search. 

371 This employs dynamic programming: starting from the initial node, recursively finds best solutions from each descendant node. 

372 NOTE: this may fail to terminate if the search graph has cycles, and it will visit nodes repeatedly if the graph is not a tree.""" 

373 

374 def _run(self, node: T) -> Solutions[T]: 

375 score = self.problem.score(node) 

376 if self.problem.is_solution(node): 

377 solns = Solutions(score, {node}) 

378 else: 

379 solns = Solutions.empty() 

380 # compute solutions for each neighbor 

381 for nbr in self.problem.get_neighbors(node): 

382 solns.merge(self._run(nbr)) 

383 return solns 

384 

385 

386class BeamSearch(Search[T]): 

387 """A beam search.""" 

388 beam_width: int = 10 

389 

390 def _run(self, node: T) -> Solutions[T]: 

391 score = self.problem.score(node) 

392 if self.problem.is_solution(node): 

393 solns = Solutions(score, {node}) 

394 else: 

395 solns = Solutions.empty() 

396 # compute heap of top-scoring neighbors (not just the best ones) 

397 best_nbrs: TopNHeap[tuple[float, T]] = TopNHeap(N = self.beam_width) 

398 for nbr in self.problem.get_neighbors(node): 

399 score = self.problem.score(nbr) 

400 best_nbrs.push((score, nbr)) 

401 # for the top-scoring neighbors, compute their best solutions, then merge 

402 for (_, nbr) in best_nbrs: 

403 solns.merge(self._run(nbr)) 

404 return solns