Coverage for gemlib/mcmc/discrete_time_state_transition_model/move_events.py: 100%
33 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Move partially-censored events in DiscreteTimeStateTransitionModel"""
3from typing import NamedTuple
5import jax
6import tensorflow_probability.substrates.jax as tfp
8from gemlib.mcmc.discrete_time_state_transition_model.move_events_impl import (
9 UncalibratedEventTimesUpdate,
10)
11from gemlib.mcmc.sampling_algorithm import ChainState, SamplingAlgorithm
13Tensor = jax.Array
16class MoveEventsState(NamedTuple):
17 num_units: int
18 delta_max: int
19 count_max: int
22class MoveEventsInfo(NamedTuple):
23 is_accepted: bool
24 initial_conditions: float
25 proposed_events: Tensor
26 proposed_log_density: float
27 log_accept_ratio: float
28 seed: Tensor
31def move_events(
32 incidence_matrix: Tensor,
33 transition_index: int,
34 num_units: int,
35 delta_max: int,
36 count_max: int,
37 name: str | None = None,
38):
39 r"""Move partially-censored transition events
41 This kernel provides an MCMC algorithm that moves partially-censored
42 transition events in a DiscreteTimeStateTransitionModel event timeseres.
43 It caters for the situation in which a transition is known to have occurred
44 (e.g. as a result of a subsequent case detection), but the time at
45 which it occurred is unknown.
47 Args:
48 incidence_matrix: the state-transition graph incidence matrix
49 transition_index: the index of the transition in `incidence_matrix`
50 to update.
51 num_units: the number of epidemiological units to update at once
52 delta_max: the maximum time interval over which to move transition
53 times.
54 count_max: the maximum number of transitions to move at once.
55 name: the name of the kernel.
57 Returns:
58 An instance of :obj:`SamplingAlgorithm` with members
59 :obj:`init(tlp_fn: LogProbFnType, position: Position,
60 initial_conditions: Tensor)` and :obj:`step(tlp_fn: LogProbFnType,
61 cks: ChainAndKernelState, seed: SeedType, initial_conditions: Tensor)`.
62 :obj:`initial_conditions` is an extra keyword
63 argument required to be passed to the :obj:`init` and :obj:`step`
64 functions.
66 References:
67 Philip D O'Neill and Gareth O Roberts (1999) Baysian inference for
68 partially observed stochastic epidemics. _Journal of the Royal
69 Statistical Society: Series A (Statistics in Society), **162**\ :121--129.
71 Jewell *et al.* (2023) Bayesian inference for high-dimensional discrete-
72 time epidemic models: spatial dynamics of the UK COVID-19 outbreak.
73 Pre-print arXiv:2306.07987
74 """
76 def _build_kernel(target_log_prob_fn, initial_conditions):
77 return tfp.mcmc.MetropolisHastings(
78 inner_kernel=UncalibratedEventTimesUpdate(
79 target_log_prob_fn,
80 incidence_matrix=incidence_matrix,
81 initial_conditions=initial_conditions,
82 target_transition_id=transition_index,
83 num_units=num_units,
84 delta_max=delta_max,
85 count_max=count_max,
86 name=name or "move_events",
87 )
88 )
90 def init_fn(target_log_prob_fn, position, initial_conditions):
91 kernel = _build_kernel(target_log_prob_fn, initial_conditions)
92 results = kernel.bootstrap_results(position)
93 chain_state = ChainState(
94 position=position,
95 log_density=results.accepted_results.target_log_prob,
96 )
97 kernel_state = MoveEventsState(num_units, delta_max, count_max)
99 return chain_state, kernel_state
101 def step_fn(
102 target_log_prob_fn, target_and_kernel_state, seed, initial_conditions
103 ):
104 target_chain_state, kernel_state = target_and_kernel_state
106 kernel = _build_kernel(target_log_prob_fn, initial_conditions)
108 new_position, results = kernel.one_step(
109 target_chain_state.position,
110 kernel.bootstrap_results(target_chain_state.position),
111 seed=seed,
112 )
114 new_chain_and_kernel_state = (
115 ChainState(
116 position=new_position,
117 log_density=results.accepted_results.target_log_prob,
118 ),
119 kernel_state,
120 )
122 return new_chain_and_kernel_state, MoveEventsInfo(
123 is_accepted=results.is_accepted,
124 initial_conditions=initial_conditions,
125 proposed_events=results.proposed_state,
126 proposed_log_density=results.proposed_results.target_log_prob,
127 log_accept_ratio=results.log_accept_ratio,
128 seed=seed,
129 )
131 return SamplingAlgorithm(init_fn, step_fn)