Coverage for gemlib/distributions/discrete_markov.py: 99%

107 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Functions for chain binomial simulation.""" 

2 

3from collections.abc import Callable, Iterable 

4 

5import jax 

6import jax.numpy as jnp 

7import numpy as np 

8import tensorflow_probability.substrates.jax as tfp 

9 

10from gemlib.math import multiply_no_nan 

11from gemlib.util import transition_coords_tuple 

12 

13tfd = tfp.distributions 

14 

15__all__ = [ 

16 "make_transition_prob_matrix_fn", 

17 "compute_state", 

18 "discrete_markov_log_prob", 

19 "discrete_markov_simulation", 

20] 

21 

22 

23# scatter to transition matrix 

24def _scatter_to_transition_matrix( 

25 rates: Iterable[jax.Array], 

26 rate_coords: tuple[tuple[int, int], ...], 

27 num_states: int, 

28) -> jax.Array: 

29 """Build an un-normalised Markov transition rate matrix from rates and 

30 coordinates 

31 

32 Args: 

33 rates (Iterable[jax.Array]): An iterable (tuple, list, ...) of ``R`` 

34 jax.Arrays containing transition rates. 

35 rate_coords (tuple[tuple[int, int],...]): A tuple of 

36 ``(from_state, to_state)`` coordinate pairs for each of the ``R`` 

37 rates. 

38 num_states (int): The number of possible states ``S``. 

39 

40 Returns: 

41 jax.Array: An array of shape ``rates[0].shape + (S, S)`` representing 

42 the un-normalised transition rate matrix. 

43 """ 

44 

45 # The each shape is (...,S) for S compartments 

46 # We're going to scatter the rates into a (S,S,...) data structure 

47 # to avoid strided scatter 

48 matrix_shape = (num_states, num_states) + rates[0].shape 

49 

50 output = jnp.zeros(matrix_shape, dtype=rates[0].dtype) 

51 

52 # We iterate over each rate vector and coordinate pair, inserting the 

53 # transition rates into the coord-th slice of the right-hand two "S" 

54 # dimensions of ``output``.rate_coords 

55 for rate, coord in zip(rates, rate_coords, strict=True): 

56 output = output.at[coord[-2], coord[-1], ...].set( 

57 rate, 

58 indices_are_sorted=True, 

59 unique_indices=True, 

60 mode="promise_in_bounds", 

61 ) 

62 

63 # Transform from (S,S,...) to (...,S,S) 

64 return jnp.moveaxis(output, [0, 1], [-2, -1]) 

65 

66 

67def _approx_expm(rates: jax.Array) -> jax.Array: 

68 """Approximates a full Markov transition matrix from a rate matrix 

69 through first-order approximation of the matrix exponential 

70 

71 Args: 

72 rates (jax.Array): un-normalised square rate matrix (i.e. diagonal zero) 

73 or batch of such matrices. Accepts shapes ``(B, S, S)`` 

74 or ``(S, S)`` 

75 

76 Returns: 

77 jax.Array: Approximation to Markov transition matrix of same shape as 

78 ``rates`` 

79 """ 

80 

81 total_rates = jnp.sum(rates, axis=-1, keepdims=True) 

82 

83 prob = 1.0 - jnp.exp(-total_rates) 

84 

85 partial_matrix = multiply_no_nan(prob / total_rates, rates) 

86 

87 diagonal_values = 1.0 - jnp.sum(partial_matrix, axis=-1) 

88 batched_dimensions = 3 

89 if rates.ndim == batched_dimensions: 

90 # Batched inputs. 

91 mask = jnp.eye(rates.shape[-1], dtype=bool) 

92 mask = jnp.broadcast_to(mask, rates.shape) 

93 else: 

94 # Single input. 

95 mask = jnp.eye(rates.shape[-1], dtype=bool) 

96 

97 return jnp.where(mask, diagonal_values[..., None], partial_matrix) 

98 

99 

100def make_transition_prob_matrix_fn( 

101 transition_rate_fn: Callable[[float, jax.Array], tuple[jax.Array, ...]], 

102 time_delta: float, 

103 incidence_matrix: np.ndarray, 

104) -> Callable[[float, jax.Array], jax.Array]: 

105 """Generates a function for computing the discrete-time 

