Coverage for src/pqlattice/random/_distribution.py: 70%
104 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-12 21:36 +0100
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-12 21:36 +0100
1import math
2import random
3from abc import ABC, abstractmethod
4from typing import Any, override
6from .._utils import as_integer
7from ..typing import Matrix, Vector
10class Distribution(ABC):
11 def __init__(self, seed: int | None = None):
12 self._pyrng = random.Random(seed)
14 def set_seed(self, seed: int | None) -> None:
15 if seed is not None:
16 self._pyrng.seed(seed)
18 @abstractmethod
19 def sample_int(self, seed: int | None) -> int: ...
21 @abstractmethod
22 def sample_vector(self, n: int, seed: int | None = None) -> Vector: ...
24 @abstractmethod
25 def sample_matrix(self, rows: int, cols: int | None = None, seed: int | None = None) -> Matrix: ...
27 @abstractmethod
28 def set_params(self, *args: Any) -> None: ...
30 @abstractmethod
31 def get_params(self) -> dict[str, Any]: ...
34class Uniform(Distribution):
35 def __init__(self, range_beg: int, range_end: int, seed: int | None = None):
36 """
37 Creates a uniform sampler from range [range_beg; range_end].
39 Parameters
40 ----------
41 range_beg : int
42 begin of sampling range. Inclusive
43 range_end : int
44 end of sampling range. Inclusive
45 seed : int | None, optional
46 seed for random number generator.
47 """
48 super().__init__(seed)
49 self._range_beg = range_beg
50 self._range_end = range_end
52 @override
53 def sample_int(self, seed: int | None = None) -> int:
54 """
55 Get uniform random int from range [self.beg_range, self.end_range]
57 Parameters
58 ----------
59 seed : int | None, optional
60 set the new seed, if None does nothing
62 Returns
63 -------
64 int
65 random integer from range [self.beg_range, self.end_range]
66 """
67 self.set_seed(seed)
68 return self._pyrng.randint(self._range_beg, self._range_end)
70 @override
71 def sample_vector(self, n: int, seed: int | None = None) -> Vector:
72 """
73 _summary_
75 Parameters
76 ----------
77 n : int
78 _description_
79 seed : int | None, optional
80 _description_, by default None
82 Returns
83 -------
84 Vector
85 _description_
86 """
87 self.set_seed(seed)
88 return as_integer([self.sample_int() for _ in range(n)])
90 @override
91 def sample_matrix(self, rows: int, cols: int | None = None, seed: int | None = None) -> Matrix:
92 """
93 _summary_
95 Parameters
96 ----------
97 rows : int
98 _description_
99 cols : int | None, optional
100 _description_, by default None
101 seed : int | None, optional
102 _description_, by default None
104 Returns
105 -------
106 Matrix
107 _description_
108 """
109 self.set_seed(seed)
110 if cols is None:
111 cols = rows
112 return as_integer([[self.sample_int() for _ in range(cols)] for _ in range(rows)])
114 @override
115 def set_params(self, range_beg: int | None = None, range_end: int | None = None) -> None: # type: ignore
116 """
117 _summary_
119 Parameters
120 ----------
121 range_beg : int | None, optional
122 _description_, by default None
123 range_end : int | None, optional
124 _description_, by default None
125 """
126 if range_beg is None:
127 range_beg = self._range_beg
128 if range_end is None:
129 range_end = self._range_end
131 self._range_beg = range_beg
132 self._range_end = range_end
134 @override
135 def get_params(self) -> dict[str, int]:
136 """
137 _summary_
139 Returns
140 -------
141 dict[str, int]
142 _description_
143 """
144 return {"range_beg": self._range_beg, "range_end": self._range_end}
147class DiscreteGaussian(Distribution):
148 def __init__(self, sigma: float, center: int | float = 0, tail_cut: float = 6.0, seed: int | None = None):
149 """
150 _summary_
152 Parameters
153 ----------
154 sigma : float
155 _description_
156 center : int | float, optional
157 _description_, by default 0
158 tail_cut : float, optional
159 _description_, by default 6.0
160 seed : int | None, optional
161 _description_, by default None
162 """
163 super().__init__(seed)
164 self.center = center
165 self.sigma = sigma
166 self.tail_cut = tail_cut
168 self._table: dict[int, float] = {}
169 self._recompute_table()
171 def _recompute_table(self) -> None:
172 self.bound = int(math.ceil(self.tail_cut * self.sigma))
173 self._table: dict[int, float] = {}
174 for x in range(-self.bound, self.bound + 1):
175 self._table[x] = math.exp(-(x**2) / (2 * self.sigma**2))
177 @override
178 def sample_int(self, seed: int | None = None) -> int:
179 """
180 _summary_
182 Parameters
183 ----------
184 seed : int | None, optional
185 _description_, by default None
187 Returns
188 -------
189 int
190 _description_
191 """
192 self.set_seed(seed)
194 if isinstance(self.center, int):
195 return self._sample_centered_fast() + self.center
197 return self._sample_dynamic_rejection(self.center)
199 @override
200 def sample_vector(self, n: int, seed: int | None = None) -> Vector:
201 """
202 _summary_
204 Parameters
205 ----------
206 n : int
207 _description_
208 seed : int | None, optional
209 _description_, by default None
211 Returns
212 -------
213 Vector
214 _description_
215 """
216 self.set_seed(seed)
217 return as_integer([self.sample_int() for _ in range(n)])
219 @override
220 def sample_matrix(self, rows: int, cols: int | None = None, seed: int | None = None) -> Matrix:
221 """
222 _summary_
224 Parameters
225 ----------
226 rows : int
227 _description_
228 cols : int | None, optional
229 _description_, by default None
230 seed : int | None, optional
231 _description_, by default None
233 Returns
234 -------
235 Matrix
236 _description_
237 """
238 self.set_seed(seed)
239 if cols is None:
240 cols = rows
241 return as_integer([[self.sample_int() for _ in range(cols)] for _ in range(rows)])
243 @override
244 def set_params(self, sigma: float | None = None, center: float | int | None = None, tail_cut: float | None = None) -> None: # type: ignore
245 """
246 _summary_
248 Parameters
249 ----------
250 sigma : float | None, optional
251 _description_, by default None
252 center : float | int | None, optional
253 _description_, by default None
254 tail_cut : float | None, optional
255 _description_, by default None
256 """
257 if sigma is None:
258 sigma = self.sigma
259 if center is None:
260 center = self.center
261 if tail_cut is None:
262 tail_cut = self.tail_cut
264 self.sigma = sigma
265 self.tail_cut = tail_cut
266 self.center = center
267 self._recompute_table()
269 def get_params(self) -> dict[str, float]:
270 """
271 _summary_
273 Returns
274 -------
275 dict[str, float]
276 _description_
277 """
278 return {"sigma": self.sigma, "center": self.center, "tail_cut": self.tail_cut, "bound": self.bound}
280 def _sample_centered_fast(self) -> int:
281 max_iters = 1000
282 for _ in range(max_iters):
283 x = random.randint(-self.bound, self.bound)
284 prob = self._table.get(x, 0.0)
285 if random.random() < prob:
286 return x
288 raise RuntimeError("Failed to generate sample")
290 def _sample_dynamic_rejection(self, c: float) -> int:
291 start = int(math.floor(c - self.bound))
292 end = int(math.ceil(c + self.bound))
294 max_iters = 1000
295 for _ in range(max_iters):
296 x = random.randint(start, end)
297 dist_sq = (x - c) ** 2
298 prob = math.exp(-dist_sq / (2 * self.sigma**2))
300 if random.random() < prob:
301 return x
303 raise RuntimeError("Failed to generate sample")