Coverage for gemlib/distributions/discrete_time_state_transition_model.py: 96%

84 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Describes a DiscreteTimeStateTransitionModel.""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable, Sequence 

6 

7import jax 

8import jax.numpy as jnp 

9import numpy as np 

10import tensorflow_probability.substrates.jax as tfp 

11 

12from gemlib.distributions.discrete_markov import ( 

13 compute_state, 

14 discrete_markov_log_prob, 

15 discrete_markov_simulation, 

16 make_transition_prob_matrix_fn, 

17) 

18from gemlib.func_util import maybe_combine_fn 

19from gemlib.math import cumsum_np 

20from gemlib.prng_util import sanitize_key 

21from gemlib.tensor_util import broadcast_fn_to 

22from gemlib.util import ( 

23 batch_gather, 

24 transition_coords, 

25) 

26 

27Array = jax.Array 

28ArrayLike = jax.typing.ArrayLike 

29tfd = tfp.distributions 

30 

31 

32class DiscreteTimeStateTransitionModel(tfd.Distribution): 

33 """Discrete-time state transition model 

34 

35 A discrete-time state transition model assumes a population of 

36 individuals is divided into a number of mutually exclusive states, 

37 where transitions between states occur according to a Markov process. 

38 Such models are commonly found in epidemiological and ecological 

39 applications, where rapid implementation and modification is necessary. 

40 

41 This class provides a programmable implementation of the discrete-time 

42 state transition model, compatible with TensorFlow Probability. 

43 

44 

45 Example 

46 ------- 

47 A homogeneously mixing SIR model implementation:: 

48 

49 import jax.numpy as jnp 

50 from gemlib.distributions import DiscreteTimeStateTransitionModel 

51 

52 # Initial state, counts per compartment (S, I, R), for one 

53 # population 

54 initial_state = jnp.array([[99, 1, 0]], jnp.float32) 

55 

56 # Note that the incidence_matrix is treated as a static parameter, 

57 # and so is supplied as a numpy ndarray, not a jax Array. 

58 incidence_matrix = np.array( 

59 [ 

60 [-1, 0], 

61 [1, -1], 

62 [0, 1], 

63 ], 

64 dtype=np.float32, 

65 ) 

66 

67 

68 def si_rate(t, state): 

69 return 0.28 * state[:, 1] / jnp.sum(state, axis=-1) 

70 

71 

72 def ir_rate(t, state): 

73 return 0.14 

74 

75 

76 # Instantiate model 

77 sir = DiscreteTimeStateTransitionModel( 

78 transition_rate_fn=(si_rate, ir_rate), 

79 incidence_matrix=incidence_matrix, 

80 initial_state=initial_state, 

81 num_steps=100, 

82 ) 

83 

84 # One realisation of the epidemic process 

85 sim = sir.sample(seed=0) 

86 """ 

87 

88 def __init__( 

89 self, 

90 transition_rate_fn: Sequence[Callable[[float, ArrayLike], ArrayLike]] 

91 | Callable[[float, ArrayLike], tuple[ArrayLike, ...]], 

92 incidence_matrix: np.typing.NDArray, 

93 initial_state: ArrayLike, 

94 num_steps: int, 

95 initial_step: int = 0, 

96 time_delta: float = 1.0, 

97 validate_args: bool = False, 

98 allow_nan_stats: bool = True, 

99 name: str = "DiscreteTimeStateTransitionModel", 

100 ): 

101 """Initialise a discrete-time state transition model. 

102 

103 Args: 

104 transition_rate_fn: Either a sequence of ``R`` callables 

105 ``(t: float, state: ArrayLike) -> ArrayLike``, each returning a 

106 Array of rates of shape ``(N,)``, or a callable 

107 ``(t: float, state: ArrayLike) -> tuple[ArrayLike, ...]``, 

108 returning a tuple of ``R`` vectors of shape ``(N,)``. 

109 Here ``N`` is the number of units in the simulation. 

110 In the first (preferred) form, each callable returns the respective 

111 transition rate. 

112 In the second form, the single callable returns ``R`` transition 

113 rates. 

114 **Note**: the second form will be 

115 deprecated in future releases of ``gemlib``. 

116 incidence_matrix: incidence matrix of shape ``(S, R)`` 

117 for ``S`` states and ``R`` transitions between states. 

118 initial_state: Initial state of the model, of shape 

119 ``(N,S)`` containing the counts in each of ``N`` units and 

120 ``S`` states. 

121 We require ``initial_state.shape[-1]==incidence_matrix.shape[-2]`` 

122 num_steps: the number of time steps simulated by the model. 

123 initial_step: time ``t`` at the start of the simulation. 

124 time_delta: the duration of the discretized time step. 

125 """ 

