Coverage for src\pqlattice\random\_distribution.py: 66%

93 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-01-07 08:10 +0100

1import math 

2import random 

3 

4from .._utils import as_integer 

5from ..typing import Matrix, Vector 

6 

7 

8class Uniform: 

9 def __init__(self, range_beg: int, range_end: int, seed: int | None = None): 

10 self._pyrng = random.Random(seed) 

11 self._range_beg = range_beg 

12 self._range_end = range_end 

13 

14 def sample_int(self, seed: int | None = None) -> int: 

15 self.set_seed(seed) 

16 return self._pyrng.randint(self._range_beg, self._range_end) 

17 

18 def sample_vector(self, n: int, seed: int | None = None) -> Vector: 

19 self.set_seed(seed) 

20 return as_integer([self.sample_int() for _ in range(n)]) 

21 

22 def sample_matrix(self, m: int, n: int | None = None, seed: int | None = None) -> Matrix: 

23 self.set_seed(seed) 

24 if n is None: 

25 n = m 

26 return as_integer([[self.sample_int() for _ in range(m)] for _ in range(n)]) 

27 

28 def set_seed(self, seed: int | None) -> None: 

29 if seed is not None: 

30 self._pyrng.seed(seed) 

31 

32 def set_params(self, range_beg: int | None = None, range_end: int | None = None) -> None: 

33 if range_beg is None: 

34 range_beg = self._range_beg 

35 if range_end is None: 

36 range_end = self._range_end 

37 

38 self._range_beg = range_beg 

39 self._range_end = range_end 

40 

41 def get_params(self) -> dict[str, int]: 

42 return {"range_beg": self._range_beg, "range_end": self._range_end} 

43 

44 

45class DiscreteGaussian: 

46 def __init__(self, sigma: float, center: int | float = 0, tail_cut: float = 6.0, seed: int | None = None): 

47 self.center = center 

48 self.sigma = sigma 

49 self.tail_cut = tail_cut 

50 self._pyrng = random.Random(seed) 

51 

52 self._table: dict[int, float] = {} 

53 self._recompute_table() 

54 

55 def _recompute_table(self) -> None: 

56 self.bound = int(math.ceil(self.tail_cut * self.sigma)) 

57 self._table: dict[int, float] = {} 

58 for x in range(-self.bound, self.bound + 1): 

59 self._table[x] = math.exp(-(x**2) / (2 * self.sigma**2)) 

60 

61 def set_seed(self, seed: int | None) -> None: 

62 if seed is not None: 

63 self._pyrng.seed(seed) 

64 

65 def sample_int(self, seed: int | None = None) -> int: 

66 self.set_seed(seed) 

67 

68 if isinstance(self.center, int): 

69 return self._sample_centered_fast() + self.center 

70 

71 return self._sample_dynamic_rejection(self.center) 

72 

73 def sample_vector(self, n: int, seed: int | None = None) -> Vector: 

74 self.set_seed(seed) 

75 return as_integer([self.sample_int() for _ in range(n)]) 

76 

77 def sample_matrix(self, rows: int, cols: int | None = None, seed: int | None = None) -> Matrix: 

78 self.set_seed(seed) 

79 if cols is None: 

80 cols = rows 

81 return as_integer([[self.sample_int() for _ in range(cols)] for _ in range(rows)]) 

82 

83 def set_params(self, sigma: float | None = None, center: float | int | None = None, tail_cut: float | None = None) -> None: 

84 if sigma is None: 

85 sigma = self.sigma 

86 if center is None: 

87 center = self.center 

88 if tail_cut is None: 

89 tail_cut = self.tail_cut 

90 

91 self.sigma = sigma 

92 self.tail_cut = tail_cut 

93 self.center = center 

94 self._recompute_table() 

95 

96 def get_params(self) -> dict[str, float]: 

97 return {"sigma": self.sigma, "center": self.center, "tail_cut": self.tail_cut, "bound": self.bound} 

98 

99 def _sample_centered_fast(self) -> int: 

100 max_iters = 1000 

101 for _ in range(max_iters): 

102 x = random.randint(-self.bound, self.bound) 

103 prob = self._table.get(x, 0.0) 

104 if random.random() < prob: 

105 return x 

106 

107 raise RuntimeError("Failed to generate sample") 

108 

109 def _sample_dynamic_rejection(self, c: float) -> int: 

110 start = int(math.floor(c - self.bound)) 

111 end = int(math.ceil(c + self.bound)) 

112 

113 max_iters = 1000 

114 for _ in range(max_iters): 

115 x = random.randint(start, end) 

116 dist_sq = (x - c) ** 2 

117 prob = math.exp(-dist_sq / (2 * self.sigma**2)) 

118 

119 if random.random() < prob: 

120 return x 

121 

122 raise RuntimeError("Failed to generate sample")