Coverage for gemlib/distributions/discrete_time_state_transition_model.py: 96%
84 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Describes a DiscreteTimeStateTransitionModel."""
3from __future__ import annotations
5from collections.abc import Callable, Sequence
7import jax
8import jax.numpy as jnp
9import numpy as np
10import tensorflow_probability.substrates.jax as tfp
12from gemlib.distributions.discrete_markov import (
13 compute_state,
14 discrete_markov_log_prob,
15 discrete_markov_simulation,
16 make_transition_prob_matrix_fn,
17)
18from gemlib.func_util import maybe_combine_fn
19from gemlib.math import cumsum_np
20from gemlib.prng_util import sanitize_key
21from gemlib.tensor_util import broadcast_fn_to
22from gemlib.util import (
23 batch_gather,
24 transition_coords,
25)
27Array = jax.Array
28ArrayLike = jax.typing.ArrayLike
29tfd = tfp.distributions
32class DiscreteTimeStateTransitionModel(tfd.Distribution):
33 """Discrete-time state transition model
35 A discrete-time state transition model assumes a population of
36 individuals is divided into a number of mutually exclusive states,
37 where transitions between states occur according to a Markov process.
38 Such models are commonly found in epidemiological and ecological
39 applications, where rapid implementation and modification is necessary.
41 This class provides a programmable implementation of the discrete-time
42 state transition model, compatible with TensorFlow Probability.
45 Example
46 -------
47 A homogeneously mixing SIR model implementation::
49 import jax.numpy as jnp
50 from gemlib.distributions import DiscreteTimeStateTransitionModel
52 # Initial state, counts per compartment (S, I, R), for one
53 # population
54 initial_state = jnp.array([[99, 1, 0]], jnp.float32)
56 # Note that the incidence_matrix is treated as a static parameter,
57 # and so is supplied as a numpy ndarray, not a jax Array.
58 incidence_matrix = np.array(
59 [
60 [-1, 0],
61 [1, -1],
62 [0, 1],
63 ],
64 dtype=np.float32,
65 )
68 def si_rate(t, state):
69 return 0.28 * state[:, 1] / jnp.sum(state, axis=-1)
72 def ir_rate(t, state):
73 return 0.14
76 # Instantiate model
77 sir = DiscreteTimeStateTransitionModel(
78 transition_rate_fn=(si_rate, ir_rate),
79 incidence_matrix=incidence_matrix,
80 initial_state=initial_state,
81 num_steps=100,
82 )
84 # One realisation of the epidemic process
85 sim = sir.sample(seed=0)
86 """
88 def __init__(
89 self,
90 transition_rate_fn: Sequence[Callable[[float, ArrayLike], ArrayLike]]
91 | Callable[[float, ArrayLike], tuple[ArrayLike, ...]],
92 incidence_matrix: np.typing.NDArray,
93 initial_state: ArrayLike,
94 num_steps: int,
95 initial_step: int = 0,
96 time_delta: float = 1.0,
97 validate_args: bool = False,
98 allow_nan_stats: bool = True,
99 name: str = "DiscreteTimeStateTransitionModel",
100 ):
101 """Initialise a discrete-time state transition model.
103 Args:
104 transition_rate_fn: Either a sequence of ``R`` callables
105 ``(t: float, state: ArrayLike) -> ArrayLike``, each returning a
106 Array of rates of shape ``(N,)``, or a callable
107 ``(t: float, state: ArrayLike) -> tuple[ArrayLike, ...]``,
108 returning a tuple of ``R`` vectors of shape ``(N,)``.
109 Here ``N`` is the number of units in the simulation.
110 In the first (preferred) form, each callable returns the respective
111 transition rate.
112 In the second form, the single callable returns ``R`` transition
113 rates.
114 **Note**: the second form will be
115 deprecated in future releases of ``gemlib``.
116 incidence_matrix: incidence matrix of shape ``(S, R)``
117 for ``S`` states and ``R`` transitions between states.
118 initial_state: Initial state of the model, of shape
119 ``(N,S)`` containing the counts in each of ``N`` units and
120 ``S`` states.
121 We require ``initial_state.shape[-1]==incidence_matrix.shape[-2]``
122 num_steps: the number of time steps simulated by the model.
123 initial_step: time ``t`` at the start of the simulation.
124 time_delta: the duration of the discretized time step.
125 """
126 parameters = dict(locals())
128 self._incidence_matrix = np.asarray(incidence_matrix)
130 self._source_states = _compute_source_states(self._incidence_matrix)
132 initial_state = jnp.asarray(initial_state)
133 self._transition_prob_matrix_fn = make_transition_prob_matrix_fn(
134 broadcast_fn_to(
135 maybe_combine_fn(transition_rate_fn),
136 jnp.shape(initial_state)[:-1],
137 ),
138 time_delta,
139 self._incidence_matrix,
140 )
141 super().__init__(
142 dtype=initial_state.dtype,
143 reparameterization_type=tfd.FULLY_REPARAMETERIZED,
144 validate_args=validate_args,
145 allow_nan_stats=allow_nan_stats,
146 parameters=parameters,
147 name=name,
148 )
150 @property
151 def transition_rate_fn(self):
152 return self._parameters["transition_rate_fn"]
154 @property
155 def incidence_matrix(self):
156 return self._parameters["incidence_matrix"]
158 @property
159 def initial_state(self):
160 return self._parameters["initial_state"]
162 @property
163 def initial_step(self):
164 return self._parameters["initial_step"]
166 @property
167 def source_states(self):
168 return self._source_states
170 @property
171 def time_delta(self):
172 return self._parameters["time_delta"]
174 @property
175 def num_steps(self):
176 return self._parameters["num_steps"]
178 @property
179 def num_units(self):
180 return jnp.array(self._parameters["initial_state"]).shape[-2]
182 @property
183 def num_states(self):
184 return jnp.array(self._parameters["initial_state"]).shape[-1]
186 def _batch_shape(self):
187 return []
189 def _event_shape(self):
190 shape = (
191 int(self.num_steps), # T
192 jnp.shape(self.initial_state)[-2], # N
193 jnp.shape(self.incidence_matrix)[-1], # S
194 )
195 return shape
197 def compute_state(
198 self, events: Array, include_final_state: bool = False
199 ) -> jax.Array:
200 """Computes a state timeseries from a sequence of transition events.
202 Args:
203 events: an array of events of shape ``(T,N,R)`` where
204 ``T`` is the number of timesteps,
205 ``N=self.num_units`` the number of units and
206 ``R=self.incidence_matrix.shape[0]``
207 the number of transitions.
209 include_final_state: If ``False`` (default)
210 the result does not include the final state.
211 Otherwise the result includes the final state.
213 Returns:
214 An array of shape ``(T, N, S)`` if
215 ``include_final_state==False`` or ``(T+1, N, S)``
216 if ``include_final_state==True``, where ``S=self.num_states``,
217 giving the number of individuals in each state at each time
218 point for each unit.
219 """
220 return compute_state(
221 incidence_matrix=self._incidence_matrix,
222 initial_state=self.initial_state,
223 events=events,
224 closed=include_final_state,
225 )
227 def transition_prob_matrix(self, events: None | ArrayLike) -> Array:
228 """Compute the Markov transition probability matrix.
230 Args:
231 events: None, or an array of shape ``(T, N, R)``, where
232 ``T`` is the number of timesteps,
233 ``N=self.num_units` the number of units and
234 ``R=self.incidence_matrix.shape[0]``
235 the number of transitions.
237 Returns:
238 Transition probabilty matrix. If ``events`` is None,
239 this matrix is of shape ``(N, S, S)``, transition probability
240 matrix associated with the initial state
241 (``self.initial_state``), of shape ``(N, S)``, where
242 ``S=self.num_states`` is the number of distinct states.
243 Otherwise, this matrix is of shape ``(T, N, S, S)``,
244 representing the transition probability matrix at each timestep.
245 of shape ``(T, N, S, S)``.
246 """
247 if events is None:
248 return self._transition_prob_matrix_fn(
249 self.initial_step, self.initial_state
250 )
252 state = self.compute_state(events)
253 times = jnp.arange(
254 self.initial_step,
255 self.initial_step + self.time_delta * self.num_steps,
256 self.time_delta,
257 )
259 return jax.vmap(self._transition_prob_matrix_fn)(times, state)
261 def _sample_n(self, n: int, seed: int | Array | None = None) -> Array:
262 """Runs `n` simulations of the epidemic model.
264 Args:
265 n: number of simulations to run.
266 seed: an integer or a JAX PRNG key.
268 Returns:
269 an array of shape ``(n, T, N, R)``
270 for ``n`` samples, ``T=self.num_steps`` time steps,
271 ``N=self.num_units`` units and ``R=self.transition_matrix.shape[0]``
272 transitions, containing the
273 number of events for each timestep, unit and transition.
274 """
275 key = sanitize_key(seed)
276 keys = jax.random.split(sanitize_key(key), num=n)
278 def one_sample(key):
279 _, events = discrete_markov_simulation(
280 transition_prob_matrix_fn=self._transition_prob_matrix_fn,
281 state=self.initial_state,
282 start=self.initial_step,
283 end=self.initial_step + self.num_steps * self.time_delta,
284 time_step=self.time_delta,
285 seed=key,
286 )
287 return events
289 sim = jax.vmap(one_sample)(keys)
290 indices = transition_coords(self._incidence_matrix)
292 return batch_gather(sim, indices)
294 def _log_prob(self, y):
295 y = jnp.asarray(y)
297 batch_shape = y.shape[:-3]
298 y_flat = y.reshape((-1,) + y.shape[-3:])
300 def one_log_prob(y):
301 return discrete_markov_log_prob(
302 events=y,
303 init_state=jnp.asarray(self.initial_state, dtype=y.dtype),
304 init_step=jnp.asarray(self.initial_step, dtype=y.dtype),
305 time_delta=jnp.asarray(self.time_delta, dtype=y.dtype),
306 transition_prob_matrix_fn=self._transition_prob_matrix_fn,
307 incidence_matrix=self._incidence_matrix,
308 )
310 log_probs = jax.vmap(one_log_prob)(y_flat)
311 return log_probs.reshape(batch_shape)
314# Unclear why we don't just use gemlib.utils.transition_coords
315def _compute_source_states(
316 incidence_matrix: np.ndarray, dtype=np.int32
317) -> np.ndarray:
318 """Computes the indices of the source states for each
319 transition in a state transition model.
321 Args:
322 incidence_matrix (jax.Array): incidence matrix of shape ``(S, R)``
323 for ``S`` states and ``R`` transitions between states.
325 Returns:
326 jax.Array: an array of shape ``(R,)`` containing indices of source
327 states.
328 """
329 incidence_matrix = np.transpose(incidence_matrix)
330 source_states = np.sum(
331 cumsum_np(
332 np.clip(-incidence_matrix, a_min=0, a_max=1),
333 axis=-1,
334 reverse=True,
335 exclusive=True,
336 ),
337 axis=-1,
338 )
339 return source_states.astype(dtype)