Coverage for gemlib/mcmc/discrete_time_state_transition_model/move_events_impl.py: 96%

136 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Constrained Event Move kernel for DiscreteTimeStateTransitionModel""" 

2 

3from collections.abc import Callable 

4from typing import NamedTuple, TypeVar 

5from warnings import warn 

6 

7import jax.numpy as jnp 

8import numpy as np 

9import tensorflow_probability.substrates.jax as tfp 

10 

11from gemlib.distributions.discrete_markov import compute_state 

12from gemlib.distributions.kcategorical import UniformKCategorical 

13from gemlib.distributions.uniform_integer import UniformInteger 

14from gemlib.mcmc.sampling_algorithm import ( 

15 Position, 

16) 

17from gemlib.prng_util import sanitize_key 

18from gemlib.util import transition_coords 

19 

20tfd = tfp.distributions 

21mcmc_util = tfp.mcmc.internal.util 

22 

23__all__ = ["UncalibratedEventTimesUpdate"] 

24 

25Tensor = TypeVar("Tensor") 

26 

27 

28# Proposal mechanism 

29def events_state_count_bounding_fn(dmax): 

30 """Compute upper bound for number of possible event moves 

31 

32 Args 

33 ---- 

34 dmax: maximum possible extent in time 

35 events: a [T, M] tensor of transition events 

36 src_state: a [T, M] tensor of corresponding source state 

37 dest_state: a [T, M] tensor of corresponing destination state 

38 timepoints: a [M] tensor of timepoints 

39 delta_t: a [M] tensor of time deltas where `abs(delta_t) <= dmax` 

40 

41 Returns 

42 ------- 

43 A [M] tensor of the maximum number of transition events 

44 that can be moved from `time` to `time+delta` without 

45 violating `src_state >= 0` and `dest_state >= 0`. 

