Coverage for gemlib/distributions/continuous_time_state_transition_model.py: 97%
61 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Continuous time state transition model"""
3from __future__ import annotations
5import math as pymath
6from collections.abc import Callable, Sequence
8import jax
9import jax.numpy as jnp
10import numpy as np
11import tensorflow_probability.substrates.jax as tfp
13from gemlib.distributions.continuous_markov import (
14 EventList,
15 compute_state,
16 continuous_markov_simulation,
17 continuous_time_log_likelihood,
18)
19from gemlib.func_util import maybe_combine_fn
20from gemlib.prng_util import sanitize_key
21from gemlib.tensor_util import broadcast_fn_to
23# aliasing for convenience
24Array = jax.Array
25ArrayLike = jax.typing.ArrayLike
26NDArray = np.typing.NDArray
27tfd = tfp.distributions
30class ContinuousTimeStateTransitionModel(tfd.Distribution):
31 """Continuous time state transition model.
33 Example:
34 A homogeneously mixing SIR model implementation::
36 import jax.numpy as jnp
37 from gemlib.distributions import ContinuousTimeStateTransitionModel
39 # Initial state, counts per compartment (S, I, R), for one
40 # population
41 initial_state = jnp.array([[99, 1, 0]], jnp.float32)
43 # Note that the incidence_matrix is treated as a static parameter,
44 # and so is supplied as a numpy ndarray, not a jax Array.
45 incidence_matrix = np.array(
46 [
47 [-1, 0],
48 [1, -1],
49 [0, 1],
50 ],
51 dtype=np.float32,
52 )
55 def si_rate(t, state):
56 return 0.28 * state[:, 1] / jnp.sum(state, axis=-1)
59 def ir_rate(t, state):
60 return 0.14
63 # Instantiate model
64 sir = ContinuousTimeStateTransitionModel(
65 transition_rate_fn=(si_rate, ir_rate),
66 incidence_matrix=incidence_matrix,
67 initial_state=initial_state,
68 num_steps=100,
69 )
71 # One realisation of the epidemic process
72 sim = sir.sample(seed=0)
74 # Compute the log probability of observing `sim`
75 sir.log_prob(sim)
76 """
78 def __init__(
79 self,
80 transition_rate_fn: Sequence[Callable[[float, ArrayLike], ArrayLike]]
81 | Callable[[float, ArrayLike], tuple[ArrayLike, ...]],
82 incidence_matrix: NDArray,
83 initial_state: Array,
84 num_steps: int,
85 initial_time: float = 0.0,
86 validate_args: bool = False,
87 allow_nan_stats: bool = True,
88 name: str = "ContinuousTimeStateTransitionModel",
89 ):
90 """
91 Initializes a ContinuousTimeStateTransitionModel object.
93 Args:
94 transition_rate_fn: Either a list of callables of the form
95 :code:`fn(t: float, state: Tensor) -> Tensor` or a Python callable
96 of the form :code:`fn(t: float, state: Tensor) -> tuple(Tensor,...)`
97 . In the first
98 (preferred) form, each callable in the list corresponds to the
99 respective transition in :code:`incidence_matrix`. In the second
100 form, the callable should return a :code:`tuple` of transition rate
101 tensors corresponding to transitions in :code:`incidence_matrix`.
102 **Note**: the second form will be deprecated in future releases of
103 :code:`gemlib`.
104 incidence_matrix: Matrix representing the incidence of transitions
105 between states.
106 initial_state: A :code:`[N, S]` tensor containing the initial state of
107 the population of :code:`N` units in :code:`S` epidemiological
108 classes.
109 num_steps: the number of markov jumps for a single iteration.
110 initial_time: Initial time of the model. Defaults to 0.0.
111 name: Name of the model. Defaults to
112 "ContinuousTimeStateTransitionModel".
114 """
115 parameters = dict(locals())
117 self._incidence_matrix = jnp.asarray(incidence_matrix)
118 self._initial_state = initial_state
119 self._initial_time = jnp.asarray(initial_time)
121 self._transition_rate_fn = maybe_combine_fn(transition_rate_fn)
123 dtype = EventList(
124 time=self._initial_time.dtype,
125 transition=jnp.int32,
126 unit=jnp.int32,
127 )
129 super().__init__(
130 dtype=dtype,
131 reparameterization_type=tfd.FULLY_REPARAMETERIZED,
132 validate_args=validate_args,
133 allow_nan_stats=allow_nan_stats,
134 parameters=parameters,
135 name=name,
136 )
138 @property
139 def transition_rate_fn(self):
140 """Transition rate function for the model."""
141 return self._parameters["transition_rate_fn"]
143 @property
144 def incidence_matrix(self):
145 """Incidence matrix for the model."""
146 return self._parameters["incidence_matrix"]
148 @property
149 def initial_state(self):
150 """Initial state of the model."""
151 return self._parameters["initial_state"]
153 @property
154 def num_steps(self):
155 """Number of events to simulate."""
156 return self._parameters["num_steps"]
158 @property
159 def initial_time(self):
160 """Initial wall clock for the model. Sets the time scale."""
161 return self._parameters["initial_time"]
163 def compute_state(
164 self, event_list: EventList, include_final_state: bool = False
165 ) -> jax.Array:
166 """Compute state timeseries given an event list
168 Args:
169 event_list: the event list, assumed to be sorted by time.
170 include_final_state: should the final state be included in the
171 returned timeseries? If `True`, then the time dimension of
172 the returned tensor will be 1 greater than the length of the
173 event list. If `False` (default) these will be equal.
175 Returns:
176 A `[T, N, S]` tensor where `T` is the number of events, `N` is the
177 number of units, and `S` is the number of states.
178 """
179 return compute_state(
180 self.incidence_matrix,
181 self.initial_state,
182 event_list,
183 include_final_state,
184 )
186 def _sample_n(self, n: int, seed: int | Array | None = None) -> EventList:
187 """
188 Samples n outcomes from the continuous time state transition model.
190 Args:
191 n (int): The number of realisations of the Markov process to sample
192 (currently ignored).
193 seed (int, optional): The seed value for random number generation.
194 Defaults to None.
196 Returns:
197 Sample(s) from the continuous time state transition model.
198 """
200 key = sanitize_key(seed)
201 keys = jax.random.split(key, n)
203 def one_sample(seed):
204 return continuous_markov_simulation(
205 transition_rate_fn=broadcast_fn_to(
206 self._transition_rate_fn,
207 self._initial_state.shape[:-1],
208 ),
209 incidence_matrix=self._incidence_matrix,
210 initial_state=self._initial_state,
211 initial_time=self._initial_time,
212 num_markov_jumps=self.num_steps,
213 seed=seed,
214 )
216 outcome = jax.vmap(one_sample)(keys)
218 return outcome
220 def _log_prob(self, value: EventList) -> Array:
221 """
222 Computes the log probability of the given outcomes.
224 Args:
225 value (EventList): an EventList object representing the
226 outcomes.
228 Returns:
229 float: The log probability of the given outcomes.
230 """
231 value = jax.tree_util.tree_map(lambda x: jnp.asarray(x), value)
233 batch_shape = value.time.shape[:-1]
234 flat_shape = (
235 pymath.prod(batch_shape),
236 value.time.shape[-1],
237 )
239 value = jax.tree_util.tree_map(lambda x: x.reshape(flat_shape), value)
241 def one_log_prob(x):
242 return continuous_time_log_likelihood(
243 transition_rate_fn=broadcast_fn_to(
244 self._transition_rate_fn,
245 self._initial_state.shape[:-1],
246 ),
247 incidence_matrix=self.incidence_matrix,
248 initial_state=self.initial_state,
249 initial_time=self.initial_time,
250 event_list=x,
251 )
253 log_probs = jax.vmap(one_log_prob)(value)
255 return jnp.reshape(log_probs, batch_shape)
257 def _event_shape(self) -> EventList:
258 return EventList(
259 time=(self.num_steps,),
260 transition=(self.num_steps,),
261 unit=(self.num_steps,),
262 )
264 def _batch_shape(self) -> tuple:
265 return ()