Coverage for gemlib/mcmc/discrete_time_state_transition_model/right_censored_events_mh.py: 100%
40 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Right-censored events MCMC kernel for DiscreteTimeStateTransitionModel"""
3from typing import NamedTuple
5import jax.numpy as jnp
6import numpy as np
7import tensorflow_probability.substrates.jax as tfp
9from gemlib.mcmc.discrete_time_state_transition_model.right_censored_events_impl import ( # noqa: E501
10 UncalibratedOccultUpdate,
11)
12from gemlib.mcmc.sampling_algorithm import ChainState, SamplingAlgorithm
14Tensor = np.typing.NDArray
17class RightCensoredEventsState(NamedTuple):
18 count_max: int
19 t_range: tuple[int, int]
22class RightCensoredEventsInfo(NamedTuple):
23 event_count: int
24 is_accepted: bool
25 is_add: int
26 log_acceptance_correction: float
27 proposed_state: Tensor
28 proposed_log_density: float
29 seed: tuple[int, int]
30 timepoint: int
31 unit: int
34def right_censored_events_mh(
35 incidence_matrix: Tensor,
36 transition_index: int,
37 t_range: tuple[int, int],
38 count_max: int = 1,
39 name: str | None = None,
40) -> SamplingAlgorithm:
41 r"""Update right-censored events for DiscreteTimeStateTransitionModel
43 In a partially-complete epidemic, we may wish to explore the space of
44 events that _might_ have occurred, but are not apparent because
45 of some detection event that is yet to occur. This MCMC kernel performs
46 a single-site "add/delete" move, as described (in continuous time) in
47 O'Neill and Roberts (1999).
49 Args:
50 incidence_matrix: the state-transition graph incidence matrix
51 transition_index: the index of the transition in :obj:`incidence_matrix`
52 to update
53 t_range: the time-range for which to update censored transition events.
54 Typically this would be :obj:`[s, num_steps)` where :obj:`s < num_steps`
55 for :obj:`num_steps` the total number of timesteps in the model.
56 count_max: the max number of transitions to add or delete in any one
57 update
58 name: name of the kernel
60 Returns:
61 An instance of :obj:`SamplingAlgorithm` with members
62 :obj:`init(tlp_fn: LogProbFnType, position: Position,
63 initial_conditions: Tensor)` and :obj:`step(tlp_fn: LogProbFnType,
64 cks: ChainAndKernelState, seed: SeedType, initial_conditions: Tensor)`.
65 :obj:`initial_conditions` is an extra keyword
66 argument required to be passed to the :obj:`init` and :obj:`step`
67 functions.
69 References:
70 1. Philip D O'Neill and Gareth O Roberts (1999) Baysian inference for
71 partially observed stochastic epidemics. _Journal of the Royal
72 Statistical Society: Series A (Statistics in Society), **162**
73 \ :121--129.
75 2. Jewell _et al._ (2023) Bayesian inference for high-dimensional
76 discrete-time epidemic models: spatial dynamics of the UK COVID-19
77 outbreak. Pre-print arXiv:2306.07987
78 """
80 def _build_kernel(target_log_prob_fn, initial_conditions):
81 return tfp.mcmc.MetropolisHastings(
82 inner_kernel=UncalibratedOccultUpdate(
83 target_log_prob_fn,
84 incidence_matrix=incidence_matrix,
85 initial_conditions=initial_conditions,
86 target_transition_id=transition_index,
87 count_max=count_max,
88 t_range=t_range,
89 name=name,
90 ),
91 )
93 def init_fn(target_log_prob_fn, position, initial_conditions):
94 position = jnp.asarray(position)
95 initial_conditions = jnp.asarray(initial_conditions)
97 kernel = _build_kernel(target_log_prob_fn, initial_conditions)
98 results = kernel.bootstrap_results(position)
99 chain_state = ChainState(
100 position=position,
101 log_density=results.accepted_results.target_log_prob,
102 )
103 kernel_state = RightCensoredEventsState(
104 count_max=count_max, t_range=t_range
105 )
107 return chain_state, kernel_state
109 def step_fn(
110 target_log_prob_fn, target_and_kernel_state, seed, initial_conditions
111 ):
112 initial_conditions = jnp.asarray(initial_conditions)
114 target_chain_state, kernel_state = target_and_kernel_state
116 kernel = _build_kernel(target_log_prob_fn, initial_conditions)
118 new_target_position, results = kernel.one_step(
119 target_chain_state.position,
120 kernel.bootstrap_results(target_chain_state.position),
121 seed=seed,
122 )
123 new_chain_and_kernel_state = (
124 ChainState(
125 position=new_target_position,
126 log_density=results.accepted_results.target_log_prob,
127 ),
128 kernel_state,
129 )
131 pr = results.proposed_results
133 return new_chain_and_kernel_state, RightCensoredEventsInfo(
134 is_accepted=results.is_accepted,
135 unit=pr.unit,
136 timepoint=pr.timepoint,
137 is_add=pr.is_add,
138 event_count=pr.event_count,
139 log_acceptance_correction=pr.log_acceptance_correction,
140 proposed_state=results.proposed_state,
141 proposed_log_density=results.proposed_results.target_log_prob,
142 seed=results.seed,
143 )
145 return SamplingAlgorithm(init_fn, step_fn)