Coverage for gemlib/distributions/kcategorical.py: 92%

37 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1import jax 

2import jax.numpy as jnp 

3import tensorflow_probability.substrates.jax as tfp 

4from tensorflow_probability.substrates.jax.internal import ( 

5 parameter_properties, 

6 tensorshape_util, 

7) 

8 

9tfd = tfp.distributions 

10 

11 

12def _log_choose(N, k): # noqa: N803 

13 return ( 

14 jax.lax.lgamma(N + 1.0) 

15 - jax.lax.lgamma(k + 1.0) 

16 - jax.lax.lgamma(N - k + 1.0) 

17 ) 

18 

19 

20class UniformKCategorical(tfd.Distribution): 

21 def __init__( 

22 self, 

23 k, 

24 mask, 

25 float_dtype=jnp.float32, 

26 validate_args=False, 

27 allow_nan_stats=True, 

28 name="UniformKCategorical", 

29 ): 

30 """Uniform K-Categorical distribution. 

31 

32 Given a set of items indexed $1,...,n$ and a boolean mask of the same 

33 shape sample $k$ indices without replacement. 

34 

35 :param k: the number of indices to sample 

36 :param mask: a boolean mask with `True` where an element is valid, 

37 otherwise `False` 

38 :param validate_args: Whether to validate args 

39 :param allow_nan_stats: allow nan stats 

40 :param name: name of the distribution 

41 

42 Example 1: Generate 4 samples of size k given a mask 

43 import jax 

44 import jax.numpy as jnp 

45 from jax import random 

46 from gemlib.distributions.kcategorical import UniformKCategorical 

47 

48 # Mask determines which indices are valid and returned by sample(). 

49 # Below combinations of the indices 0, 3, 4, and 6 will be realised 

50 # when sampling. 

51 mask = [True, False, False, True, True, False, True] 

52 X = UniformKCategorical(k=3, mask=mask) 

53 x = X.sample(4) 

54 print(x) 

55 

56 Example 2: Probability of a given sample 

57 import jax 

58 import jax.numpy as jnp 

59 from jax import random 

60 from gemlib.distributions.kcategorical import UniformKCategorical 

61 

62 # Probability of drawing an unordered sample of 

63 # size k from N items where N equals the number 

64 # of True states in the mask determines N. 

65 mask = [True, False, False, True, True, False, True] 

66 sample = jnp.array[4, 3, 6]) 

67 X = UniformKCategorical(k=sample.shape[-1], mask=mask) 

68 lp = X.log_prob(sample) 

69 print('prob:', jnp.exp(lp)) 

70 

71 """ 

72 parameters = dict(locals()) 

73 self._mask = jnp.array(mask) 

74 self._k = k 

75 self._float_dtype = float_dtype 

76 dtype = jnp.int32 

77 

78 super().__init__( 

79 dtype=dtype, 

80 validate_args=validate_args, 

81 reparameterization_type=tfd.FULLY_REPARAMETERIZED, 

82 allow_nan_stats=allow_nan_stats, 

83 parameters=parameters, 

84 name=name, 

85 ) 

86 

87 @classmethod 

88 def _parameter_properties(cls, dtype, num_classes=None): # noqa: ARG003 

89 return { 

90 "k": parameter_properties.ParameterProperties(), 

91 "mask": parameter_properties.ParameterProperties( 

92 default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED, 

93 event_ndims=1, 

94 ), 

95 } 

96 

97 @property 

98 def k(self): 

99 return self._parameters["k"] 

100 

101 @property 

102 def mask(self): 

103 return self._parameters["mask"] 

104 

105 def _batch_shape(self): 

106 return jnp.shape(self._mask)[:-1] 

107 

108 def _event_shape(self): 

109 return tensorshape_util.constant_value_as_shape( 

110 jnp.expand_dims(self._k, axis=0) 

111 ) 

112 

113 def _sample_n(self, n, seed=None): 

114 seed = tfp.random.sanitize_seed(seed, salt="KCategorical._sample_n") 

115 u = tfd.Uniform( 

116 low=jnp.zeros(self._mask.shape, dtype=jnp.float32), 

117 high=jnp.ones(self._mask.shape, dtype=jnp.float32), 

118 ).sample((n,), seed=seed) 

119 u = u * self._parameters["mask"] 

120 

121 _, x = jax.lax.top_k(u, k=self._k) 

122 return x 

123 

124 def _log_prob(self, _): 

125 N = jnp.count_nonzero(self._mask, axis=-1) 

126 return -_log_choose(N, self._k).astype(self._float_dtype)