Coverage for test.py: 98%
56 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 10:49 -0500
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 10:49 -0500
1"""Unit tests for the optimates library."""
3from dataclasses import dataclass
4import random
5from typing import Iterable
7import pytest
9from optimates.search import BlindRandomSearch, EmptyNeighborSetError, ExhaustiveSearch, GreedyLocalSearch, HillClimb, SearchProblem, SimulatedAnnealing, StochasticLocalSearch
12@dataclass
13class RangeProblem(SearchProblem[int]):
14 """Trivial search space where i = {0, ..., n - 1} and score(i) = i."""
15 n: int
17 def score(self, node: int) -> float:
18 return float(node)
20 def initial_nodes(self) -> Iterable[int]:
21 return [0]
23 def is_solution(self, node: int) -> bool:
24 return True
26 def iter_nodes(self) -> Iterable[int]:
27 return range(self.n)
29 def random_node(self) -> int:
30 return random.choice(range(self.n))
32 def get_neighbors(self, node: int) -> Iterable[int]:
33 nbrs = []
34 if node > 0:
35 nbrs.append(node - 1)
36 if node < self.n - 1:
37 nbrs.append(node + 1)
38 return nbrs
40 def random_neighbor(self, node: int) -> int:
41 nbrs = self.get_neighbors(node)
42 if nbrs:
43 return random.choice(list(nbrs))
44 raise EmptyNeighborSetError()
47class ReverseRangeProblem(RangeProblem):
48 """Trivial search space where i = {0, ..., n - 1} and score(i) = -i."""
50 def score(self, node: int) -> float:
51 return -float(node)
53 def initial_nodes(self) -> Iterable[int]:
54 return [5]
57range_problem = RangeProblem(10)
58rev_range_problem = ReverseRangeProblem(10)
60TESTS = [
61 (0, ExhaustiveSearch(range_problem, max_iters = None),
62 {'num_steps' : 10, 'monotonic' : True, 'score' : 9, 'solutions' : {9}}),
63 (0, BlindRandomSearch(range_problem, max_iters = 5),
64 {'num_steps' : 6, 'monotonic' : False, 'score' : 8, 'solutions' : {8}}),
65 (0, StochasticLocalSearch(range_problem, max_iters = 20),
66 {'num_steps' : 21, 'monotonic' : True, 'score' : 9, 'solutions' : {9}}),
67 (0, StochasticLocalSearch(range_problem, max_iters = 10),
68 {'num_steps' : 11, 'monotonic' : True, 'score' : 6, 'solutions' : {6}}),
69 (0, GreedyLocalSearch(range_problem, max_iters = None),
70 {'num_steps' : 10, 'monotonic' : True, 'score' : 9, 'solutions' : {9}}),
71 (0, SimulatedAnnealing(range_problem, max_iters = 10),
72 {'num_steps' : 11, 'monotonic' : False, 'score' : 5, 'solutions' : {5}}),
73 (0, ExhaustiveSearch(rev_range_problem, max_iters = None),
74 {'num_steps' : 11, 'monotonic' : False, 'score' : 0, 'solutions' : {0}}),
75 (0, BlindRandomSearch(rev_range_problem, max_iters = 10),
76 {'num_steps' : 11, 'monotonic' : False, 'score' : 0, 'solutions' : {0}}),
77 (0, StochasticLocalSearch(rev_range_problem, max_iters = 20),
78 {'num_steps' : 21, 'monotonic' : True, 'score' : 0, 'solutions' : {0}}),
79 (0, StochasticLocalSearch(rev_range_problem, max_iters = 10),
80 {'num_steps' : 11, 'monotonic' : True, 'score' : -3, 'solutions' : {3}}),
81 (0, GreedyLocalSearch(rev_range_problem, max_iters = None),
82 {'num_steps' : 6, 'monotonic' : True, 'score' : 0, 'solutions' : {0}}),
83 (0, SimulatedAnnealing(rev_range_problem, max_iters = 10),
84 {'num_steps' : 11, 'monotonic' : False, 'score' : -4, 'solutions' : {4}}),
85 (0, SimulatedAnnealing(rev_range_problem, max_iters = 22),
86 {'num_steps' : 23, 'monotonic' : False, 'score' : 0, 'solutions' : {0}}),
87]
89@pytest.mark.parametrize(['seed', 'search_obj', 'result'], TESTS)
90def test_search(seed, search_obj, result):
91 """Tests the expected result for various search problems."""
92 random.seed(seed)
93 if isinstance(search_obj, HillClimb):
94 initial = search_obj.problem.default_initial_node()
95 (res, steps) = search_obj.iterate_search(initial)
96 if 'num_steps' in result:
97 assert (len(steps) == result['num_steps'])
98 if 'monotonic' in result: # check monotonicity
99 monotonic = result['monotonic']
100 is_mon = True
101 for i in range(len(steps) - 1):
102 if monotonic: # score sequence must be non-decreasing
103 assert (steps[i][1] <= steps[i + 1][1])
104 else:
105 is_mon &= (steps[i][1] <= steps[i + 1][1])
106 # if expecting not monotonic, the sequence must be non-monotonic
107 assert monotonic or (not is_mon)
108 assert result['score'] == res.score
109 assert result['solutions'] == res.solutions