106 Markov transition matrix over a given time interval. 

107 

108 Args: 

109 transition_rate_fn (Callable[[float, jax.Array], tuple[jax.Array, ...]]) 

110 : Function which takes a scalar time ``t`` and ``state`` of shape 

111 ``(N, S)`` to transition rates as a tuple of ``R`` arrays of 

112 shape ``(N,)``. 

113 time_delta (float): Size of the time step. 

114 incidence_matrix (jax.Array): Array of shape ``(S, R)`` describing 

115 allowed transitions between compartments. 

116 

117 Returns: 

118 Callable[[float, jax.Array], jax.Array]: A function that takes a scalar 

119 ``time`` and state of shape ``(N,S)`` to a Markov transition matrix 

120 of shape ``(N, S, S)``. 

121 """ 

122 

123 rate_coords = transition_coords_tuple(incidence_matrix) 

124 

125 def fn(t: float, state: jax.Array) -> jax.Array: 

126 rates = transition_rate_fn(t, state) 

127 

128 rate_matrix = _scatter_to_transition_matrix( 

129 rates, rate_coords, state.shape[-1] 

130 ) 

131 # ``rate_matrix`` is a tensor of shape 

132 # ``(N, S, S)`` where ``rate_matrix[n, i, j]`` 

133 # gives the rate for the transition from state ``i`` to 

134 # state ``j`` in unit ``n``, for ``i``!=``j``. 

135 

136 transition_matrix = _approx_expm(rate_matrix * time_delta) 

137 return transition_matrix 

138 

139 return fn 

140 

141 

142@jax.custom_vjp 

143def _multinomial_log_prob(total_count, probs, counts): 

144 log_unnorm_prob = jnp.sum(multiply_no_nan(jnp.log(probs), counts), axis=-1) 

145 neg_log_normalizer = tfp.math.log_combinations(total_count, counts) 

146 output = log_unnorm_prob + neg_log_normalizer 

147 return output 

148 

149 

150# Forward pass 

151def _multinomial_log_prob_fwd(total_count, probs, counts): 

152 log_unnorm_prob = jnp.sum(multiply_no_nan(jnp.log(probs), counts), axis=-1) 

153 neg_log_normalizer = tfp.math.log_combinations(total_count, counts) 

154 output = log_unnorm_prob + neg_log_normalizer 

155 return output, (probs, counts) 

156 

157 

158# Backward pass 

159def _multinomial_log_prob_bwd(res, g): 

160 probs, counts = res 

161 counts = jnp.broadcast_to(counts, probs.shape) 

162 g = jnp.expand_dims(g, axis=-1) 

163 grad_probs = jnp.nan_to_num(1.0 / probs) * counts * g 

164 return (None, grad_probs, None) 

165 

166 

167# Registering the custom forward and backward functions for 

168# _multinomial_log_prob 

169_multinomial_log_prob.defvjp( 

170 _multinomial_log_prob_fwd, _multinomial_log_prob_bwd 

171) 

172 

173 

174def compute_state( 

175 initial_state: jax.Array, 

176 events: jax.Array, 

177 incidence_matrix: np.ndarray, 

178 closed: bool = False, 

179): 

180 """Compute the state array at multiple points in time from the initial state 

181 and event array. 

182 

183 Args: 

184 initial_state (jax.Array): Initial state with shape ``(N, S)``. 

185 events (jax.Array): Time series of events with shape ``(T, N, R)``. 

186 incidence_matrix (jax.Array): a matrix with shape ``(S, R)`` describing 

187 how transitions update the state. 

188 closed (bool): if ``True``, return state at ``0..T``, otherwise 

189 ``0..T-1``. 

190 

191 Returns: 

192 jax.Array: Array of shape ``(T, N, S)`` if ``closed=False`` or 

193 ``(T+1, N, S)`` if ``closed=True``, describing the state of the 

