Coverage for gemlib/distributions/hypergeometric_sampler.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Hypergeometric sampling algorithm""" 

2 

3# ruff: noqa: N803 

4import jax 

5import jax.numpy as jnp 

6import numpy as np 

7import tensorflow_probability.substrates.jax as tfp 

8from tensorflow_probability.substrates.jax.internal import ( 

9 dtype_util, 

10 samplers, 

11) 

12 

13tfd = tfp.distributions 

14ps = tfp.internal.prefer_static 

15brs = tfp.internal.batched_rejection_sampler 

16 

17 

18def sample_hypergeometric(num_samples, N, K, n, seed=None): 

19 dtype = dtype_util.common_dtype([N, K, n], jnp.float32) 

20 N = jnp.asarray(N, dtype) 

21 K = jnp.asarray(K, dtype) 

22 n = jnp.asarray(n, dtype) 

23 good_params_mask = (N >= 1.0) & (N >= K) & (n <= N) 

24 N = jnp.where(good_params_mask, N, 100.0) 

25 K = jnp.where(good_params_mask, K, 50.0) 

26 n = jnp.where(good_params_mask, n, 50.0) 

27 sample_shape = ps.concat( 

28 [ 

29 [num_samples], 

30 ps.broadcast_shape( 

31 ps.broadcast_shape(ps.shape(N), ps.shape(K)), ps.shape(n) 

32 ), 

33 ], 

34 axis=0, 

35 ) 

36 

37 # First Transform N, K, n such that 

38 # N / 2 >= K, N / 2 >= n 

39 is_k_small = 0.5 * N >= K 

40 is_n_small = n <= 0.5 * N 

41 previous_K = K 

42 previous_n = n 

43 K = jnp.where(is_k_small, K, N - K) 

44 n = jnp.where(is_n_small, n, N - n) 

45 

46 # TODO: Can we write this in a more numerically stable way? 

47 def _log_hypergeometric_coeff(x): 

48 return ( 

49 jax.lax.lgamma(x + 1.0) 

50 + jax.lax.lgamma(K - x + 1.0) 

51 + jax.lax.lgamma(n - x + 1.0) 

52 + jax.lax.lgamma(N - K - n + x + 1.0) 

53 ) 

54 

55 p = K / N 

56 q = 1 - p 

57 a = n * p + 0.5 

58 c = jnp.sqrt(2.0 * a * q * (1.0 - n / N)) 

59 k = jnp.floor((n + 1) * (K + 1) / (N + 2)) 

60 g = _log_hypergeometric_coeff(k) 

61 diff = jnp.floor(a - c) 

62 x = (a - diff - 1) / (a - diff) 

63 diff = jnp.where( 

64 (n - diff) * (p - diff / N) * jnp.square(x) 

65 > (diff + 1.0) * (q - (n - diff - 1) / N), 

66 diff + 1.0, 

67 diff, 

68 ) 

69 # TODO: Can we write this difference of lgammas more numerically stably? 

70 h = (a - diff) * jnp.exp( 

71 0.5 * (g - _log_hypergeometric_coeff(diff)) + jnp.log(2.0) 

72 ) 

73 b = jnp.minimum(jnp.minimum(n, K) + 1, jnp.floor(a + 5 * c)) 

74 

75 def generate_and_test_samples(seed): 

76 v_seed, u_seed = samplers.split_seed(seed) 

77 U = samplers.uniform(sample_shape, dtype=dtype, seed=u_seed) 

78 V = samplers.uniform(sample_shape, dtype=dtype, seed=v_seed) 

79 # Guard against 0. 

80 

81 X = a + h * (V - 0.5) / (1.0 - U) 

82 samples = jnp.floor(X) 

83 good_sample_mask = (samples >= 0.0) & (samples < b) 

84 T = g - _log_hypergeometric_coeff(samples) 

85 # Uses slow pass since we are trying to do this in a vectorized way. 

86 good_sample_mask = good_sample_mask & (2 * jnp.log1p(-U) <= T) 

87 return samples, good_sample_mask 

88 

89 samples = brs.batched_las_vegas_algorithm( 

90 generate_and_test_samples, seed=seed 

91 )[0] 

92 samples = jnp.where(good_params_mask, samples, np.nan) 

93 # Now transform the samples depending on if we constrained N and / or k 

94 samples = jnp.where( 

95 ~is_k_small & ~is_n_small, 

96 samples + previous_K + previous_n - N, 

97 jnp.where( 

98 ~is_k_small, 

99 previous_n - samples, 

100 jnp.where(~is_n_small, previous_K - samples, samples), 

101 ), 

102 ) 

103 return samples