Coverage for gemlib/distributions/hypergeometric_sampler.py: 100%
51 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Hypergeometric sampling algorithm"""
3# ruff: noqa: N803
4import jax
5import jax.numpy as jnp
6import numpy as np
7import tensorflow_probability.substrates.jax as tfp
8from tensorflow_probability.substrates.jax.internal import (
9 dtype_util,
10 samplers,
11)
13tfd = tfp.distributions
14ps = tfp.internal.prefer_static
15brs = tfp.internal.batched_rejection_sampler
18def sample_hypergeometric(num_samples, N, K, n, seed=None):
19 dtype = dtype_util.common_dtype([N, K, n], jnp.float32)
20 N = jnp.asarray(N, dtype)
21 K = jnp.asarray(K, dtype)
22 n = jnp.asarray(n, dtype)
23 good_params_mask = (N >= 1.0) & (N >= K) & (n <= N)
24 N = jnp.where(good_params_mask, N, 100.0)
25 K = jnp.where(good_params_mask, K, 50.0)
26 n = jnp.where(good_params_mask, n, 50.0)
27 sample_shape = ps.concat(
28 [
29 [num_samples],
30 ps.broadcast_shape(
31 ps.broadcast_shape(ps.shape(N), ps.shape(K)), ps.shape(n)
32 ),
33 ],
34 axis=0,
35 )
37 # First Transform N, K, n such that
38 # N / 2 >= K, N / 2 >= n
39 is_k_small = 0.5 * N >= K
40 is_n_small = n <= 0.5 * N
41 previous_K = K
42 previous_n = n
43 K = jnp.where(is_k_small, K, N - K)
44 n = jnp.where(is_n_small, n, N - n)
46 # TODO: Can we write this in a more numerically stable way?
47 def _log_hypergeometric_coeff(x):
48 return (
49 jax.lax.lgamma(x + 1.0)
50 + jax.lax.lgamma(K - x + 1.0)
51 + jax.lax.lgamma(n - x + 1.0)
52 + jax.lax.lgamma(N - K - n + x + 1.0)
53 )
55 p = K / N
56 q = 1 - p
57 a = n * p + 0.5
58 c = jnp.sqrt(2.0 * a * q * (1.0 - n / N))
59 k = jnp.floor((n + 1) * (K + 1) / (N + 2))
60 g = _log_hypergeometric_coeff(k)
61 diff = jnp.floor(a - c)
62 x = (a - diff - 1) / (a - diff)
63 diff = jnp.where(
64 (n - diff) * (p - diff / N) * jnp.square(x)
65 > (diff + 1.0) * (q - (n - diff - 1) / N),
66 diff + 1.0,
67 diff,
68 )
69 # TODO: Can we write this difference of lgammas more numerically stably?
70 h = (a - diff) * jnp.exp(
71 0.5 * (g - _log_hypergeometric_coeff(diff)) + jnp.log(2.0)
72 )
73 b = jnp.minimum(jnp.minimum(n, K) + 1, jnp.floor(a + 5 * c))
75 def generate_and_test_samples(seed):
76 v_seed, u_seed = samplers.split_seed(seed)
77 U = samplers.uniform(sample_shape, dtype=dtype, seed=u_seed)
78 V = samplers.uniform(sample_shape, dtype=dtype, seed=v_seed)
79 # Guard against 0.
81 X = a + h * (V - 0.5) / (1.0 - U)
82 samples = jnp.floor(X)
83 good_sample_mask = (samples >= 0.0) & (samples < b)
84 T = g - _log_hypergeometric_coeff(samples)
85 # Uses slow pass since we are trying to do this in a vectorized way.
86 good_sample_mask = good_sample_mask & (2 * jnp.log1p(-U) <= T)
87 return samples, good_sample_mask
89 samples = brs.batched_las_vegas_algorithm(
90 generate_and_test_samples, seed=seed
91 )[0]
92 samples = jnp.where(good_params_mask, samples, np.nan)
93 # Now transform the samples depending on if we constrained N and / or k
94 samples = jnp.where(
95 ~is_k_small & ~is_n_small,
96 samples + previous_K + previous_n - N,
97 jnp.where(
98 ~is_k_small,
99 previous_n - samples,
100 jnp.where(~is_n_small, previous_K - samples, samples),
101 ),
102 )
103 return samples