46 """ 

47 

48 def fn(events, src_state, dest_state, timepoints, delta): 

49 # Compute indices of elements in src_state and dest_state to gather 

50 timerange = jnp.arange(dmax, dtype=np.int32) 

51 candidate_times = jnp.where( 

52 delta[:, np.newaxis] < 0, 

53 timepoints[:, np.newaxis] - timerange[np.newaxis, :], 

54 timepoints[:, np.newaxis] + timerange[np.newaxis, :] + 1, 

55 ) # [M, dmax] 

56 # Clip time indices to ensure we don't access elements outside the time 

57 # extent 

58 candidate_times = jnp.clip( 

59 candidate_times, min=0, max=events.shape[-2] - 1 

60 ) 

61 

62 # Compute tensor of indices into src_state and dest_state 

63 src_dest_indices = ( 

64 candidate_times, 

65 jnp.broadcast_to( 

66 jnp.arange(events.shape[-1])[:, np.newaxis], 

67 candidate_times.shape, 

68 ), # Make unit indices broadcast over rows of time_window 

69 ) 

70 

71 # Gather 

72 src_max = src_state[src_dest_indices] 

73 dest_max = dest_state[src_dest_indices] 

74 

75 # Pick out required bounds 

76 state_max = jnp.where(delta[:, np.newaxis] < 0, src_max, dest_max) 

77 state_max = jnp.min(state_max, axis=-1) 

78 

79 # Take min of bounds and available events 

80 timepoint_events = events[timepoints, np.arange(events.shape[-1])] 

81 bound = jnp.minimum(state_max, timepoint_events) 

82 

83 return bound 

84 

85 return fn 

86 

87 

88def _is_within(x, low, high): 

89 """Tests `low <= x < high`""" 

90 return jnp.logical_and(jnp.greater_equal(x, low), jnp.less(x, high)) 

91 

92 

93def _timepoint_selector(events, name): 

94 has_events = (events > 0.0).astype(events.dtype) 

95 timepoint_probs = (has_events / has_events.sum(axis=-2, keepdims=True)).T 

96 

97 return tfd.Independent( 

98 tfd.Categorical(probs=timepoint_probs), 

99 reinterpreted_batch_ndims=1, 

100 name=name, 

101 ) 

102 

103 

104def _delta_selector(delta_max, timepoints, events, name): 

105 candidate_deltas = jnp.concatenate( 

106 [ 

107 np.arange(-delta_max, 0, dtype=np.int32), 

108 np.arange(1, delta_max + 1, dtype=np.int32), 

109 ], 

110 axis=0, 

111 ) 

112 valid_candidates = _is_within( # ensure timepoint + delta is valid 

113 candidate_deltas[..., np.newaxis, :] + timepoints[..., np.newaxis], 

114 low=0, 

115 high=events.shape[-2], 

116 ).astype(events.dtype) 

117 

118 probs = valid_candidates / valid_candidates.sum(axis=-1, keepdims=True) 

119 

120 return tfd.Independent( 

121 tfd.FiniteDiscrete(outcomes=candidate_deltas, probs=probs), 

122 reinterpreted_batch_ndims=1, 

123 name=name, 

124 ) 

125 

126 

127def _make_event_count_selector(src_dest_ids, count_bounding_fn, count_max): 

128 def fn(events, state, timepoints, delta): 

129 bound = count_bounding_fn( 

130 events=events, 

131 src_state=state[:, :, src_dest_ids[0]], 

132 dest_state=state[:, :, src_dest_ids[1]], 

133 timepoints=timepoints, 

134 delta=delta, 

135 ) 

136 

137 bound = jnp.minimum(bound, count_max) # Upper bound on number of events 

138 

139 # Select number of events to move 

140 return tfd.Independent( 

141 UniformInteger( 

142 low=np.int32(1), 

143 high=bound.astype(np.int32) + 1, 

144 float_dtype=events.dtype, 

145 ), 

146 reinterpreted_batch_ndims=1, 

147 name="event_count", 

148 ) 

149 

150 return fn 

151 

152 

153def discrete_move_events_proposal( 

154 incidence_matrix: Tensor, 

155 target_transition_id: int, 

156 num_units: int, 

157 delta_max: int, 

158 count_max: int, 

159 initial_conditions: Tensor, 

160 events: Tensor, 

161 count_bounding_fn: Callable[ 

162 [Tensor, Tensor, Tensor, Tensor, Tensor], Tensor 

163 ], 

164 name=None, 

165): 

166 """Build a proposal mechanism for moving discrete transition events 

167 

168 Args 

169 ---- 

170 incidence_matrix: a `[S,R]` matrix representing the state transition model 

171 target_transition_id: the id (column) of the transition to update 

172 num_units: number of units to update simultaneously 

173 delta_max: maximum time delta through which to move transitions 

174 count_max: absolute maximum number of transition events to move 

175 initial_conditions: a `[M,S]` matrix representing `M` units and `S` states 

176 events: a `[T, M, R]` tensor giving the number of events that occur for each 

177 time $t=0,...,T-1$, unit $m=0,...,M-1$, and transition $r=0,...,R-1$ 

178 count_bounding_fn: a callable taking `events`, `src_state`, `dest_state`, 

179 `timepoints`, and `delta_t` and returning tensor of 

180 length `num_units` of the maximum possible events that 

181 can be moved. 

182 

183 Returns 

184 ------- 

185 An instance of `tfd.JointDistributionCoroutine` which samples a proposal 

