Coverage for gemlib/distributions/continuous_markov.py: 100%
87 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Function for continuous time simulation"""
3from __future__ import annotations
5import math as pymath
6from collections.abc import Callable
7from typing import Any, NamedTuple
9import jax
10import jax.numpy as jnp
11import numpy as np
12import tensorflow_probability.substrates.jax as tfp
14from gemlib.util import batch_gather, transition_coords
16# aliasing for convenience
17ArrayLike = jax.typing.ArrayLike
18tfd = tfp.distributions
19DTYPE = jnp.float32
22class EventList(NamedTuple):
23 """Tracker of an event in an epidemic simulation
25 Attributes:
26 time (float): The time at which the event occurred.
27 transition (int): The type of transition that occurred.
28 unit (int): The unit involved in the event.
29 """
31 time: Any
32 transition: Any
33 unit: Any
36def _one_hot_expand_state(condensed_state: jax.Array) -> jax.Array:
37 """Expand the state of the epidemic to a one-hot representation
38 Args:
39 epidemic_state: The state of the epidemic
40 Returns:
41 The one-hot representation of the epidemic state
42 """
43 # Create one-hot encoded vectors for each state
44 one_hot_states = jax.nn.one_hot(
45 jnp.arange(len(condensed_state)),
46 num_classes=len(condensed_state),
47 dtype=jnp.float32,
48 )
49 # Repeat each one-hot state based on its corresponding count
50 repeated_states = jnp.repeat(one_hot_states, condensed_state, axis=0)
52 # Reshape and transpose to get state per row representation
53 return repeated_states
56def _total_flux(transition_rates, state, incidence_matrix):
57 """Multiplies `transition_rates` by source `state`s to return
58 the total flux along transitions given `state`.
60 Args
61 ----
62 transition_rates: a `[R,N]` tensor of per-unit transition rates
63 for `R` transitions and `N` aggregation units.
64 state: a `[N, S]` tensor of `N` aggregation units and `S` states.
65 incidence_matrix: a `[S, R]` matrix describing the change in `S` for
66 each transition `R`.
68 Returns
69 -------
70 A [R,N] tensor of total flux along each transition, taking into account the
71 availability of units in the source state.
72 """
74 source_state_idx = transition_coords(np.array(incidence_matrix))[:, 0]
75 source_states = batch_gather(
76 state, indices=source_state_idx[:, jnp.newaxis]
77 )
78 transition_rates = jnp.stack(transition_rates, axis=-1)
80 return jnp.einsum("...nr,...nr->...rn", transition_rates, source_states)
83def compute_state(
84 incidence_matrix: np.ndarray,
85 initial_state: jax.Array,
86 event_list: EventList,
87 include_final_state: bool = False,
88):
89 """Given an event list `event_list`, compute a timeseries
90 of state given the model.
92 Args
93 ----
94 incidence_matrix: a `[S,R]` graph incidence matrix for `S`
95 compartments and `R` transitions.
96 initial_state: a `[N,S]` representing the initial state of `N`
97 units by `S` compartments.
98 event_list: the event list, assumed to be sorted by time.
99 include_final_state: should the final state be included in the
100 returned timeseries? If `True`, then the time dimension of
101 the returned tensor will be 1 greater than the length of the
102 event list. If `False` (default) these will be equal.
104 Return
105 ------
106 A `[T, N, S]` tensor where `T` is the number of events, `N` is the
107 number of units, and `S` is the number of states.
108 """
109 event_list = event_list.__class__(*[jnp.array(x) for x in event_list])
111 # initial_state = jnp.array(initial_state)
112 incidence_matrix = jnp.array(incidence_matrix)
114 num_times = event_list.time.shape[-1]
115 num_units = initial_state.shape[-2]
116 num_reactions = incidence_matrix.shape[-1] + 1 # R + pad for ghost events
118 # The technique to handle batch dimensions in the event_list is to
119 # flatten the batch dimensions. That way we can guarantee
120 # that the output dense events have shape [B, T, N, R].
122 batch_dims = event_list.time.shape[:-1]
123 flat_batch_shape = int(pymath.prod(batch_dims))
125 # Compute one-hot encoding of event timeseries
126 event_tensor_shape = (
127 flat_batch_shape,
128 num_times,
129 num_units,
130 num_reactions,
131 ) # [B, T, N, R]
133 hot_indices = jnp.stack(
134 [
135 jnp.repeat(
136 jnp.arange(flat_batch_shape), num_times
137 ), # B 0,0,...,1,1
138 jnp.tile(jnp.arange(num_times), flat_batch_shape), # T 0,1,...,0,1
139 event_list.unit.reshape(-1), # N (flat_batch_shape * num_times)
140 event_list.transition.reshape(
141 -1
142 ), # R (flat_batch_shape * num_times)
143 ],
144 axis=-1,
145 )
147 event_tensor = (
148 jnp.zeros(event_tensor_shape, dtype=initial_state.dtype)
149 .at[tuple(hot_indices.T)]
150 .add(1.0)[..., :-1]
151 )
153 # Compute deltas and cumsum over the state
154 delta = jnp.matmul(event_tensor, incidence_matrix.T)
155 if include_final_state is False:
156 delta = delta[..., :-1, :, :]
158 batched_initial_state = jnp.broadcast_to(
159 initial_state[jnp.newaxis, jnp.newaxis, ...], # Add B and T dimensions
160 shape=(flat_batch_shape,) + (1,) + initial_state.shape,
161 )
162 state = jnp.cumsum(
163 jnp.concatenate([batched_initial_state, delta], axis=-3), axis=-3
164 )
166 return jnp.reshape(state, shape=batch_dims + state.shape[-3:])
169def exponential_propogate(
170 transition_rate_fn: Callable, incidence_matrix: jax.Array
171) -> EventList:
172 """Generates a function for propogating an epidemic forward in time
174 Closure over the transition rate function and the incidence matrix
175 which outline the epidemic dynamics and model structure. The returned
176 function can be used to simulate the epidemic forward in time one step.
178 Args:
179 transition_rate_fn (Callable): a function that takes the current
180 state of the epidemic and returns the transition rates for each
181 unit/meta-population.
182 incidence_matrix (tensor): A `[R, S]` matrix that describes the graph
183 structure of the state transition mode. The rows correspond to the
184 `R` transitions and the columns correspond to the `S` states.
186 Returns:
187 EventList: A NamedTuple that describes the next event in the
188 epidemic.
189 """
190 tr_incidence_matrix = incidence_matrix.T
192 def propogate_fn(time: float, state: jax.Array, seed: int) -> list:
193 """Propogates the state of the epidemic forward in time
195 Args:
196 time (float): Wall clock of the epidemic - can easily recover
197 the time delta
198 state (tensor): `[N,S]` representing the current state.
200 Returns:
201 EventList: The next event in the epidemic.
202 """
203 seed_exp, seed_cat = tfp.random.split_seed(seed, n=2)
204 num_units = state.shape[-2]
206 # compute event rates for all possible events
207 transition_rates = _total_flux(
208 transition_rate_fn(time, state), state, incidence_matrix
209 )
211 # simulate next time
212 t_next = tfd.Exponential(rate=jnp.sum(transition_rates)).sample(
213 seed=seed_exp
214 )
216 # use categorical distribution to get event type and indiviudal id
217 event_id = tfd.Categorical(
218 probs=jnp.reshape(transition_rates, shape=(-1,)),
219 dtype=jnp.int32,
220 ).sample(seed=seed_cat)
222 unit_idx = jnp.mod(event_id, num_units)
223 transition_idx = jnp.floor_divide(event_id, num_units)
225 # update the state
226 new_state = state.at[unit_idx].add(tr_incidence_matrix[transition_idx])
228 return (
229 time + t_next,
230 new_state,
231 EventList(time + t_next, transition_idx, unit_idx),
232 )
234 return propogate_fn
237def continuous_markov_simulation(
238 transition_rate_fn: Callable,
239 initial_state: jax.Array,
240 incidence_matrix: jax.Array,
241 num_markov_jumps: int,
242 initial_time: float = 0.0,
243 seed=jnp.ndarray,
244) -> EventList:
245 """
246 Simulates a continuous-time Markov process
248 Args:
249 transition_rate_fn (Callable): A function that computes the transition
250 rates given the current state and incidence matrix.
251 initial_state (Tensor): A [N, S] tensor, respresenting a population of N
252 units and S states.
253 num_markov_jumps (int): The number of iterations to simulate.
254 incidence_matrix (Tensor): The `[S,R]` incidence matrix representing the
255 state transition model with S states and R transitions.
256 seed (Optional[List(int,int)): The random seed.
257 Returns:
258 EventList: An object containing the simulated epidemic events.
260 #"""
262 dtype = initial_state.dtype
264 propagate_fn = exponential_propogate(transition_rate_fn, incidence_matrix)
266 accum = EventList(
267 time=jnp.ones((num_markov_jumps,), dtype=dtype) * jnp.inf, # time
268 transition=jnp.ones((num_markov_jumps,), dtype=jnp.int32)
269 * incidence_matrix.shape[1], # transition
270 unit=jnp.zeros((num_markov_jumps), dtype=jnp.int32)
271 * jnp.iinfo(np.int32).max, # unit
272 )
274 def cond(args):
275 i, time, state, seed, accum = args
276 transition_rates = _total_flux(
277 transition_rate_fn(time, state), state, incidence_matrix
278 )
279 cont = (i < num_markov_jumps) & (jnp.sum(transition_rates) > 0.0)
280 return cont
282 def body(args):
283 i, time, state, seed, accum = args
284 next_seed, subkey = jax.random.split(seed)
285 next_time, next_state, event = propagate_fn(time, state, subkey)
286 accum = jax.tree_util.tree_map(
287 lambda acc, x: acc.at[i].set(x), accum, event
288 )
290 return (
291 i + 1,
292 next_time,
293 next_state,
294 next_seed,
295 accum,
296 )
298 actual_markov_jumps, _, _, _, accum = jax.lax.while_loop(
299 cond, body, (0, initial_time, initial_state, seed, accum)
300 )
302 times, transitions, units = accum
304 # Pad unused parts of the output TensorArrays if the
305 # loop terminates before num_markov_jumps
306 mask = jnp.arange(num_markov_jumps) < actual_markov_jumps
307 output = EventList(
308 time=jnp.where(mask, times, jnp.inf),
309 transition=jnp.where(mask, transitions, incidence_matrix.shape[1]),
310 unit=jnp.where(mask, units, jnp.iinfo(np.int32).max),
311 )
313 return output
316def continuous_time_log_likelihood(
317 transition_rate_fn: Callable,
318 incidence_matrix: jax.Array,
319 initial_state: jax.Array,
320 initial_time: float,
321 event_list: EventList,
322) -> float:
323 """
324 Computes the log-likelihood of a continuous-time Markov process
325 given the transition rate function,
326 incidence matrix, initial state, number of jumps, and event data.
328 Args:
329 transition_rate_fn (Callable): A function that computes the
330 transition rate given the current state and time.
331 incidence_matrix: The incidence matrix representing
332 the connections between states in `[S,R]` format.
333 initial_state: The initial state of the process as a `[N,R]`.
334 num_jumps (int): The number of jumps to simulate.
335 event (EventList): The event data containing the times
336 and states.
338 Returns:
339 Tensor: The log-likelihood of the continuous-time Markov process.
340 """
341 # construct the epidemic states [T, N, S]
342 states = compute_state(
343 incidence_matrix=incidence_matrix,
344 initial_state=initial_state,
345 event_list=event_list,
346 )
348 # compute the transition rates for each of the states in the event
349 time = jnp.concatenate(
350 [jnp.array([initial_time]), event_list.time], axis=-1
351 )
352 rates = jax.vmap(transition_rate_fn)(
353 time[:-1], states
354 ) # R-tuple of [T,N] arrays
356 total_flux = _total_flux(rates, states, incidence_matrix)
358 indices = jnp.stack(
359 [
360 jnp.arange(event_list.time.shape[0]),
361 jnp.clip( # Clip due to ghost events (values zeroed later)
362 event_list.transition,
363 min=0,
364 max=incidence_matrix.shape[-1] - 1,
365 ),
366 event_list.unit,
367 ],
368 axis=-1,
369 )
371 # compute event specific rate - get indices of the event that happened
372 event_rate = total_flux[indices[:, -3], indices[:, -2], indices[:, -1]]
374 # compute total rate per timestep
375 total_rate = jnp.sum(total_flux, axis=(-2, -1))
376 # total_rate = tf.einsum("tns -> t", total_flux)
378 # compute time deltas
379 time_delta = time[1:] - time[:-1]
381 # compute the log-likelihood
382 loglik_t = -total_rate * time_delta + jnp.log(event_rate)
384 # Zero out for any inf times (i.e. possible padding of event_list chunk)
385 loglik_t = jnp.where(
386 jnp.isfinite(time_delta), loglik_t, jnp.zeros_like(loglik_t)
387 )
389 return jnp.sum(loglik_t)