Coverage for gemlib/distributions/uniform_integer.py: 96%

45 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""The UniformInteger distribution class""" 

2 

3import jax.numpy as jnp 

4import tensorflow_probability.substrates.jax as tfp 

5from tensorflow_probability.substrates.jax.internal import ( 

6 dtype_util, 

7 parameter_properties, 

8 samplers, 

9) 

10 

11from gemlib.prng_util import sanitize_key 

12 

13tfd = tfp.distributions 

14 

15 

16class UniformInteger(tfd.Distribution): 

17 def __init__( 

18 self, 

19 low, 

20 high, 

21 validate_args=False, 

22 allow_nan_stats=True, 

23 float_dtype=jnp.float32, 

24 name="UniformInteger", 

25 ): 

26 """Integer uniform distribution. 

27 

28 Args: 

29 ---- 

30 low: Integer tensor, lower boundary of the output interval. Must have 

31 `low <= high`. 

32 high: Integer tensor, _inclusive_ upper boundary of the output 

33 interval. Must have `low <= high`. 

34 validate_args: Python `bool`, default `False`. When `True` 

35 distribution parameters are checked for validity despite possibly 

36 degrading runtime performance. When `False` invalid inputs may 

37 silently render incorrect outputs. 

38 allow_nan_stats: Python `bool`, default `True`. When `True`, 

39 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 

40 indicate the result is undefined. When `False`, an exception is 

41 raised if one or more of the statistic's batch members are undefined. 

42 dtype: returned integer dtype when sampling. 

43 float_dtype: returned float dtype of log probability. 

44 name: Python `str` name prefixed to Ops created by this class. 

45 

46 Example 1: sampling 

47 ```python 

48 import jax.numpy as jnp 

49 import jax 

50 from gemlib.distributions.uniform_integer import UniformInteger 

51 

52 key = jax.random.key(10402302) 

53 X = UniformInteger(0, 10, dtype=jnp.int32) 

54 x = X.sample([3, 3], seed=key) 

55 print("samples:", x, "=", [[8, 4, 8], [2, 7, 9], [6, 0, 9]]) 

56 ``` 

57 

58 Example 2: log probability 

59 ```python 

60 import jax.numpy as jnp 

61 from gemlib.distributions.uniform_integer import UniformInteger 

62 

63 X = UniformInteger(0, 10) 

64 lp = X.log_prob(jnp.array([[8, 4, 8], [2, 7, 9], [6, 0, 9]])) 

65 total_lp = jnp.round(tf.math.reduce_sum(lp) * 1e5) / 1e5 

66 print("total lp:", total_lp, "= -20.72327") 

67 ``` 

68 

69 Raises: 

70 ------ 

71 InvalidArgument if `low > high` and `validate_args=False`. 

72 

73 """ 

74 parameters = dict(locals()) 

75 dtype = dtype_util.common_dtype([low, high], jnp.int32) 

76 self._float_dtype = float_dtype 

77 self._low = jnp.asarray(low, dtype=dtype) 

78 self._high = jnp.asarray(high, dtype=dtype) 

79 

80 super().__init__( 

81 dtype=dtype, 

82 reparameterization_type=tfd.FULLY_REPARAMETERIZED, 

83 validate_args=validate_args, 

84 allow_nan_stats=allow_nan_stats, 

85 parameters=parameters, 

86 name=name, 

87 ) 

88 if validate_args is True: 

89 assert jnp.all(self._low < self._high), ( 

90 "Condition low < high failed" 

91 ) 

92 

93 @classmethod 

94 def _parameter_properties(cls, dtype, num_classes=None): # noqa: ARG003 

95 return { 

96 "low": parameter_properties.ParameterProperties( 

97 default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED, 

98 ), 

99 "high": parameter_properties.ParameterProperties( 

100 default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED, 

101 ), 

102 } 

103 

104 @property 

105 def low(self): 

106 """Lower boundary of the output interval.""" 

107 return self._low 

108 

109 @property 

110 def high(self): 

111 """Upper boundary of the output interval.""" 

112 return self._high 

113 

114 @property 

115 def float_dtype(self): 

116 return self._float_dtype 

117 

118 def _event_shape_tensor(self): 

119 return jnp.array([], dtype=jnp.int32) 

120 

121 def _event_shape(self): 

122 return () 

123 

124 def _sample_n(self, n, seed=None): 

125 seed = sanitize_key(seed) 

126 low = self.low 

127 high = self.high 

128 shape = (n,) + self._batch_shape_tensor(low=low, high=high) 

129 

130 samples = samplers.uniform(shape=shape, dtype=jnp.float32, seed=seed) 

131 return low + jnp.floor((high - low) * samples).astype(low.dtype) 

132 

133 def _prob(self, x): 

134 low = jnp.asarray(self.low, self.float_dtype) 

135 high = jnp.asarray(self.high, self.float_dtype) 

136 x = jnp.asarray(x, dtype=self.float_dtype) 

137 

138 return jnp.where( 

139 jnp.isnan(x), 

140 x, 

141 jnp.where( 

142 (x < low) | (x >= high), 

143 jnp.zeros_like(x), 

144 jnp.ones_like(x) / (high - low), 

145 ), 

146 ) 

147 

148 def _log_prob(self, x): 

149 return jnp.log(self._prob(x))