186 """ 

187 events_ = jnp.asarray(events) 

188 initial_conditions_ = jnp.asarray(initial_conditions) 

189 

190 def proposal(): 

191 target_events = events_[..., target_transition_id] 

192 state = compute_state(initial_conditions_, events, incidence_matrix) 

193 

194 src_dest_ids = transition_coords(incidence_matrix)[target_transition_id] 

195 

196 # Select units to update if they have events 

197 nonzero_units = jnp.any(target_events > 0, axis=-2) 

198 units = yield tfd.JointDistribution.Root( 

199 UniformKCategorical( 

200 k=num_units, 

201 mask=nonzero_units, 

202 float_dtype=events_.dtype, 

203 name="unit", 

204 ) 

205 ) 

206 

207 # Select out the units of interest 

208 events_subset = target_events[..., units] 

209 state_subset = state[..., units, :] 

210 

211 # Select timepoint to update 

212 timepoints = yield _timepoint_selector(events_subset, "timepoint") 

213 

214 # Select distance to move, clipping if a delta would move us out of 

215 # the valid range [0, events_subset.shape[-2]) 

216 delta = yield _delta_selector( 

217 delta_max, timepoints, events_subset, "delta" 

218 ) 

219 

220 # Compute number of events to move 

221 count_selector = _make_event_count_selector( 

222 src_dest_ids, count_bounding_fn, count_max 

223 ) 

224 yield count_selector(events_subset, state_subset, timepoints, delta) 

225 

226 return tfd.JointDistributionCoroutine(proposal, name=name) 

227 

228 

229class EventTimesKernelResults(NamedTuple): 

230 log_acceptance_correction: float 

231 target_log_prob: float 

232 fwd_proposed_move: NamedTuple 

233 rev_proposed_move: NamedTuple 

234 seed: list[int] 

235 

236 

237def _apply_move(event_tensor, event_id, move): 

238 """Apply `move` to `event_tensor` 

239 

240 Args 

241 ---- 

242 event_tensor: shape [T, M, R] 

243 event_id: the event id to move 

244 move: a data structure with fields ["unit", "timepoint", "delta", 

245 "counts"] 

246 

247 Returns 

248 ------- 

249 a copy of `event_tensor` with `move` applied 

250 """ 

251 event_tensor = jnp.array(event_tensor) 

252 

253 # Subtract `count` from the `event_tensor[timepoint, :, event_id]` 

254 new_state = event_tensor.at[move.timepoint, move.unit, event_id].subtract( 

255 move.event_count 

256 ) 

257 

258 # Add `count` to [move.timpoint+delta, :, event_id] 

259 new_state = new_state.at[ 

260 move.timepoint + move.delta, move.unit, event_id 

261 ].add(move.event_count) 

262 

263 return new_state 

264 

265 

266def _reverse_move(move): 

267 return move._replace( 

268 timepoint=move.timepoint + move.delta, delta=-move.delta 

269 ) 

270 

271 

272class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel): 

273 """UncalibratedEventTimesUpdate""" 

274 

275 def __init__( 

276 self, 

277 target_log_prob_fn: Callable[[Position], float], 

278 incidence_matrix: Tensor, 

279 initial_conditions: Tensor, 

280 target_transition_id: int, 

281 num_units: int, 

282 delta_max: int, 

283 count_max: int, 

284 name: str | None = None, 

285 ): 

286 """An uncalibrated random walk for event times. 

287 :param target_log_prob_fn: the log density of the target distribution 

288 :param target_event_id: the position in the first dimension of the 

289 events tensor that we wish to move 

290 :param prev_event_id: the position of the previous event in the events 

291 tensor 

292 :param next_event_id: the position of the next event in the events 

293 tensor 

294 :param initial_state: the initial state tensor 

295 :param seed: a random seed 

296 :param name: the name of the update step 

