Coverage for gemlib/distributions/continuous_markov.py: 100%

87 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Function for continuous time simulation""" 

2 

3from __future__ import annotations 

4 

5import math as pymath 

6from collections.abc import Callable 

7from typing import Any, NamedTuple 

8 

9import jax 

10import jax.numpy as jnp 

11import numpy as np 

12import tensorflow_probability.substrates.jax as tfp 

13 

14from gemlib.util import batch_gather, transition_coords 

15 

16# aliasing for convenience 

17ArrayLike = jax.typing.ArrayLike 

18tfd = tfp.distributions 

19DTYPE = jnp.float32 

20 

21 

22class EventList(NamedTuple): 

23 """Tracker of an event in an epidemic simulation 

24 

25 Attributes: 

26 time (float): The time at which the event occurred. 

27 transition (int): The type of transition that occurred. 

28 unit (int): The unit involved in the event. 

29 """ 

30 

31 time: Any 

32 transition: Any 

33 unit: Any 

34 

35 

36def _one_hot_expand_state(condensed_state: jax.Array) -> jax.Array: 

37 """Expand the state of the epidemic to a one-hot representation 

38 Args: 

39 epidemic_state: The state of the epidemic 

40 Returns: 

41 The one-hot representation of the epidemic state 

42 """ 

43 # Create one-hot encoded vectors for each state 

44 one_hot_states = jax.nn.one_hot( 

45 jnp.arange(len(condensed_state)), 

46 num_classes=len(condensed_state), 

47 dtype=jnp.float32, 

48 ) 

49 # Repeat each one-hot state based on its corresponding count 

50 repeated_states = jnp.repeat(one_hot_states, condensed_state, axis=0) 

51 

52 # Reshape and transpose to get state per row representation 

53 return repeated_states 

54 

55 

56def _total_flux(transition_rates, state, incidence_matrix): 

57 """Multiplies `transition_rates` by source `state`s to return 

58 the total flux along transitions given `state`. 

59 

60 Args 

61 ---- 

62 transition_rates: a `[R,N]` tensor of per-unit transition rates 

63 for `R` transitions and `N` aggregation units. 

64 state: a `[N, S]` tensor of `N` aggregation units and `S` states. 

65 incidence_matrix: a `[S, R]` matrix describing the change in `S` for 

66 each transition `R`. 

67 

68 Returns 

69 ------- 

70 A [R,N] tensor of total flux along each transition, taking into account the 

71 availability of units in the source state. 

72 """ 

73 

74 source_state_idx = transition_coords(np.array(incidence_matrix))[:, 0] 

75 source_states = batch_gather( 

76 state, indices=source_state_idx[:, jnp.newaxis] 

77 ) 

78 transition_rates = jnp.stack(transition_rates, axis=-1) 

79 

80 return jnp.einsum("...nr,...nr->...rn", transition_rates, source_states) 

81 

82 

83def compute_state( 

84 incidence_matrix: np.ndarray, 

85 initial_state: jax.Array, 

86 event_list: EventList, 

87 include_final_state: bool = False, 

88): 

89 """Given an event list `event_list`, compute a timeseries 

90 of state given the model. 

91 

92 Args 

93 ---- 

94 incidence_matrix: a `[S,R]` graph incidence matrix for `S` 

95 compartments and `R` transitions. 

96 initial_state: a `[N,S]` representing the initial state of `N` 

97 units by `S` compartments. 

98 event_list: the event list, assumed to be sorted by time. 

99 include_final_state: should the final state be included in the 

100 returned timeseries? If `True`, then the time dimension of 

101 the returned tensor will be 1 greater than the length of the 

102 event list. If `False` (default) these will be equal. 

103 

104 Return 

105 ------ 

106 A `[T, N, S]` tensor where `T` is the number of events, `N` is the 

107 number of units, and `S` is the number of states. 

