Coverage for src\pqlattice\random\_distribution.py: 66%
93 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-07 08:10 +0100
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-07 08:10 +0100
1import math
2import random
4from .._utils import as_integer
5from ..typing import Matrix, Vector
8class Uniform:
9 def __init__(self, range_beg: int, range_end: int, seed: int | None = None):
10 self._pyrng = random.Random(seed)
11 self._range_beg = range_beg
12 self._range_end = range_end
14 def sample_int(self, seed: int | None = None) -> int:
15 self.set_seed(seed)
16 return self._pyrng.randint(self._range_beg, self._range_end)
18 def sample_vector(self, n: int, seed: int | None = None) -> Vector:
19 self.set_seed(seed)
20 return as_integer([self.sample_int() for _ in range(n)])
22 def sample_matrix(self, m: int, n: int | None = None, seed: int | None = None) -> Matrix:
23 self.set_seed(seed)
24 if n is None:
25 n = m
26 return as_integer([[self.sample_int() for _ in range(m)] for _ in range(n)])
28 def set_seed(self, seed: int | None) -> None:
29 if seed is not None:
30 self._pyrng.seed(seed)
32 def set_params(self, range_beg: int | None = None, range_end: int | None = None) -> None:
33 if range_beg is None:
34 range_beg = self._range_beg
35 if range_end is None:
36 range_end = self._range_end
38 self._range_beg = range_beg
39 self._range_end = range_end
41 def get_params(self) -> dict[str, int]:
42 return {"range_beg": self._range_beg, "range_end": self._range_end}
45class DiscreteGaussian:
46 def __init__(self, sigma: float, center: int | float = 0, tail_cut: float = 6.0, seed: int | None = None):
47 self.center = center
48 self.sigma = sigma
49 self.tail_cut = tail_cut
50 self._pyrng = random.Random(seed)
52 self._table: dict[int, float] = {}
53 self._recompute_table()
55 def _recompute_table(self) -> None:
56 self.bound = int(math.ceil(self.tail_cut * self.sigma))
57 self._table: dict[int, float] = {}
58 for x in range(-self.bound, self.bound + 1):
59 self._table[x] = math.exp(-(x**2) / (2 * self.sigma**2))
61 def set_seed(self, seed: int | None) -> None:
62 if seed is not None:
63 self._pyrng.seed(seed)
65 def sample_int(self, seed: int | None = None) -> int:
66 self.set_seed(seed)
68 if isinstance(self.center, int):
69 return self._sample_centered_fast() + self.center
71 return self._sample_dynamic_rejection(self.center)
73 def sample_vector(self, n: int, seed: int | None = None) -> Vector:
74 self.set_seed(seed)
75 return as_integer([self.sample_int() for _ in range(n)])
77 def sample_matrix(self, rows: int, cols: int | None = None, seed: int | None = None) -> Matrix:
78 self.set_seed(seed)
79 if cols is None:
80 cols = rows
81 return as_integer([[self.sample_int() for _ in range(cols)] for _ in range(rows)])
83 def set_params(self, sigma: float | None = None, center: float | int | None = None, tail_cut: float | None = None) -> None:
84 if sigma is None:
85 sigma = self.sigma
86 if center is None:
87 center = self.center
88 if tail_cut is None:
89 tail_cut = self.tail_cut
91 self.sigma = sigma
92 self.tail_cut = tail_cut
93 self.center = center
94 self._recompute_table()
96 def get_params(self) -> dict[str, float]:
97 return {"sigma": self.sigma, "center": self.center, "tail_cut": self.tail_cut, "bound": self.bound}
99 def _sample_centered_fast(self) -> int:
100 max_iters = 1000
101 for _ in range(max_iters):
102 x = random.randint(-self.bound, self.bound)
103 prob = self._table.get(x, 0.0)
104 if random.random() < prob:
105 return x
107 raise RuntimeError("Failed to generate sample")
109 def _sample_dynamic_rejection(self, c: float) -> int:
110 start = int(math.floor(c - self.bound))
111 end = int(math.ceil(c + self.bound))
113 max_iters = 1000
114 for _ in range(max_iters):
115 x = random.randint(start, end)
116 dist_sq = (x - c) ** 2
117 prob = math.exp(-dist_sq / (2 * self.sigma**2))
119 if random.random() < prob:
120 return x
122 raise RuntimeError("Failed to generate sample")