Coverage for gemlib/distributions/hypergeometric.py: 75%
65 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Hypergeometric random variable"""
3# ruff: noqa: N803, N802
4import jax
5import jax.numpy as jnp
6import tensorflow_probability.substrates.jax as tfp
7from tensorflow_probability.substrates.jax.internal import (
8 dtype_util,
9 parameter_properties,
10)
12from gemlib.distributions.hypergeometric_sampler import sample_hypergeometric
14tfd = tfp.distributions
16__all__ = ["Hypergeometric"]
19def _log_factorial(x):
20 """Computes x!"""
21 return jax.lax.lgamma(x + 1.0)
24def _log_choose(n, k):
25 """Computes nCk"""
26 return _log_factorial(n) - _log_factorial(k) - _log_factorial(n - k)
29class Hypergeometric(tfd.Distribution):
30 def __init__(
31 self,
32 N,
33 K,
34 n,
35 validate_args=False,
36 allow_nan_stats=True,
37 name="Hypergeometric",
38 ):
39 """Hypergeometric distribution
41 Args:
42 ----
43 N: Population size
44 K: number of units of interest in the population
45 n: size of sample drawn from the population
46 validate_args: should arguments be validated for correctness
47 allow_nan_stats: allow NaN to be returned for mode, mean, variance...
49 """
50 parameters = dict(locals())
51 dtype = dtype_util.common_dtype([N, K, n], jnp.float32)
52 self._N = jnp.asarray(N, dtype=dtype)
53 self._K = jnp.asarray(K, dtype=dtype)
54 self._n = jnp.asarray(n, dtype=dtype)
55 self._n_positive_mask = N > 0.0
56 super().__init__(
57 dtype=dtype,
58 reparameterization_type=tfd.FULLY_REPARAMETERIZED,
59 validate_args=validate_args,
60 allow_nan_stats=allow_nan_stats,
61 parameters=parameters,
62 name=name,
63 )
64 if validate_args is True:
65 assert jnp.all(self._N >= 0.0), "N must be non-negative"
66 assert jnp.all(self._K <= self._N), "K must be <= N"
67 assert jnp.all(self._n <= self._N), "n must be <= N"
69 @classmethod
70 def _parameter_properties(cls, dtype, num_classes=None): # noqa: ARG003
71 return {
72 "N": parameter_properties.ParameterProperties(
73 default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED
74 ),
75 "K": parameter_properties.ParameterProperties(),
76 "n": parameter_properties.ParameterProperties(),
77 }
79 @staticmethod
80 def _param_shapes(sample_shape):
81 return dict(
82 zip(
83 ("N", "K", "n"),
84 ([jnp.array(sample_shape, dtype=jnp.int32)] * 3),
85 strict=False,
86 )
87 )
89 @classmethod
90 def _params_event_ndims(cls):
91 return {"N": 0, "K": 0, "n": 0}
93 @property
94 def N(self):
95 """Population size"""
96 return self._parameters["N"]
98 @property
99 def K(self):
100 """Number of units of interest in population"""
101 return self._parameters["K"]
103 @property
104 def n(self):
105 """Sample size"""
106 return self._parameters["n"]
108 def _default_event_space_bijector(self):
109 return
111 def _event_shape_tensor(self):
112 return jnp.array([], dtype=jnp.int32)
114 def _event_shape(self):
115 return ()
117 def _sample_n(self, n, seed=None):
118 sample = sample_hypergeometric(n, self.N, self.K, self.n, seed=seed)
119 sample = jnp.where(self._n_positive_mask, sample, 0.0)
120 return sample
122 def _log_prob(self, x):
123 numerator = _log_choose(self._K, x) + _log_choose(
124 self._N - self._K, self._n - x
125 )
126 denominator = _log_choose(self._N, self._n)
127 return numerator - denominator
129 def _mode(self):
130 return jnp.floor((self._n + 1) * (self._K + 1) / (self._N + 2))
132 def _mean(self):
133 return self._n * self._K / self._N
135 def _variance(self):
136 n = self._n
137 N = self._N
138 K = self._K
140 return n * (K / N) * (N - K) / N * (N - n) / (N - 1)