108 """ 

109 event_list = event_list.__class__(*[jnp.array(x) for x in event_list]) 

110 

111 # initial_state = jnp.array(initial_state) 

112 incidence_matrix = jnp.array(incidence_matrix) 

113 

114 num_times = event_list.time.shape[-1] 

115 num_units = initial_state.shape[-2] 

116 num_reactions = incidence_matrix.shape[-1] + 1 # R + pad for ghost events 

117 

118 # The technique to handle batch dimensions in the event_list is to 

119 # flatten the batch dimensions. That way we can guarantee 

120 # that the output dense events have shape [B, T, N, R]. 

121 

122 batch_dims = event_list.time.shape[:-1] 

123 flat_batch_shape = int(pymath.prod(batch_dims)) 

124 

125 # Compute one-hot encoding of event timeseries 

126 event_tensor_shape = ( 

127 flat_batch_shape, 

128 num_times, 

129 num_units, 

130 num_reactions, 

131 ) # [B, T, N, R] 

132 

133 hot_indices = jnp.stack( 

134 [ 

135 jnp.repeat( 

136 jnp.arange(flat_batch_shape), num_times 

137 ), # B 0,0,...,1,1 

138 jnp.tile(jnp.arange(num_times), flat_batch_shape), # T 0,1,...,0,1 

139 event_list.unit.reshape(-1), # N (flat_batch_shape * num_times) 

140 event_list.transition.reshape( 

141 -1 

142 ), # R (flat_batch_shape * num_times) 

143 ], 

144 axis=-1, 

145 ) 

146 

147 event_tensor = ( 

148 jnp.zeros(event_tensor_shape, dtype=initial_state.dtype) 

149 .at[tuple(hot_indices.T)] 

150 .add(1.0)[..., :-1] 

151 ) 

152 

153 # Compute deltas and cumsum over the state 

154 delta = jnp.matmul(event_tensor, incidence_matrix.T) 

155 if include_final_state is False: 

156 delta = delta[..., :-1, :, :] 

157 

158 batched_initial_state = jnp.broadcast_to( 

159 initial_state[jnp.newaxis, jnp.newaxis, ...], # Add B and T dimensions 

160 shape=(flat_batch_shape,) + (1,) + initial_state.shape, 

161 ) 

162 state = jnp.cumsum( 

163 jnp.concatenate([batched_initial_state, delta], axis=-3), axis=-3 

164 ) 

165 

166 return jnp.reshape(state, shape=batch_dims + state.shape[-3:]) 

167 

168 

169def exponential_propogate( 

170 transition_rate_fn: Callable, incidence_matrix: jax.Array 

171) -> EventList: 

172 """Generates a function for propogating an epidemic forward in time 

173 

174 Closure over the transition rate function and the incidence matrix 

175 which outline the epidemic dynamics and model structure. The returned 

176 function can be used to simulate the epidemic forward in time one step. 

177 

178 Args: 

179 transition_rate_fn (Callable): a function that takes the current 

180 state of the epidemic and returns the transition rates for each 

181 unit/meta-population. 

182 incidence_matrix (tensor): A `[R, S]` matrix that describes the graph 

183 structure of the state transition mode. The rows correspond to the 

184 `R` transitions and the columns correspond to the `S` states. 

185 

186 Returns: 

187 EventList: A NamedTuple that describes the next event in the 

188 epidemic. 

189 """ 

190 tr_incidence_matrix = incidence_matrix.T 

191 

192 def propogate_fn(time: float, state: jax.Array, seed: int) -> list: 

193 """Propogates the state of the epidemic forward in time 

194 

195 Args: 

196 time (float): Wall clock of the epidemic - can easily recover 

197 the time delta 

198 state (tensor): `[N,S]` representing the current state. 

199 

200 Returns: 

201 EventList: The next event in the epidemic. 

202 """ 

203 seed_exp, seed_cat = tfp.random.split_seed(seed, n=2) 

204 num_units = state.shape[-2] 

205 

206 # compute event rates for all possible events 

207 transition_rates = _total_flux( 

208 transition_rate_fn(time, state), state, incidence_matrix 

209 ) 

210 

211 # simulate next time 

212 t_next = tfd.Exponential(rate=jnp.sum(transition_rates)).sample( 

213 seed=seed_exp 

214 ) 

215 

216 # use categorical distribution to get event type and indiviudal id 

217 event_id = tfd.Categorical( 

218 probs=jnp.reshape(transition_rates, shape=(-1,)), 

219 dtype=jnp.int32, 

220 ).sample(seed=seed_cat) 

221 

222 unit_idx = jnp.mod(event_id, num_units) 

223 transition_idx = jnp.floor_divide(event_id, num_units) 

224 

225 # update the state 

226 new_state = state.at[unit_idx].add(tr_incidence_matrix[transition_idx]) 

227 

228 return ( 

229 time + t_next, 

230 new_state, 

231 EventList(time + t_next, transition_idx, unit_idx), 

232 ) 

233 

234 return propogate_fn 

235 

236 

237def continuous_markov_simulation( 

238 transition_rate_fn: Callable, 

239 initial_state: jax.Array, 

240 incidence_matrix: jax.Array, 

241 num_markov_jumps: int, 

242 initial_time: float = 0.0, 

243 seed=jnp.ndarray, 

244) -> EventList: 

245 """ 

246 Simulates a continuous-time Markov process 

247 

248 Args: 

249 transition_rate_fn (Callable): A function that computes the transition 

250 rates given the current state and incidence matrix. 

251 initial_state (Tensor): A [N, S] tensor, respresenting a population of N 

252 units and S states. 

253 num_markov_jumps (int): The number of iterations to simulate. 

254 incidence_matrix (Tensor): The `[S,R]` incidence matrix representing the 

255 state transition model with S states and R transitions. 

