Coverage for gemlib/distributions/kcategorical.py: 92%
37 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1import jax
2import jax.numpy as jnp
3import tensorflow_probability.substrates.jax as tfp
4from tensorflow_probability.substrates.jax.internal import (
5 parameter_properties,
6 tensorshape_util,
7)
9tfd = tfp.distributions
12def _log_choose(N, k): # noqa: N803
13 return (
14 jax.lax.lgamma(N + 1.0)
15 - jax.lax.lgamma(k + 1.0)
16 - jax.lax.lgamma(N - k + 1.0)
17 )
20class UniformKCategorical(tfd.Distribution):
21 def __init__(
22 self,
23 k,
24 mask,
25 float_dtype=jnp.float32,
26 validate_args=False,
27 allow_nan_stats=True,
28 name="UniformKCategorical",
29 ):
30 """Uniform K-Categorical distribution.
32 Given a set of items indexed $1,...,n$ and a boolean mask of the same
33 shape sample $k$ indices without replacement.
35 :param k: the number of indices to sample
36 :param mask: a boolean mask with `True` where an element is valid,
37 otherwise `False`
38 :param validate_args: Whether to validate args
39 :param allow_nan_stats: allow nan stats
40 :param name: name of the distribution
42 Example 1: Generate 4 samples of size k given a mask
43 import jax
44 import jax.numpy as jnp
45 from jax import random
46 from gemlib.distributions.kcategorical import UniformKCategorical
48 # Mask determines which indices are valid and returned by sample().
49 # Below combinations of the indices 0, 3, 4, and 6 will be realised
50 # when sampling.
51 mask = [True, False, False, True, True, False, True]
52 X = UniformKCategorical(k=3, mask=mask)
53 x = X.sample(4)
54 print(x)
56 Example 2: Probability of a given sample
57 import jax
58 import jax.numpy as jnp
59 from jax import random
60 from gemlib.distributions.kcategorical import UniformKCategorical
62 # Probability of drawing an unordered sample of
63 # size k from N items where N equals the number
64 # of True states in the mask determines N.
65 mask = [True, False, False, True, True, False, True]
66 sample = jnp.array[4, 3, 6])
67 X = UniformKCategorical(k=sample.shape[-1], mask=mask)
68 lp = X.log_prob(sample)
69 print('prob:', jnp.exp(lp))
71 """
72 parameters = dict(locals())
73 self._mask = jnp.array(mask)
74 self._k = k
75 self._float_dtype = float_dtype
76 dtype = jnp.int32
78 super().__init__(
79 dtype=dtype,
80 validate_args=validate_args,
81 reparameterization_type=tfd.FULLY_REPARAMETERIZED,
82 allow_nan_stats=allow_nan_stats,
83 parameters=parameters,
84 name=name,
85 )
87 @classmethod
88 def _parameter_properties(cls, dtype, num_classes=None): # noqa: ARG003
89 return {
90 "k": parameter_properties.ParameterProperties(),
91 "mask": parameter_properties.ParameterProperties(
92 default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED,
93 event_ndims=1,
94 ),
95 }
97 @property
98 def k(self):
99 return self._parameters["k"]
101 @property
102 def mask(self):
103 return self._parameters["mask"]
105 def _batch_shape(self):
106 return jnp.shape(self._mask)[:-1]
108 def _event_shape(self):
109 return tensorshape_util.constant_value_as_shape(
110 jnp.expand_dims(self._k, axis=0)
111 )
113 def _sample_n(self, n, seed=None):
114 seed = tfp.random.sanitize_seed(seed, salt="KCategorical._sample_n")
115 u = tfd.Uniform(
116 low=jnp.zeros(self._mask.shape, dtype=jnp.float32),
117 high=jnp.ones(self._mask.shape, dtype=jnp.float32),
118 ).sample((n,), seed=seed)
119 u = u * self._parameters["mask"]
121 _, x = jax.lax.top_k(u, k=self._k)
122 return x
124 def _log_prob(self, _):
125 N = jnp.count_nonzero(self._mask, axis=-1)
126 return -_log_choose(N, self._k).astype(self._float_dtype)