Coverage for gemlib/distributions/hypergeometric.py: 75%

65 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Hypergeometric random variable""" 

2 

3# ruff: noqa: N803, N802 

4import jax 

5import jax.numpy as jnp 

6import tensorflow_probability.substrates.jax as tfp 

7from tensorflow_probability.substrates.jax.internal import ( 

8 dtype_util, 

9 parameter_properties, 

10) 

11 

12from gemlib.distributions.hypergeometric_sampler import sample_hypergeometric 

13 

14tfd = tfp.distributions 

15 

16__all__ = ["Hypergeometric"] 

17 

18 

19def _log_factorial(x): 

20 """Computes x!""" 

21 return jax.lax.lgamma(x + 1.0) 

22 

23 

24def _log_choose(n, k): 

25 """Computes nCk""" 

26 return _log_factorial(n) - _log_factorial(k) - _log_factorial(n - k) 

27 

28 

29class Hypergeometric(tfd.Distribution): 

30 def __init__( 

31 self, 

32 N, 

33 K, 

34 n, 

35 validate_args=False, 

36 allow_nan_stats=True, 

37 name="Hypergeometric", 

38 ): 

39 """Hypergeometric distribution 

40 

41 Args: 

42 ---- 

43 N: Population size 

44 K: number of units of interest in the population 

45 n: size of sample drawn from the population 

46 validate_args: should arguments be validated for correctness 

47 allow_nan_stats: allow NaN to be returned for mode, mean, variance... 

48 

49 """ 

50 parameters = dict(locals()) 

51 dtype = dtype_util.common_dtype([N, K, n], jnp.float32) 

52 self._N = jnp.asarray(N, dtype=dtype) 

53 self._K = jnp.asarray(K, dtype=dtype) 

54 self._n = jnp.asarray(n, dtype=dtype) 

55 self._n_positive_mask = N > 0.0 

56 super().__init__( 

57 dtype=dtype, 

58 reparameterization_type=tfd.FULLY_REPARAMETERIZED, 

59 validate_args=validate_args, 

60 allow_nan_stats=allow_nan_stats, 

61 parameters=parameters, 

62 name=name, 

63 ) 

64 if validate_args is True: 

65 assert jnp.all(self._N >= 0.0), "N must be non-negative" 

66 assert jnp.all(self._K <= self._N), "K must be <= N" 

67 assert jnp.all(self._n <= self._N), "n must be <= N" 

68 

69 @classmethod 

70 def _parameter_properties(cls, dtype, num_classes=None): # noqa: ARG003 

71 return { 

72 "N": parameter_properties.ParameterProperties( 

73 default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED 

74 ), 

75 "K": parameter_properties.ParameterProperties(), 

76 "n": parameter_properties.ParameterProperties(), 

77 } 

78 

79 @staticmethod 

80 def _param_shapes(sample_shape): 

81 return dict( 

82 zip( 

83 ("N", "K", "n"), 

84 ([jnp.array(sample_shape, dtype=jnp.int32)] * 3), 

85 strict=False, 

86 ) 

87 ) 

88 

89 @classmethod 

90 def _params_event_ndims(cls): 

91 return {"N": 0, "K": 0, "n": 0} 

92 

93 @property 

94 def N(self): 

95 """Population size""" 

96 return self._parameters["N"] 

97 

98 @property 

99 def K(self): 

100 """Number of units of interest in population""" 

101 return self._parameters["K"] 

102 

103 @property 

104 def n(self): 

105 """Sample size""" 

106 return self._parameters["n"] 

107 

108 def _default_event_space_bijector(self): 

109 return 

110 

111 def _event_shape_tensor(self): 

112 return jnp.array([], dtype=jnp.int32) 

113 

114 def _event_shape(self): 

115 return () 

116 

117 def _sample_n(self, n, seed=None): 

118 sample = sample_hypergeometric(n, self.N, self.K, self.n, seed=seed) 

119 sample = jnp.where(self._n_positive_mask, sample, 0.0) 

120 return sample 

121 

122 def _log_prob(self, x): 

123 numerator = _log_choose(self._K, x) + _log_choose( 

124 self._N - self._K, self._n - x 

125 ) 

126 denominator = _log_choose(self._N, self._n) 

127 return numerator - denominator 

128 

129 def _mode(self): 

130 return jnp.floor((self._n + 1) * (self._K + 1) / (self._N + 2)) 

131 

132 def _mean(self): 

133 return self._n * self._K / self._N 

134 

135 def _variance(self): 

136 n = self._n 

137 N = self._N 

138 K = self._K 

139 

140 return n * (K / N) * (N - K) / N * (N - n) / (N - 1)