126 parameters = dict(locals()) 

127 

128 self._incidence_matrix = np.asarray(incidence_matrix) 

129 

130 self._source_states = _compute_source_states(self._incidence_matrix) 

131 

132 initial_state = jnp.asarray(initial_state) 

133 self._transition_prob_matrix_fn = make_transition_prob_matrix_fn( 

134 broadcast_fn_to( 

135 maybe_combine_fn(transition_rate_fn), 

136 jnp.shape(initial_state)[:-1], 

137 ), 

138 time_delta, 

139 self._incidence_matrix, 

140 ) 

141 super().__init__( 

142 dtype=initial_state.dtype, 

143 reparameterization_type=tfd.FULLY_REPARAMETERIZED, 

144 validate_args=validate_args, 

145 allow_nan_stats=allow_nan_stats, 

146 parameters=parameters, 

147 name=name, 

148 ) 

149 

150 @property 

151 def transition_rate_fn(self): 

152 return self._parameters["transition_rate_fn"] 

153 

154 @property 

155 def incidence_matrix(self): 

156 return self._parameters["incidence_matrix"] 

157 

158 @property 

159 def initial_state(self): 

160 return self._parameters["initial_state"] 

161 

162 @property 

163 def initial_step(self): 

164 return self._parameters["initial_step"] 

165 

166 @property 

167 def source_states(self): 

168 return self._source_states 

169 

170 @property 

171 def time_delta(self): 

172 return self._parameters["time_delta"] 

173 

174 @property 

175 def num_steps(self): 

176 return self._parameters["num_steps"] 

177 

178 @property 

179 def num_units(self): 

180 return jnp.array(self._parameters["initial_state"]).shape[-2] 

181 

182 @property 

183 def num_states(self): 

184 return jnp.array(self._parameters["initial_state"]).shape[-1] 

185 

186 def _batch_shape(self): 

187 return [] 

188 

189 def _event_shape(self): 

190 shape = ( 

191 int(self.num_steps), # T 

192 jnp.shape(self.initial_state)[-2], # N 

193 jnp.shape(self.incidence_matrix)[-1], # S 

194 ) 

195 return shape 

196 

197 def compute_state( 

198 self, events: Array, include_final_state: bool = False 

199 ) -> jax.Array: 

200 """Computes a state timeseries from a sequence of transition events. 

201 

202 Args: 

203 events: an array of events of shape ``(T,N,R)`` where 

204 ``T`` is the number of timesteps, 

205 ``N=self.num_units`` the number of units and 

206 ``R=self.incidence_matrix.shape[0]`` 

207 the number of transitions. 

208 

209 include_final_state: If ``False`` (default) 

210 the result does not include the final state. 

211 Otherwise the result includes the final state. 

212 

213 Returns: 

214 An array of shape ``(T, N, S)`` if 

215 ``include_final_state==False`` or ``(T+1, N, S)`` 

216 if ``include_final_state==True``, where ``S=self.num_states``, 

217 giving the number of individuals in each state at each time 

218 point for each unit. 

219 """ 

220 return compute_state( 

221 incidence_matrix=self._incidence_matrix, 

222 initial_state=self.initial_state, 

223 events=events, 

224 closed=include_final_state, 

225 ) 

226 

227 def transition_prob_matrix(self, events: None | ArrayLike) -> Array: 