297 """ 

298 self._name = name 

299 self._parameters = { 

300 "target_log_prob_fn": target_log_prob_fn, 

301 "incidence_matrix": incidence_matrix, 

302 "initial_conditions": initial_conditions, 

303 "target_transition_id": target_transition_id, 

304 "num_units": num_units, 

305 "delta_max": delta_max, 

306 "count_max": count_max, 

307 "name": name, 

308 } 

309 

310 @property 

311 def target_log_prob_fn(self): 

312 return self._parameters["target_log_prob_fn"] 

313 

314 @property 

315 def incidence_matrix(self): 

316 return self._parameters["incidence_matrix"] 

317 

318 @property 

319 def initial_conditions(self): 

320 return self._parameters["initial_conditions"] 

321 

322 @property 

323 def target_transition_id(self): 

324 return self._parameters["target_transition_id"] 

325 

326 @property 

327 def num_units(self): 

328 return self._parameters["num_units"] 

329 

330 @property 

331 def delta_max(self): 

332 return self._parameters["delta_max"] 

333 

334 @property 

335 def count_max(self): 

336 return self._parameters["count_max"] 

337 

338 @property 

339 def name(self): 

340 return self._parameters["name"] 

341 

342 @property 

343 def parameters(self): 

344 """Return `dict` of ``__init__`` arguments and their values.""" 

345 return self._parameters 

346 

347 @property 

348 def is_calibrated(self): 

349 return False 

350 

351 def _proposal(self, events): 

352 return discrete_move_events_proposal( 

353 incidence_matrix=self.incidence_matrix, 

354 target_transition_id=self.target_transition_id, 

355 num_units=self.num_units, 

356 delta_max=self.delta_max, 

357 count_max=self.count_max, 

358 initial_conditions=self.initial_conditions, 

359 events=events, 

360 count_bounding_fn=events_state_count_bounding_fn(self.delta_max), 

361 name=self.name, 

362 ) 

363 

364 def one_step(self, current_events, previous_kernel_results, seed): # noqa: ARG002 

365 """One update of event times. 

366 :param current_events: a [T, M, X] tensor containing number of events 

367 per time t, metapopulation m, 

368 and transition x. 

369 :param previous_kernel_results: an object of type 

370 UncalibratedRandomWalkResults. 

371 :returns: a tuple containing new_state and UncalibratedRandomWalkResults 

372 """ 

373 seed = sanitize_key(seed) 

374 step_events = current_events 

375 if mcmc_util.is_list_like(current_events): 

376 step_events = current_events[0] 

377 warn( 

378 "Batched event times updates are not supported. Using \ 

379 first event item only.", 

380 stacklevel=2, 

381 ) 

382 fwd_proposal_dist = self._proposal(step_events) 

383 fwd_proposed_move = fwd_proposal_dist.sample(seed=seed) 

384 fwd_prob = fwd_proposal_dist.log_prob(fwd_proposed_move) 

385 

386 # Propagate the proposal into events 

387 proposed_events = _apply_move( 

388 event_tensor=step_events, 

389 event_id=self.target_transition_id, 

390 move=fwd_proposed_move, 

391 ) 

392 

393 next_target_log_prob = self.target_log_prob_fn(proposed_events) 

394 

395 # Compute the reverse move 

396 rev_proposed_move = _reverse_move(fwd_proposed_move) 

397 

398 rev_prob = self._proposal(proposed_events).log_prob(rev_proposed_move) 

399 log_acceptance_correction = rev_prob - fwd_prob 

400 

401 if mcmc_util.is_list_like(current_events): 

402 proposed_events = [proposed_events] 

403 

404 return ( 

405 proposed_events, 

406 EventTimesKernelResults( 

407 log_acceptance_correction=log_acceptance_correction, 

408 target_log_prob=next_target_log_prob, 

409 fwd_proposed_move=fwd_proposed_move, 

410 rev_proposed_move=rev_proposed_move, 

411 seed=seed, 

412 ), 

413 ) 

414 

415 def bootstrap_results(self, init_state): 

416 if mcmc_util.is_list_like(init_state): 

417 init_state = init_state[0] 

418 warn( 

419 "Batched event times updates are not supported. Using \ 

420 first event item only.", 

421 stacklevel=2, 

422 ) 

423 

424 init_target_log_prob = self.target_log_prob_fn(init_state) 

425 

426 proposal_example = self._proposal(init_state).sample( 

427 seed=sanitize_key(0) 

428 ) 

429 

430 return EventTimesKernelResults( 

431 log_acceptance_correction=jnp.zeros_like(init_target_log_prob), 

432 target_log_prob=init_target_log_prob, 

433 fwd_proposed_move=proposal_example, 

434 rev_proposed_move=proposal_example, 

435 seed=sanitize_key(0), 

436 )