Coverage for gemlib/distributions/discrete_markov.py: 99%
107 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Functions for chain binomial simulation."""
3from collections.abc import Callable, Iterable
5import jax
6import jax.numpy as jnp
7import numpy as np
8import tensorflow_probability.substrates.jax as tfp
10from gemlib.math import multiply_no_nan
11from gemlib.util import transition_coords_tuple
13tfd = tfp.distributions
15__all__ = [
16 "make_transition_prob_matrix_fn",
17 "compute_state",
18 "discrete_markov_log_prob",
19 "discrete_markov_simulation",
20]
23# scatter to transition matrix
24def _scatter_to_transition_matrix(
25 rates: Iterable[jax.Array],
26 rate_coords: tuple[tuple[int, int], ...],
27 num_states: int,
28) -> jax.Array:
29 """Build an un-normalised Markov transition rate matrix from rates and
30 coordinates
32 Args:
33 rates (Iterable[jax.Array]): An iterable (tuple, list, ...) of ``R``
34 jax.Arrays containing transition rates.
35 rate_coords (tuple[tuple[int, int],...]): A tuple of
36 ``(from_state, to_state)`` coordinate pairs for each of the ``R``
37 rates.
38 num_states (int): The number of possible states ``S``.
40 Returns:
41 jax.Array: An array of shape ``rates[0].shape + (S, S)`` representing
42 the un-normalised transition rate matrix.
43 """
45 # The each shape is (...,S) for S compartments
46 # We're going to scatter the rates into a (S,S,...) data structure
47 # to avoid strided scatter
48 matrix_shape = (num_states, num_states) + rates[0].shape
50 output = jnp.zeros(matrix_shape, dtype=rates[0].dtype)
52 # We iterate over each rate vector and coordinate pair, inserting the
53 # transition rates into the coord-th slice of the right-hand two "S"
54 # dimensions of ``output``.rate_coords
55 for rate, coord in zip(rates, rate_coords, strict=True):
56 output = output.at[coord[-2], coord[-1], ...].set(
57 rate,
58 indices_are_sorted=True,
59 unique_indices=True,
60 mode="promise_in_bounds",
61 )
63 # Transform from (S,S,...) to (...,S,S)
64 return jnp.moveaxis(output, [0, 1], [-2, -1])
67def _approx_expm(rates: jax.Array) -> jax.Array:
68 """Approximates a full Markov transition matrix from a rate matrix
69 through first-order approximation of the matrix exponential
71 Args:
72 rates (jax.Array): un-normalised square rate matrix (i.e. diagonal zero)
73 or batch of such matrices. Accepts shapes ``(B, S, S)``
74 or ``(S, S)``
76 Returns:
77 jax.Array: Approximation to Markov transition matrix of same shape as
78 ``rates``
79 """
81 total_rates = jnp.sum(rates, axis=-1, keepdims=True)
83 prob = 1.0 - jnp.exp(-total_rates)
85 partial_matrix = multiply_no_nan(prob / total_rates, rates)
87 diagonal_values = 1.0 - jnp.sum(partial_matrix, axis=-1)
88 batched_dimensions = 3
89 if rates.ndim == batched_dimensions:
90 # Batched inputs.
91 mask = jnp.eye(rates.shape[-1], dtype=bool)
92 mask = jnp.broadcast_to(mask, rates.shape)
93 else:
94 # Single input.
95 mask = jnp.eye(rates.shape[-1], dtype=bool)
97 return jnp.where(mask, diagonal_values[..., None], partial_matrix)
100def make_transition_prob_matrix_fn(
101 transition_rate_fn: Callable[[float, jax.Array], tuple[jax.Array, ...]],
102 time_delta: float,
103 incidence_matrix: np.ndarray,
104) -> Callable[[float, jax.Array], jax.Array]:
105 """Generates a function for computing the discrete-time
106 Markov transition matrix over a given time interval.
108 Args:
109 transition_rate_fn (Callable[[float, jax.Array], tuple[jax.Array, ...]])
110 : Function which takes a scalar time ``t`` and ``state`` of shape
111 ``(N, S)`` to transition rates as a tuple of ``R`` arrays of
112 shape ``(N,)``.
113 time_delta (float): Size of the time step.
114 incidence_matrix (jax.Array): Array of shape ``(S, R)`` describing
115 allowed transitions between compartments.
117 Returns:
118 Callable[[float, jax.Array], jax.Array]: A function that takes a scalar
119 ``time`` and state of shape ``(N,S)`` to a Markov transition matrix
120 of shape ``(N, S, S)``.
121 """
123 rate_coords = transition_coords_tuple(incidence_matrix)
125 def fn(t: float, state: jax.Array) -> jax.Array:
126 rates = transition_rate_fn(t, state)
128 rate_matrix = _scatter_to_transition_matrix(
129 rates, rate_coords, state.shape[-1]
130 )
131 # ``rate_matrix`` is a tensor of shape
132 # ``(N, S, S)`` where ``rate_matrix[n, i, j]``
133 # gives the rate for the transition from state ``i`` to
134 # state ``j`` in unit ``n``, for ``i``!=``j``.
136 transition_matrix = _approx_expm(rate_matrix * time_delta)
137 return transition_matrix
139 return fn
142@jax.custom_vjp
143def _multinomial_log_prob(total_count, probs, counts):
144 log_unnorm_prob = jnp.sum(multiply_no_nan(jnp.log(probs), counts), axis=-1)
145 neg_log_normalizer = tfp.math.log_combinations(total_count, counts)
146 output = log_unnorm_prob + neg_log_normalizer
147 return output
150# Forward pass
151def _multinomial_log_prob_fwd(total_count, probs, counts):
152 log_unnorm_prob = jnp.sum(multiply_no_nan(jnp.log(probs), counts), axis=-1)
153 neg_log_normalizer = tfp.math.log_combinations(total_count, counts)
154 output = log_unnorm_prob + neg_log_normalizer
155 return output, (probs, counts)
158# Backward pass
159def _multinomial_log_prob_bwd(res, g):
160 probs, counts = res
161 counts = jnp.broadcast_to(counts, probs.shape)
162 g = jnp.expand_dims(g, axis=-1)
163 grad_probs = jnp.nan_to_num(1.0 / probs) * counts * g
164 return (None, grad_probs, None)
167# Registering the custom forward and backward functions for
168# _multinomial_log_prob
169_multinomial_log_prob.defvjp(
170 _multinomial_log_prob_fwd, _multinomial_log_prob_bwd
171)
174def compute_state(
175 initial_state: jax.Array,
176 events: jax.Array,
177 incidence_matrix: np.ndarray,
178 closed: bool = False,
179):
180 """Compute the state array at multiple points in time from the initial state
181 and event array.
183 Args:
184 initial_state (jax.Array): Initial state with shape ``(N, S)``.
185 events (jax.Array): Time series of events with shape ``(T, N, R)``.
186 incidence_matrix (jax.Array): a matrix with shape ``(S, R)`` describing
187 how transitions update the state.
188 closed (bool): if ``True``, return state at ``0..T``, otherwise
189 ``0..T-1``.
191 Returns:
192 jax.Array: Array of shape ``(T, N, S)`` if ``closed=False`` or
193 ``(T+1, N, S)`` if ``closed=True``, describing the state of the
194 system at times ``0..T`` or ``0..T-1``, respectively.
195 """
196 increments = jnp.einsum("...tmr,sr->...tms", events, incidence_matrix)
198 if not closed:
199 padding = jnp.zeros_like(increments[..., :1, :, :])
200 cum_increments = jnp.concatenate(
201 (padding, jnp.cumsum(increments[..., :-1, :, :], axis=-3)), axis=-3
202 )
203 else:
204 padding = jnp.zeros_like(increments[..., :1, :, :])
205 cum_increments = jnp.concatenate(
206 (padding, jnp.cumsum(increments, axis=-3)), axis=-3
207 )
208 state = cum_increments + jnp.expand_dims(initial_state, axis=-3)
209 return state
212def chain_binomial_propagate(
213 transition_matrix_fn: Callable[[float, jax.Array], jax.Array],
214) -> Callable[[float, jax.Array, jax.Array], tuple[jax.Array, jax.Array]]:
215 """Propagates the state of a population according to discrete time dynamics.
217 Args:
218 transition_matrix_fn: a function (t: float, state: Array) -> Array
219 which takes a time ``t`` and state of shape ``(N, S)``, and
220 returns a Markov transition probability matrix of shape
221 ``(N, S, S)``.
222 A suitable function is generated by make_transition_prob_matrix_fn.
224 Returns:
225 A function (t: float, state: jax.Array, seed: int) ->
226 tuple[jax.Array, jax.Array] that propagates a state of shape
227 ``(N, S)`` at time ``t``, returning an ``events`` matrix of shape
228 ``(N, S, S)`` and the state at the next time step.
229 """
231 def propagate_fn(t, state, seed):
232 markov_transition_matrix = transition_matrix_fn(t, state)
233 num_states = markov_transition_matrix.shape[-1]
234 prev_probs = jnp.zeros_like(markov_transition_matrix[..., :, 0])
235 counts = jnp.zeros(
236 markov_transition_matrix.shape[:-1] + (0,),
237 dtype=markov_transition_matrix.dtype,
238 )
240 total_count = state.astype(markov_transition_matrix.dtype)
242 # Generates num_states -1 independent random number generator keys.
243 keys = jax.random.split(seed, num_states - 1)
245 for i in range(num_states - 1):
246 probs = markov_transition_matrix[..., :, i]
247 binom = tfd.Binomial(
248 total_count=total_count,
249 probs=jnp.clip(probs / (1.0 - prev_probs + 1e-10), 0.0, 1.0),
250 )
251 sample = binom.sample(seed=keys[i])
252 counts = jnp.concatenate(
253 [counts, sample[..., jnp.newaxis]], axis=-1
254 )
255 total_count -= sample
256 prev_probs += probs
258 # Final state
259 counts = jnp.concatenate(
260 [counts, total_count[..., jnp.newaxis]], axis=-1
261 )
263 # Aggregate new state
264 new_state = jnp.sum(counts, axis=-2)
266 return counts, new_state
268 return propagate_fn
271def discrete_markov_simulation(
272 transition_prob_matrix_fn: Callable[[float, jax.Array], jax.Array],
273 state: jax.Array,
274 start: float,
275 end: float,
276 time_step: float,
277 seed: jax.Array,
278):
279 """Simulates from a discrete time Markov state transition model using
280 multinomial sampling across rows of the transition matrix.
282 Args:
283 transition_prob_matrix_fn (Callable[[float, jax.Array], jax.Array]):
284 A function that takes a scalar ``time`` and state of shape
285 ``(N, S)`` to a Markov transition matrix of shape ``(N, S, S)``.
286 Here ``N`` is the number of units and ``S`` the dimension of the
287 state of each unit.
288 state (jax.Array): Initial state of shape ``(N, S)``.
289 start (float): Start time of simulation.
290 end (float): End time of simulation.
291 time_step (float): Duration of discrete timestep.
292 seed (jax.Array) : Random seed for sampling.
294 Returns:
295 tuple[jax.Array, jax.Array]: A vector of shape ``(T,)`` of the times
296 of the start of each timestep, and an array of shape
297 ``(T, N, S, S)`` containing the transitions within each timestep.
298 """
300 state = jnp.asarray(state)
301 if state.ndim == 1:
302 state = state[None, ...]
304 propagate = chain_binomial_propagate(transition_prob_matrix_fn)
306 times = jnp.arange(start, end, time_step, dtype=state.dtype)
307 keys = jax.random.split(seed, times.shape[0])
309 def scan_fn(state, elems):
310 t, current_key = elems
311 event_counts, new_state = propagate(t, state, current_key)
312 new_state = new_state.astype(state.dtype)
313 return new_state, event_counts
315 final_state, all_events = jax.lax.scan(scan_fn, state, (times, keys))
317 return times, all_events
320def discrete_markov_log_prob(
321 events: jax.Array,
322 init_state: jax.Array,
323 init_step: float,
324 time_delta: float,
325 transition_prob_matrix_fn: Callable[[float, jax.Array], jax.Array],
326 incidence_matrix: np.ndarray,
327) -> jax.Array:
328 """Calculates an unnormalised log_prob function for a discrete time epidemic
329 model.
331 Args:
332 events (jax.Array): a ``(T, N, R)`` array of transition events,
333 where the simulation has ``T`` times, ``N`` units and ``R``
334 transitions.
335 init_state (jax.Array): a vector of shape ``(N, S)``, giving the initial
336 state of the epidemic, where ``S`` is the number of states.
337 init_step (float): the initial time.
338 time_delta (float): the duration of the discrete time step.
339 transition_prob_matrix_fn (Callable[[float, jax.Array], jax.Array]):
340 A function that takes a scalar time and state of shape ``(N, S)``
341 to a Markov transition matrix of shape ``(N, S, S)``
342 incidence_matrix (np.ndarray): a matrix with shape ``(S, R)`` describing
343 how transitions update the state.
344 Returns:
345 jax.Array: A scalar log probability for these events given the initial
346 state and model.
347 """
349 num_times = events.shape[-3]
350 num_states = init_state.shape[-1]
352 # Construct the state at all timepoints, of shape ``(T, N, S)``
353 state_timeseries = compute_state(init_state, events, incidence_matrix)
355 times = init_step + time_delta * jnp.arange(num_times)
357 transition_prob_matrix = jax.vmap(transition_prob_matrix_fn)(
358 times, state_timeseries
359 )
361 rate_coords = transition_coords_tuple(incidence_matrix)
363 # Calculate ``event_matrix```` of shape ``(T, N, S, S)``, where
364 # ``event_matrix[t, n, i, j]```` is the number of items moving from state
365 # ``i```` to ``j`` at timestep ``t`` and in unit ``n``.
367 event_matrix = _scatter_to_transition_matrix(
368 jnp.unstack(events, axis=-1), rate_coords, num_states
369 )
371 # Setting the diagonal
372 diagonal = state_timeseries - jnp.sum(event_matrix, axis=-1)
373 idx = jnp.arange(event_matrix.shape[-1])
374 event_matrix = event_matrix.at[..., idx, idx].set(diagonal)
376 logp = _multinomial_log_prob(
377 state_timeseries, transition_prob_matrix, event_matrix
378 )
379 return jnp.sum(logp)