Coverage for test.py: 98%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 10:49 -0500

1"""Unit tests for the optimates library.""" 

2 

3from dataclasses import dataclass 

4import random 

5from typing import Iterable 

6 

7import pytest 

8 

9from optimates.search import BlindRandomSearch, EmptyNeighborSetError, ExhaustiveSearch, GreedyLocalSearch, HillClimb, SearchProblem, SimulatedAnnealing, StochasticLocalSearch 

10 

11 

12@dataclass 

13class RangeProblem(SearchProblem[int]): 

14 """Trivial search space where i = {0, ..., n - 1} and score(i) = i.""" 

15 n: int 

16 

17 def score(self, node: int) -> float: 

18 return float(node) 

19 

20 def initial_nodes(self) -> Iterable[int]: 

21 return [0] 

22 

23 def is_solution(self, node: int) -> bool: 

24 return True 

25 

26 def iter_nodes(self) -> Iterable[int]: 

27 return range(self.n) 

28 

29 def random_node(self) -> int: 

30 return random.choice(range(self.n)) 

31 

32 def get_neighbors(self, node: int) -> Iterable[int]: 

33 nbrs = [] 

34 if node > 0: 

35 nbrs.append(node - 1) 

36 if node < self.n - 1: 

37 nbrs.append(node + 1) 

38 return nbrs 

39 

40 def random_neighbor(self, node: int) -> int: 

41 nbrs = self.get_neighbors(node) 

42 if nbrs: 

43 return random.choice(list(nbrs)) 

44 raise EmptyNeighborSetError() 

45 

46 

47class ReverseRangeProblem(RangeProblem): 

48 """Trivial search space where i = {0, ..., n - 1} and score(i) = -i.""" 

49 

50 def score(self, node: int) -> float: 

51 return -float(node) 

52 

53 def initial_nodes(self) -> Iterable[int]: 

54 return [5] 

55 

56 

57range_problem = RangeProblem(10) 

58rev_range_problem = ReverseRangeProblem(10) 

59 

60TESTS = [ 

61 (0, ExhaustiveSearch(range_problem, max_iters = None), 

62 {'num_steps' : 10, 'monotonic' : True, 'score' : 9, 'solutions' : {9}}), 

63 (0, BlindRandomSearch(range_problem, max_iters = 5), 

64 {'num_steps' : 6, 'monotonic' : False, 'score' : 8, 'solutions' : {8}}), 

65 (0, StochasticLocalSearch(range_problem, max_iters = 20), 

66 {'num_steps' : 21, 'monotonic' : True, 'score' : 9, 'solutions' : {9}}), 

67 (0, StochasticLocalSearch(range_problem, max_iters = 10), 

68 {'num_steps' : 11, 'monotonic' : True, 'score' : 6, 'solutions' : {6}}), 

69 (0, GreedyLocalSearch(range_problem, max_iters = None), 

70 {'num_steps' : 10, 'monotonic' : True, 'score' : 9, 'solutions' : {9}}), 

71 (0, SimulatedAnnealing(range_problem, max_iters = 10), 

72 {'num_steps' : 11, 'monotonic' : False, 'score' : 5, 'solutions' : {5}}), 

73 (0, ExhaustiveSearch(rev_range_problem, max_iters = None), 

74 {'num_steps' : 11, 'monotonic' : False, 'score' : 0, 'solutions' : {0}}), 

75 (0, BlindRandomSearch(rev_range_problem, max_iters = 10), 

76 {'num_steps' : 11, 'monotonic' : False, 'score' : 0, 'solutions' : {0}}), 

77 (0, StochasticLocalSearch(rev_range_problem, max_iters = 20), 

78 {'num_steps' : 21, 'monotonic' : True, 'score' : 0, 'solutions' : {0}}), 

79 (0, StochasticLocalSearch(rev_range_problem, max_iters = 10), 

80 {'num_steps' : 11, 'monotonic' : True, 'score' : -3, 'solutions' : {3}}), 

81 (0, GreedyLocalSearch(rev_range_problem, max_iters = None), 

82 {'num_steps' : 6, 'monotonic' : True, 'score' : 0, 'solutions' : {0}}), 

83 (0, SimulatedAnnealing(rev_range_problem, max_iters = 10), 

84 {'num_steps' : 11, 'monotonic' : False, 'score' : -4, 'solutions' : {4}}), 

85 (0, SimulatedAnnealing(rev_range_problem, max_iters = 22), 

86 {'num_steps' : 23, 'monotonic' : False, 'score' : 0, 'solutions' : {0}}), 

87] 

88 

89@pytest.mark.parametrize(['seed', 'search_obj', 'result'], TESTS) 

90def test_search(seed, search_obj, result): 

91 """Tests the expected result for various search problems.""" 

92 random.seed(seed) 

93 if isinstance(search_obj, HillClimb): 

94 initial = search_obj.problem.default_initial_node() 

95 (res, steps) = search_obj.iterate_search(initial) 

96 if 'num_steps' in result: 

97 assert (len(steps) == result['num_steps']) 

98 if 'monotonic' in result: # check monotonicity 

99 monotonic = result['monotonic'] 

100 is_mon = True 

101 for i in range(len(steps) - 1): 

102 if monotonic: # score sequence must be non-decreasing 

103 assert (steps[i][1] <= steps[i + 1][1]) 

104 else: 

105 is_mon &= (steps[i][1] <= steps[i + 1][1]) 

106 # if expecting not monotonic, the sequence must be non-monotonic 

107 assert monotonic or (not is_mon) 

108 assert result['score'] == res.score 

109 assert result['solutions'] == res.solutions