194 system at times ``0..T`` or ``0..T-1``, respectively. 

195 """ 

196 increments = jnp.einsum("...tmr,sr->...tms", events, incidence_matrix) 

197 

198 if not closed: 

199 padding = jnp.zeros_like(increments[..., :1, :, :]) 

200 cum_increments = jnp.concatenate( 

201 (padding, jnp.cumsum(increments[..., :-1, :, :], axis=-3)), axis=-3 

202 ) 

203 else: 

204 padding = jnp.zeros_like(increments[..., :1, :, :]) 

205 cum_increments = jnp.concatenate( 

206 (padding, jnp.cumsum(increments, axis=-3)), axis=-3 

207 ) 

208 state = cum_increments + jnp.expand_dims(initial_state, axis=-3) 

209 return state 

210 

211 

212def chain_binomial_propagate( 

213 transition_matrix_fn: Callable[[float, jax.Array], jax.Array], 

214) -> Callable[[float, jax.Array, jax.Array], tuple[jax.Array, jax.Array]]: 

215 """Propagates the state of a population according to discrete time dynamics. 

216 

217 Args: 

218 transition_matrix_fn: a function (t: float, state: Array) -> Array 

219 which takes a time ``t`` and state of shape ``(N, S)``, and 

220 returns a Markov transition probability matrix of shape 

221 ``(N, S, S)``. 

222 A suitable function is generated by make_transition_prob_matrix_fn. 

223 

224 Returns: 

225 A function (t: float, state: jax.Array, seed: int) -> 

226 tuple[jax.Array, jax.Array] that propagates a state of shape 

227 ``(N, S)`` at time ``t``, returning an ``events`` matrix of shape 

228 ``(N, S, S)`` and the state at the next time step. 

229 """ 

230 

231 def propagate_fn(t, state, seed): 

232 markov_transition_matrix = transition_matrix_fn(t, state) 

233 num_states = markov_transition_matrix.shape[-1] 

234 prev_probs = jnp.zeros_like(markov_transition_matrix[..., :, 0]) 

235 counts = jnp.zeros( 

236 markov_transition_matrix.shape[:-1] + (0,), 

237 dtype=markov_transition_matrix.dtype, 

238 ) 

239 

240 total_count = state.astype(markov_transition_matrix.dtype) 

241 

242 # Generates num_states -1 independent random number generator keys. 

243 keys = jax.random.split(seed, num_states - 1) 

244 

245 for i in range(num_states - 1): 

246 probs = markov_transition_matrix[..., :, i] 

247 binom = tfd.Binomial( 

248 total_count=total_count, 

249 probs=jnp.clip(probs / (1.0 - prev_probs + 1e-10), 0.0, 1.0), 

250 ) 

251 sample = binom.sample(seed=keys[i]) 

252 counts = jnp.concatenate( 

253 [counts, sample[..., jnp.newaxis]], axis=-1 

254 ) 

255 total_count -= sample 

256 prev_probs += probs 

257 

258 # Final state 

259 counts = jnp.concatenate( 

260 [counts, total_count[..., jnp.newaxis]], axis=-1 

261 ) 

262 

263 # Aggregate new state 

264 new_state = jnp.sum(counts, axis=-2) 

265 

266 return counts, new_state 

267 

268 return propagate_fn 

269 

270 

271def discrete_markov_simulation( 

272 transition_prob_matrix_fn: Callable[[float, jax.Array], jax.Array], 

273 state: jax.Array, 

274 start: float, 

275 end: float, 

276 time_step: float, 

277 seed: jax.Array, 

278): 

279 """Simulates from a discrete time Markov state transition model using 

280 multinomial sampling across rows of the transition matrix. 

281 

282 Args: 

283 transition_prob_matrix_fn (Callable[[float, jax.Array], jax.Array]): 

284 A function that takes a scalar ``time`` and state of shape 

285 ``(N, S)`` to a Markov transition matrix of shape ``(N, S, S)``. 

286 Here ``N`` is the number of units and ``S`` the dimension of the 