228 """Compute the Markov transition probability matrix. 

229 

230 Args: 

231 events: None, or an array of shape ``(T, N, R)``, where 

232 ``T`` is the number of timesteps, 

233 ``N=self.num_units` the number of units and 

234 ``R=self.incidence_matrix.shape[0]`` 

235 the number of transitions. 

236 

237 Returns: 

238 Transition probabilty matrix. If ``events`` is None, 

239 this matrix is of shape ``(N, S, S)``, transition probability 

240 matrix associated with the initial state 

241 (``self.initial_state``), of shape ``(N, S)``, where 

242 ``S=self.num_states`` is the number of distinct states. 

243 Otherwise, this matrix is of shape ``(T, N, S, S)``, 

244 representing the transition probability matrix at each timestep. 

245 of shape ``(T, N, S, S)``. 

246 """ 

247 if events is None: 

248 return self._transition_prob_matrix_fn( 

249 self.initial_step, self.initial_state 

250 ) 

251 

252 state = self.compute_state(events) 

253 times = jnp.arange( 

254 self.initial_step, 

255 self.initial_step + self.time_delta * self.num_steps, 

256 self.time_delta, 

257 ) 

258 

259 return jax.vmap(self._transition_prob_matrix_fn)(times, state) 

260 

261 def _sample_n(self, n: int, seed: int | Array | None = None) -> Array: 

262 """Runs `n` simulations of the epidemic model. 

263 

264 Args: 

265 n: number of simulations to run. 

266 seed: an integer or a JAX PRNG key. 

267 

268 Returns: 

269 an array of shape ``(n, T, N, R)`` 

270 for ``n`` samples, ``T=self.num_steps`` time steps, 

271 ``N=self.num_units`` units and ``R=self.transition_matrix.shape[0]`` 

272 transitions, containing the 

273 number of events for each timestep, unit and transition. 

274 """ 

275 key = sanitize_key(seed) 

276 keys = jax.random.split(sanitize_key(key), num=n) 

277 

278 def one_sample(key): 

279 _, events = discrete_markov_simulation( 

280 transition_prob_matrix_fn=self._transition_prob_matrix_fn, 

281 state=self.initial_state, 

282 start=self.initial_step, 

283 end=self.initial_step + self.num_steps * self.time_delta, 

284 time_step=self.time_delta, 

285 seed=key, 

286 ) 

287 return events 

288 

289 sim = jax.vmap(one_sample)(keys) 

290 indices = transition_coords(self._incidence_matrix) 

291 

292 return batch_gather(sim, indices) 

293 

294 def _log_prob(self, y): 

295 y = jnp.asarray(y) 

296 

297 batch_shape = y.shape[:-3] 

298 y_flat = y.reshape((-1,) + y.shape[-3:]) 

299 

300 def one_log_prob(y): 

301 return discrete_markov_log_prob( 

302 events=y, 

303 init_state=jnp.asarray(self.initial_state, dtype=y.dtype), 

304 init_step=jnp.asarray(self.initial_step, dtype=y.dtype), 

305 time_delta=jnp.asarray(self.time_delta, dtype=y.dtype), 

306 transition_prob_matrix_fn=self._transition_prob_matrix_fn, 

307 incidence_matrix=self._incidence_matrix, 

308 ) 

309 

310 log_probs = jax.vmap(one_log_prob)(y_flat) 

311 return log_probs.reshape(batch_shape) 

312 

313 

314# Unclear why we don't just use gemlib.utils.transition_coords 

315def _compute_source_states( 

316 incidence_matrix: np.ndarray, dtype=np.int32 

317) -> np.ndarray: 

318 """Computes the indices of the source states for each 

319 transition in a state transition model. 

320 

321 Args: 

322 incidence_matrix (jax.Array): incidence matrix of shape ``(S, R)`` 

323 for ``S`` states and ``R`` transitions between states. 

324 

325 Returns: 

326 jax.Array: an array of shape ``(R,)`` containing indices of source 

327 states. 

328 """ 

329 incidence_matrix = np.transpose(incidence_matrix) 

330 source_states = np.sum( 

331 cumsum_np( 

332 np.clip(-incidence_matrix, a_min=0, a_max=1), 

333 axis=-1, 

334 reverse=True, 

335 exclusive=True, 

336 ), 

337 axis=-1, 

338 ) 

339 return source_states.astype(dtype)