256 seed (Optional[List(int,int)): The random seed. 

257 Returns: 

258 EventList: An object containing the simulated epidemic events. 

259 

260 #""" 

261 

262 dtype = initial_state.dtype 

263 

264 propagate_fn = exponential_propogate(transition_rate_fn, incidence_matrix) 

265 

266 accum = EventList( 

267 time=jnp.ones((num_markov_jumps,), dtype=dtype) * jnp.inf, # time 

268 transition=jnp.ones((num_markov_jumps,), dtype=jnp.int32) 

269 * incidence_matrix.shape[1], # transition 

270 unit=jnp.zeros((num_markov_jumps), dtype=jnp.int32) 

271 * jnp.iinfo(np.int32).max, # unit 

272 ) 

273 

274 def cond(args): 

275 i, time, state, seed, accum = args 

276 transition_rates = _total_flux( 

277 transition_rate_fn(time, state), state, incidence_matrix 

278 ) 

279 cont = (i < num_markov_jumps) & (jnp.sum(transition_rates) > 0.0) 

280 return cont 

281 

282 def body(args): 

283 i, time, state, seed, accum = args 

284 next_seed, subkey = jax.random.split(seed) 

285 next_time, next_state, event = propagate_fn(time, state, subkey) 

286 accum = jax.tree_util.tree_map( 

287 lambda acc, x: acc.at[i].set(x), accum, event 

288 ) 

289 

290 return ( 

291 i + 1, 

292 next_time, 

293 next_state, 

294 next_seed, 

295 accum, 

296 ) 

297 

298 actual_markov_jumps, _, _, _, accum = jax.lax.while_loop( 

299 cond, body, (0, initial_time, initial_state, seed, accum) 

300 ) 

301 

302 times, transitions, units = accum 

303 

304 # Pad unused parts of the output TensorArrays if the 

305 # loop terminates before num_markov_jumps 

306 mask = jnp.arange(num_markov_jumps) < actual_markov_jumps 

307 output = EventList( 

308 time=jnp.where(mask, times, jnp.inf), 

309 transition=jnp.where(mask, transitions, incidence_matrix.shape[1]), 

310 unit=jnp.where(mask, units, jnp.iinfo(np.int32).max), 

311 ) 

312 

313 return output 

314 

315 

316def continuous_time_log_likelihood( 

317 transition_rate_fn: Callable, 

318 incidence_matrix: jax.Array, 

319 initial_state: jax.Array, 

320 initial_time: float, 

321 event_list: EventList, 

322) -> float: 

323 """ 

324 Computes the log-likelihood of a continuous-time Markov process 

325 given the transition rate function, 

326 incidence matrix, initial state, number of jumps, and event data. 

327 

328 Args: 

329 transition_rate_fn (Callable): A function that computes the 

330 transition rate given the current state and time. 

331 incidence_matrix: The incidence matrix representing 

332 the connections between states in `[S,R]` format. 

333 initial_state: The initial state of the process as a `[N,R]`. 

334 num_jumps (int): The number of jumps to simulate. 

335 event (EventList): The event data containing the times 

336 and states. 

337 

338 Returns: 

339 Tensor: The log-likelihood of the continuous-time Markov process. 

340 """ 

341 # construct the epidemic states [T, N, S] 

342 states = compute_state( 

343 incidence_matrix=incidence_matrix, 

344 initial_state=initial_state, 

345 event_list=event_list, 

346 ) 

347 

348 # compute the transition rates for each of the states in the event 

349 time = jnp.concatenate( 

350 [jnp.array([initial_time]), event_list.time], axis=-1 

351 ) 

352 rates = jax.vmap(transition_rate_fn)( 

353 time[:-1], states 

354 ) # R-tuple of [T,N] arrays 

355 

356 total_flux = _total_flux(rates, states, incidence_matrix) 

357 

358 indices = jnp.stack( 

359 [ 

360 jnp.arange(event_list.time.shape[0]), 

361 jnp.clip( # Clip due to ghost events (values zeroed later) 

362 event_list.transition, 

363 min=0, 

364 max=incidence_matrix.shape[-1] - 1, 

365 ), 

366 event_list.unit, 

367 ], 

368 axis=-1, 

369 ) 

370 

371 # compute event specific rate - get indices of the event that happened 

372 event_rate = total_flux[indices[:, -3], indices[:, -2], indices[:, -1]] 

373 

374 # compute total rate per timestep 

375 total_rate = jnp.sum(total_flux, axis=(-2, -1)) 

376 # total_rate = tf.einsum("tns -> t", total_flux) 

377 

378 # compute time deltas 

379 time_delta = time[1:] - time[:-1] 

380 

381 # compute the log-likelihood 

382 loglik_t = -total_rate * time_delta + jnp.log(event_rate) 

383 

384 # Zero out for any inf times (i.e. possible padding of event_list chunk) 

385 loglik_t = jnp.where( 

386 jnp.isfinite(time_delta), loglik_t, jnp.zeros_like(loglik_t) 

387 ) 

388 

389 return jnp.sum(loglik_t)