Coverage for gemlib/mcmc/discrete_time_state_transition_model/move_events_impl.py: 96%
136 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Constrained Event Move kernel for DiscreteTimeStateTransitionModel"""
3from collections.abc import Callable
4from typing import NamedTuple, TypeVar
5from warnings import warn
7import jax.numpy as jnp
8import numpy as np
9import tensorflow_probability.substrates.jax as tfp
11from gemlib.distributions.discrete_markov import compute_state
12from gemlib.distributions.kcategorical import UniformKCategorical
13from gemlib.distributions.uniform_integer import UniformInteger
14from gemlib.mcmc.sampling_algorithm import (
15 Position,
16)
17from gemlib.prng_util import sanitize_key
18from gemlib.util import transition_coords
20tfd = tfp.distributions
21mcmc_util = tfp.mcmc.internal.util
23__all__ = ["UncalibratedEventTimesUpdate"]
25Tensor = TypeVar("Tensor")
28# Proposal mechanism
29def events_state_count_bounding_fn(dmax):
30 """Compute upper bound for number of possible event moves
32 Args
33 ----
34 dmax: maximum possible extent in time
35 events: a [T, M] tensor of transition events
36 src_state: a [T, M] tensor of corresponding source state
37 dest_state: a [T, M] tensor of corresponing destination state
38 timepoints: a [M] tensor of timepoints
39 delta_t: a [M] tensor of time deltas where `abs(delta_t) <= dmax`
41 Returns
42 -------
43 A [M] tensor of the maximum number of transition events
44 that can be moved from `time` to `time+delta` without
45 violating `src_state >= 0` and `dest_state >= 0`.
46 """
48 def fn(events, src_state, dest_state, timepoints, delta):
49 # Compute indices of elements in src_state and dest_state to gather
50 timerange = jnp.arange(dmax, dtype=np.int32)
51 candidate_times = jnp.where(
52 delta[:, np.newaxis] < 0,
53 timepoints[:, np.newaxis] - timerange[np.newaxis, :],
54 timepoints[:, np.newaxis] + timerange[np.newaxis, :] + 1,
55 ) # [M, dmax]
56 # Clip time indices to ensure we don't access elements outside the time
57 # extent
58 candidate_times = jnp.clip(
59 candidate_times, min=0, max=events.shape[-2] - 1
60 )
62 # Compute tensor of indices into src_state and dest_state
63 src_dest_indices = (
64 candidate_times,
65 jnp.broadcast_to(
66 jnp.arange(events.shape[-1])[:, np.newaxis],
67 candidate_times.shape,
68 ), # Make unit indices broadcast over rows of time_window
69 )
71 # Gather
72 src_max = src_state[src_dest_indices]
73 dest_max = dest_state[src_dest_indices]
75 # Pick out required bounds
76 state_max = jnp.where(delta[:, np.newaxis] < 0, src_max, dest_max)
77 state_max = jnp.min(state_max, axis=-1)
79 # Take min of bounds and available events
80 timepoint_events = events[timepoints, np.arange(events.shape[-1])]
81 bound = jnp.minimum(state_max, timepoint_events)
83 return bound
85 return fn
88def _is_within(x, low, high):
89 """Tests `low <= x < high`"""
90 return jnp.logical_and(jnp.greater_equal(x, low), jnp.less(x, high))
93def _timepoint_selector(events, name):
94 has_events = (events > 0.0).astype(events.dtype)
95 timepoint_probs = (has_events / has_events.sum(axis=-2, keepdims=True)).T
97 return tfd.Independent(
98 tfd.Categorical(probs=timepoint_probs),
99 reinterpreted_batch_ndims=1,
100 name=name,
101 )
104def _delta_selector(delta_max, timepoints, events, name):
105 candidate_deltas = jnp.concatenate(
106 [
107 np.arange(-delta_max, 0, dtype=np.int32),
108 np.arange(1, delta_max + 1, dtype=np.int32),
109 ],
110 axis=0,
111 )
112 valid_candidates = _is_within( # ensure timepoint + delta is valid
113 candidate_deltas[..., np.newaxis, :] + timepoints[..., np.newaxis],
114 low=0,
115 high=events.shape[-2],
116 ).astype(events.dtype)
118 probs = valid_candidates / valid_candidates.sum(axis=-1, keepdims=True)
120 return tfd.Independent(
121 tfd.FiniteDiscrete(outcomes=candidate_deltas, probs=probs),
122 reinterpreted_batch_ndims=1,
123 name=name,
124 )
127def _make_event_count_selector(src_dest_ids, count_bounding_fn, count_max):
128 def fn(events, state, timepoints, delta):
129 bound = count_bounding_fn(
130 events=events,
131 src_state=state[:, :, src_dest_ids[0]],
132 dest_state=state[:, :, src_dest_ids[1]],
133 timepoints=timepoints,
134 delta=delta,
135 )
137 bound = jnp.minimum(bound, count_max) # Upper bound on number of events
139 # Select number of events to move
140 return tfd.Independent(
141 UniformInteger(
142 low=np.int32(1),
143 high=bound.astype(np.int32) + 1,
144 float_dtype=events.dtype,
145 ),
146 reinterpreted_batch_ndims=1,
147 name="event_count",
148 )
150 return fn
153def discrete_move_events_proposal(
154 incidence_matrix: Tensor,
155 target_transition_id: int,
156 num_units: int,
157 delta_max: int,
158 count_max: int,
159 initial_conditions: Tensor,
160 events: Tensor,
161 count_bounding_fn: Callable[
162 [Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
163 ],
164 name=None,
165):
166 """Build a proposal mechanism for moving discrete transition events
168 Args
169 ----
170 incidence_matrix: a `[S,R]` matrix representing the state transition model
171 target_transition_id: the id (column) of the transition to update
172 num_units: number of units to update simultaneously
173 delta_max: maximum time delta through which to move transitions
174 count_max: absolute maximum number of transition events to move
175 initial_conditions: a `[M,S]` matrix representing `M` units and `S` states
176 events: a `[T, M, R]` tensor giving the number of events that occur for each
177 time $t=0,...,T-1$, unit $m=0,...,M-1$, and transition $r=0,...,R-1$
178 count_bounding_fn: a callable taking `events`, `src_state`, `dest_state`,
179 `timepoints`, and `delta_t` and returning tensor of
180 length `num_units` of the maximum possible events that
181 can be moved.
183 Returns
184 -------
185 An instance of `tfd.JointDistributionCoroutine` which samples a proposal
186 """
187 events_ = jnp.asarray(events)
188 initial_conditions_ = jnp.asarray(initial_conditions)
190 def proposal():
191 target_events = events_[..., target_transition_id]
192 state = compute_state(initial_conditions_, events, incidence_matrix)
194 src_dest_ids = transition_coords(incidence_matrix)[target_transition_id]
196 # Select units to update if they have events
197 nonzero_units = jnp.any(target_events > 0, axis=-2)
198 units = yield tfd.JointDistribution.Root(
199 UniformKCategorical(
200 k=num_units,
201 mask=nonzero_units,
202 float_dtype=events_.dtype,
203 name="unit",
204 )
205 )
207 # Select out the units of interest
208 events_subset = target_events[..., units]
209 state_subset = state[..., units, :]
211 # Select timepoint to update
212 timepoints = yield _timepoint_selector(events_subset, "timepoint")
214 # Select distance to move, clipping if a delta would move us out of
215 # the valid range [0, events_subset.shape[-2])
216 delta = yield _delta_selector(
217 delta_max, timepoints, events_subset, "delta"
218 )
220 # Compute number of events to move
221 count_selector = _make_event_count_selector(
222 src_dest_ids, count_bounding_fn, count_max
223 )
224 yield count_selector(events_subset, state_subset, timepoints, delta)
226 return tfd.JointDistributionCoroutine(proposal, name=name)
229class EventTimesKernelResults(NamedTuple):
230 log_acceptance_correction: float
231 target_log_prob: float
232 fwd_proposed_move: NamedTuple
233 rev_proposed_move: NamedTuple
234 seed: list[int]
237def _apply_move(event_tensor, event_id, move):
238 """Apply `move` to `event_tensor`
240 Args
241 ----
242 event_tensor: shape [T, M, R]
243 event_id: the event id to move
244 move: a data structure with fields ["unit", "timepoint", "delta",
245 "counts"]
247 Returns
248 -------
249 a copy of `event_tensor` with `move` applied
250 """
251 event_tensor = jnp.array(event_tensor)
253 # Subtract `count` from the `event_tensor[timepoint, :, event_id]`
254 new_state = event_tensor.at[move.timepoint, move.unit, event_id].subtract(
255 move.event_count
256 )
258 # Add `count` to [move.timpoint+delta, :, event_id]
259 new_state = new_state.at[
260 move.timepoint + move.delta, move.unit, event_id
261 ].add(move.event_count)
263 return new_state
266def _reverse_move(move):
267 return move._replace(
268 timepoint=move.timepoint + move.delta, delta=-move.delta
269 )
272class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
273 """UncalibratedEventTimesUpdate"""
275 def __init__(
276 self,
277 target_log_prob_fn: Callable[[Position], float],
278 incidence_matrix: Tensor,
279 initial_conditions: Tensor,
280 target_transition_id: int,
281 num_units: int,
282 delta_max: int,
283 count_max: int,
284 name: str | None = None,
285 ):
286 """An uncalibrated random walk for event times.
287 :param target_log_prob_fn: the log density of the target distribution
288 :param target_event_id: the position in the first dimension of the
289 events tensor that we wish to move
290 :param prev_event_id: the position of the previous event in the events
291 tensor
292 :param next_event_id: the position of the next event in the events
293 tensor
294 :param initial_state: the initial state tensor
295 :param seed: a random seed
296 :param name: the name of the update step
297 """
298 self._name = name
299 self._parameters = {
300 "target_log_prob_fn": target_log_prob_fn,
301 "incidence_matrix": incidence_matrix,
302 "initial_conditions": initial_conditions,
303 "target_transition_id": target_transition_id,
304 "num_units": num_units,
305 "delta_max": delta_max,
306 "count_max": count_max,
307 "name": name,
308 }
310 @property
311 def target_log_prob_fn(self):
312 return self._parameters["target_log_prob_fn"]
314 @property
315 def incidence_matrix(self):
316 return self._parameters["incidence_matrix"]
318 @property
319 def initial_conditions(self):
320 return self._parameters["initial_conditions"]
322 @property
323 def target_transition_id(self):
324 return self._parameters["target_transition_id"]
326 @property
327 def num_units(self):
328 return self._parameters["num_units"]
330 @property
331 def delta_max(self):
332 return self._parameters["delta_max"]
334 @property
335 def count_max(self):
336 return self._parameters["count_max"]
338 @property
339 def name(self):
340 return self._parameters["name"]
342 @property
343 def parameters(self):
344 """Return `dict` of ``__init__`` arguments and their values."""
345 return self._parameters
347 @property
348 def is_calibrated(self):
349 return False
351 def _proposal(self, events):
352 return discrete_move_events_proposal(
353 incidence_matrix=self.incidence_matrix,
354 target_transition_id=self.target_transition_id,
355 num_units=self.num_units,
356 delta_max=self.delta_max,
357 count_max=self.count_max,
358 initial_conditions=self.initial_conditions,
359 events=events,
360 count_bounding_fn=events_state_count_bounding_fn(self.delta_max),
361 name=self.name,
362 )
364 def one_step(self, current_events, previous_kernel_results, seed): # noqa: ARG002
365 """One update of event times.
366 :param current_events: a [T, M, X] tensor containing number of events
367 per time t, metapopulation m,
368 and transition x.
369 :param previous_kernel_results: an object of type
370 UncalibratedRandomWalkResults.
371 :returns: a tuple containing new_state and UncalibratedRandomWalkResults
372 """
373 seed = sanitize_key(seed)
374 step_events = current_events
375 if mcmc_util.is_list_like(current_events):
376 step_events = current_events[0]
377 warn(
378 "Batched event times updates are not supported. Using \
379 first event item only.",
380 stacklevel=2,
381 )
382 fwd_proposal_dist = self._proposal(step_events)
383 fwd_proposed_move = fwd_proposal_dist.sample(seed=seed)
384 fwd_prob = fwd_proposal_dist.log_prob(fwd_proposed_move)
386 # Propagate the proposal into events
387 proposed_events = _apply_move(
388 event_tensor=step_events,
389 event_id=self.target_transition_id,
390 move=fwd_proposed_move,
391 )
393 next_target_log_prob = self.target_log_prob_fn(proposed_events)
395 # Compute the reverse move
396 rev_proposed_move = _reverse_move(fwd_proposed_move)
398 rev_prob = self._proposal(proposed_events).log_prob(rev_proposed_move)
399 log_acceptance_correction = rev_prob - fwd_prob
401 if mcmc_util.is_list_like(current_events):
402 proposed_events = [proposed_events]
404 return (
405 proposed_events,
406 EventTimesKernelResults(
407 log_acceptance_correction=log_acceptance_correction,
408 target_log_prob=next_target_log_prob,
409 fwd_proposed_move=fwd_proposed_move,
410 rev_proposed_move=rev_proposed_move,
411 seed=seed,
412 ),
413 )
415 def bootstrap_results(self, init_state):
416 if mcmc_util.is_list_like(init_state):
417 init_state = init_state[0]
418 warn(
419 "Batched event times updates are not supported. Using \
420 first event item only.",
421 stacklevel=2,
422 )
424 init_target_log_prob = self.target_log_prob_fn(init_state)
426 proposal_example = self._proposal(init_state).sample(
427 seed=sanitize_key(0)
428 )
430 return EventTimesKernelResults(
431 log_acceptance_correction=jnp.zeros_like(init_target_log_prob),
432 target_log_prob=init_target_log_prob,
433 fwd_proposed_move=proposal_example,
434 rev_proposed_move=proposal_example,
435 seed=sanitize_key(0),
436 )