287 state of each unit. 

288 state (jax.Array): Initial state of shape ``(N, S)``. 

289 start (float): Start time of simulation. 

290 end (float): End time of simulation. 

291 time_step (float): Duration of discrete timestep. 

292 seed (jax.Array) : Random seed for sampling. 

293 

294 Returns: 

295 tuple[jax.Array, jax.Array]: A vector of shape ``(T,)`` of the times 

296 of the start of each timestep, and an array of shape 

297 ``(T, N, S, S)`` containing the transitions within each timestep. 

298 """ 

299 

300 state = jnp.asarray(state) 

301 if state.ndim == 1: 

302 state = state[None, ...] 

303 

304 propagate = chain_binomial_propagate(transition_prob_matrix_fn) 

305 

306 times = jnp.arange(start, end, time_step, dtype=state.dtype) 

307 keys = jax.random.split(seed, times.shape[0]) 

308 

309 def scan_fn(state, elems): 

310 t, current_key = elems 

311 event_counts, new_state = propagate(t, state, current_key) 

312 new_state = new_state.astype(state.dtype) 

313 return new_state, event_counts 

314 

315 final_state, all_events = jax.lax.scan(scan_fn, state, (times, keys)) 

316 

317 return times, all_events 

318 

319 

320def discrete_markov_log_prob( 

321 events: jax.Array, 

322 init_state: jax.Array, 

323 init_step: float, 

324 time_delta: float, 

325 transition_prob_matrix_fn: Callable[[float, jax.Array], jax.Array], 

326 incidence_matrix: np.ndarray, 

327) -> jax.Array: 

328 """Calculates an unnormalised log_prob function for a discrete time epidemic 

329 model. 

330 

331 Args: 

332 events (jax.Array): a ``(T, N, R)`` array of transition events, 

333 where the simulation has ``T`` times, ``N`` units and ``R`` 

334 transitions. 

335 init_state (jax.Array): a vector of shape ``(N, S)``, giving the initial 

336 state of the epidemic, where ``S`` is the number of states. 

337 init_step (float): the initial time. 

338 time_delta (float): the duration of the discrete time step. 

339 transition_prob_matrix_fn (Callable[[float, jax.Array], jax.Array]): 

340 A function that takes a scalar time and state of shape ``(N, S)`` 

341 to a Markov transition matrix of shape ``(N, S, S)`` 

342 incidence_matrix (np.ndarray): a matrix with shape ``(S, R)`` describing 

343 how transitions update the state. 

344 Returns: 

345 jax.Array: A scalar log probability for these events given the initial 

346 state and model. 

347 """ 

348 

349 num_times = events.shape[-3] 

350 num_states = init_state.shape[-1] 

351 

352 # Construct the state at all timepoints, of shape ``(T, N, S)`` 

353 state_timeseries = compute_state(init_state, events, incidence_matrix) 

354 

355 times = init_step + time_delta * jnp.arange(num_times) 

356 

357 transition_prob_matrix = jax.vmap(transition_prob_matrix_fn)( 

358 times, state_timeseries 

359 ) 

360 

361 rate_coords = transition_coords_tuple(incidence_matrix) 

362 

363 # Calculate ``event_matrix```` of shape ``(T, N, S, S)``, where 

364 # ``event_matrix[t, n, i, j]```` is the number of items moving from state 

365 # ``i```` to ``j`` at timestep ``t`` and in unit ``n``. 

366 

367 event_matrix = _scatter_to_transition_matrix( 

368 jnp.unstack(events, axis=-1), rate_coords, num_states 

369 ) 

370 

371 # Setting the diagonal 

372 diagonal = state_timeseries - jnp.sum(event_matrix, axis=-1) 

373 idx = jnp.arange(event_matrix.shape[-1]) 

374 event_matrix = event_matrix.at[..., idx, idx].set(diagonal) 

375 

376 logp = _multinomial_log_prob( 

377 state_timeseries, transition_prob_matrix, event_matrix 

378 ) 

379 return jnp.sum(logp)