Coverage for gemlib/distributions/uniform_integer.py: 96%
45 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""The UniformInteger distribution class"""
3import jax.numpy as jnp
4import tensorflow_probability.substrates.jax as tfp
5from tensorflow_probability.substrates.jax.internal import (
6 dtype_util,
7 parameter_properties,
8 samplers,
9)
11from gemlib.prng_util import sanitize_key
13tfd = tfp.distributions
16class UniformInteger(tfd.Distribution):
17 def __init__(
18 self,
19 low,
20 high,
21 validate_args=False,
22 allow_nan_stats=True,
23 float_dtype=jnp.float32,
24 name="UniformInteger",
25 ):
26 """Integer uniform distribution.
28 Args:
29 ----
30 low: Integer tensor, lower boundary of the output interval. Must have
31 `low <= high`.
32 high: Integer tensor, _inclusive_ upper boundary of the output
33 interval. Must have `low <= high`.
34 validate_args: Python `bool`, default `False`. When `True`
35 distribution parameters are checked for validity despite possibly
36 degrading runtime performance. When `False` invalid inputs may
37 silently render incorrect outputs.
38 allow_nan_stats: Python `bool`, default `True`. When `True`,
39 statistics (e.g., mean, mode, variance) use the value "`NaN`" to
40 indicate the result is undefined. When `False`, an exception is
41 raised if one or more of the statistic's batch members are undefined.
42 dtype: returned integer dtype when sampling.
43 float_dtype: returned float dtype of log probability.
44 name: Python `str` name prefixed to Ops created by this class.
46 Example 1: sampling
47 ```python
48 import jax.numpy as jnp
49 import jax
50 from gemlib.distributions.uniform_integer import UniformInteger
52 key = jax.random.key(10402302)
53 X = UniformInteger(0, 10, dtype=jnp.int32)
54 x = X.sample([3, 3], seed=key)
55 print("samples:", x, "=", [[8, 4, 8], [2, 7, 9], [6, 0, 9]])
56 ```
58 Example 2: log probability
59 ```python
60 import jax.numpy as jnp
61 from gemlib.distributions.uniform_integer import UniformInteger
63 X = UniformInteger(0, 10)
64 lp = X.log_prob(jnp.array([[8, 4, 8], [2, 7, 9], [6, 0, 9]]))
65 total_lp = jnp.round(tf.math.reduce_sum(lp) * 1e5) / 1e5
66 print("total lp:", total_lp, "= -20.72327")
67 ```
69 Raises:
70 ------
71 InvalidArgument if `low > high` and `validate_args=False`.
73 """
74 parameters = dict(locals())
75 dtype = dtype_util.common_dtype([low, high], jnp.int32)
76 self._float_dtype = float_dtype
77 self._low = jnp.asarray(low, dtype=dtype)
78 self._high = jnp.asarray(high, dtype=dtype)
80 super().__init__(
81 dtype=dtype,
82 reparameterization_type=tfd.FULLY_REPARAMETERIZED,
83 validate_args=validate_args,
84 allow_nan_stats=allow_nan_stats,
85 parameters=parameters,
86 name=name,
87 )
88 if validate_args is True:
89 assert jnp.all(self._low < self._high), (
90 "Condition low < high failed"
91 )
93 @classmethod
94 def _parameter_properties(cls, dtype, num_classes=None): # noqa: ARG003
95 return {
96 "low": parameter_properties.ParameterProperties(
97 default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED,
98 ),
99 "high": parameter_properties.ParameterProperties(
100 default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED,
101 ),
102 }
104 @property
105 def low(self):
106 """Lower boundary of the output interval."""
107 return self._low
109 @property
110 def high(self):
111 """Upper boundary of the output interval."""
112 return self._high
114 @property
115 def float_dtype(self):
116 return self._float_dtype
118 def _event_shape_tensor(self):
119 return jnp.array([], dtype=jnp.int32)
121 def _event_shape(self):
122 return ()
124 def _sample_n(self, n, seed=None):
125 seed = sanitize_key(seed)
126 low = self.low
127 high = self.high
128 shape = (n,) + self._batch_shape_tensor(low=low, high=high)
130 samples = samplers.uniform(shape=shape, dtype=jnp.float32, seed=seed)
131 return low + jnp.floor((high - low) * samples).astype(low.dtype)
133 def _prob(self, x):
134 low = jnp.asarray(self.low, self.float_dtype)
135 high = jnp.asarray(self.high, self.float_dtype)
136 x = jnp.asarray(x, dtype=self.float_dtype)
138 return jnp.where(
139 jnp.isnan(x),
140 x,
141 jnp.where(
142 (x < low) | (x >= high),
143 jnp.zeros_like(x),
144 jnp.ones_like(x) / (high - low),
145 ),
146 )
148 def _log_prob(self, x):
149 return jnp.log(self._prob(x))