Coverage for gemlib/mcmc/discrete_time_state_transition_model/left_censored_events_impl.py: 96%
78 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Metropolis Hastings implementation for left-censored events
2in a discrete-time metapopulation epidemic model
3"""
5from typing import NamedTuple
7import jax.numpy as jnp
8import numpy as np
9import tensorflow_probability.substrates.jax as tfp
10from tensorflow_probability.substrates.jax.mcmc.internal import (
11 util as mcmc_util,
12)
14from gemlib.prng_util import sanitize_key
16from .left_censored_events_proposal import (
17 left_censored_event_time_proposal,
18)
20tfd = tfp.distributions
22__all__ = ["UncalibratedLeftCensoredEventTimesUpdate"]
24LEFT_CENSORED_EVENT_TIMES_UPDATE_TUPLE_LEN = 2
27def _update_state(update, current_state, transition_idx, incidence_matrix):
28 current_initial_conditions, current_events = (
29 jnp.asarray(x) for x in current_state
30 )
32 # -1 is moving events forward in time
33 sign = jnp.array([-1, 1])[update["direction"]]
34 events_delta = update["num_events"] * sign
36 # Update initial conditions
37 indices = (
38 jnp.broadcast_to(
39 update["unit"], [current_initial_conditions.shape[-1]]
40 ),
41 np.arange(current_initial_conditions.shape[-1]),
42 )
43 new_initial_conditions = current_initial_conditions.at[indices].add(
44 events_delta * incidence_matrix[..., transition_idx]
45 )
47 # Update events
48 new_events = current_events.at[
49 update["timepoint"], update["unit"], transition_idx
50 ].subtract(events_delta)
52 return new_initial_conditions, new_events
55def _reverse_update(update):
56 direction = (update["direction"] + 1) % 2
57 return {
58 "unit": update["unit"],
59 "timepoint": update["timepoint"],
60 "direction": direction,
61 "num_events": update["num_events"],
62 }
65class LeftCensoredEventTimeResults(NamedTuple):
66 log_acceptance_correction: float
67 target_log_prob: float
68 unit: int
69 timepoint: int
70 direction: int
71 num_events: int
72 seed: tuple[int, int]
75class UncalibratedLeftCensoredEventTimesUpdate(tfp.mcmc.TransitionKernel):
76 """UncalibratedLeftCensoredEventTimesUpdate"""
78 def __init__(
79 self,
80 target_log_prob_fn,
81 transition_index,
82 incidence_matrix,
83 max_timepoint,
84 max_events,
85 name=None,
86 ):
87 """An uncalibrated random walk for initial conditions.
88 :param target_log_prob_fn: the log density of the target distribution
89 :param transition_index: the index of the transition to adjust
90 :param incidence_matrix: the `[S,R]` incidence matrix
91 :param max_timepoint: max timepoint up to which to move events
92 :param max_events: max number of events per unit/timepoint to move
93 """
94 self._name = name
95 self._parameters = {
96 "target_log_prob_fn": target_log_prob_fn,
97 "transition_index": transition_index,
98 "incidence_matrix": incidence_matrix,
99 "max_timepoint": max_timepoint,
100 "max_events": max_events,
101 "name": name,
102 }
104 @property
105 def target_log_prob_fn(self):
106 return self._parameters["target_log_prob_fn"]
108 @property
109 def transition_index(self):
110 return self._parameters["transition_index"]
112 @property
113 def incidence_matrix(self):
114 return self._parameters["incidence_matrix"]
116 @property
117 def max_timepoint(self):
118 return self._parameters["max_timepoint"]
120 @property
121 def max_events(self):
122 return self._parameters["max_events"]
124 @property
125 def name(self):
126 return self._parameters["name"]
128 @property
129 def parameters(self):
130 return self._parameters
132 @property
133 def is_calibrated(self):
134 return False
136 def one_step(self, current_state, previous_kernel_results, seed=None): # noqa: ARG002
137 """Update the initial conditions
139 :param current_state: a tuple of `(current_initial_conditions,
140 current_events)`
141 :param previous_kernel_results: previous kernel results tuple
142 :param seed: optional seed tuple `(int32, int32)`
143 :returns: new state tuple `(next_initial_conditions, next_events)`
144 """
145 if (not mcmc_util.is_list_like(current_state)) and (
146 len(current_state) == LEFT_CENSORED_EVENT_TIMES_UPDATE_TUPLE_LEN
147 ):
148 raise ValueError(
149 f"State for LeftCensoredEventTimesUpdate must be list/tuple\
150 of length {LEFT_CENSORED_EVENT_TIMES_UPDATE_TUPLE_LEN}"
151 )
153 proposal = left_censored_event_time_proposal(
154 events=current_state[1],
155 initial_state=current_state[0],
156 transition=self.transition_index,
157 incidence_matrix=self.incidence_matrix,
158 num_units=1,
159 max_timepoint=self.max_timepoint,
160 max_events=self.max_events,
161 name=f"{self.name}/fwd_proposal",
162 )
163 fwd_update = proposal.sample(seed=seed)
164 fwd_proposal_log_prob = proposal.log_prob(
165 fwd_update, name="fwd_proposal_log_prob"
166 )
168 next_state = _update_state(
169 fwd_update,
170 current_state,
171 self.transition_index,
172 self.incidence_matrix,
173 )
174 next_target_log_prob = self.target_log_prob_fn(*next_state)
176 rev_update = _reverse_update(fwd_update)
177 rev_proposal = left_censored_event_time_proposal(
178 events=next_state[1],
179 initial_state=next_state[0],
180 transition=self.transition_index,
181 incidence_matrix=self.incidence_matrix,
182 num_units=1,
183 max_timepoint=self.max_timepoint,
184 max_events=self.max_events,
185 name=f"{self.name}/rev_proposal",
186 )
187 rev_proposal_log_prob = rev_proposal.log_prob(rev_update)
188 log_acceptance_correction = jnp.sum(
189 rev_proposal_log_prob - fwd_proposal_log_prob
190 )
191 results = (
192 next_state,
193 LeftCensoredEventTimeResults(
194 log_acceptance_correction=log_acceptance_correction,
195 target_log_prob=next_target_log_prob,
196 unit=fwd_update["unit"],
197 timepoint=fwd_update["timepoint"],
198 direction=fwd_update["direction"],
199 num_events=fwd_update["num_events"],
200 seed=seed,
201 ),
202 )
204 return results
206 def bootstrap_results(self, init_state):
207 if (not mcmc_util.is_list_like(init_state)) and (
208 len(init_state) == 2 # noqa: PLR2004
209 ):
210 raise ValueError(
211 f"State for LeftCensoredEventTimesUpdate must be list/tuple\
212 of length {LEFT_CENSORED_EVENT_TIMES_UPDATE_TUPLE_LEN}"
213 )
215 initial_conditions = init_state[0]
216 events = init_state[1]
217 init_target_log_prob = self.target_log_prob_fn(
218 initial_conditions, events
219 )
220 return LeftCensoredEventTimeResults(
221 log_acceptance_correction=jnp.asarray(
222 0.0, dtype=init_target_log_prob.dtype
223 ),
224 target_log_prob=init_target_log_prob,
225 unit=jnp.zeros([1], dtype=np.int32),
226 timepoint=jnp.zeros([1], dtype=np.int32),
227 direction=jnp.array(0, dtype=np.int32),
228 num_events=jnp.ones([1], dtype=np.int32),
229 seed=sanitize_key(0),
230 )