Coverage for gemlib/mcmc/discrete_time_state_transition_model/right_censored_events_impl.py: 96%
107 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Sampler for discrete-space occult events"""
3from collections.abc import Callable
4from typing import NamedTuple
6import jax
7import jax.numpy as jnp
8import numpy as np
9import tensorflow_probability.substrates.jax as tfp
10from tensorflow_probability.substrates.jax.mcmc.internal import (
11 util as mcmc_util,
12)
14from gemlib.distributions.discrete_markov import (
15 compute_state,
16)
17from gemlib.mcmc.discrete_time_state_transition_model.right_censored_events_proposal import ( # noqa:E501
18 add_occult_proposal,
19 del_occult_proposal,
20)
21from gemlib.mcmc.sampling_algorithm import Position
22from gemlib.prng_util import sanitize_key
23from gemlib.util import transition_coords
25tfd = tfp.distributions
26Tensor = np.typing.NDArray
28__all__ = ["UncalibratedOccultUpdate"]
31PROB_DIRECTION = 0.5
34class OccultKernelResults(NamedTuple):
35 log_acceptance_correction: float
36 target_log_prob: float
37 unit: int
38 timepoint: int
39 is_add: bool
40 event_count: int
41 seed: tuple[int, int]
44def _is_row_nonzero(m):
45 return jnp.sum(m, axis=-1) > 0.0
48def _maybe_expand_dims(x):
49 """If x is a scalar, give it at least 1 dimension"""
50 return jnp.atleast_1d(x)
53def _add_events(events, unit, timepoint, target_transition_id, event_count):
54 """Adds `x_star` events to metapopulation `m`,
55 time `t`, transition `x` in `events`.
56 """
57 events = jnp.asarray(events)
58 return events.at[timepoint, unit, target_transition_id].add(event_count)
61class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
62 """UncalibratedOccultUpdate"""
64 def __init__(
65 self,
66 target_log_prob_fn: Callable[[Position], float],
67 incidence_matrix: Tensor,
68 initial_conditions: Tensor,
69 target_transition_id: int,
70 count_max: int,
71 t_range,
72 name=None,
73 ):
74 """An uncalibrated random walk for event times.
75 :param target_log_prob_fn: the log density of the target distribution
76 :param target_event_id: the position in the last dimension of the events
77 tensor that we wish to move
78 :param t_range: a tuple containing earliest and latest times between
79 which to update occults.
80 :param seed: a random seed
81 :param name: the name of the update step
82 """
83 self._parameters = dict(locals())
84 self._name = name or "uncalibrated_occult_update"
85 self._dtype = jnp.asarray(initial_conditions).dtype
87 @property
88 def target_log_prob_fn(self):
89 return self._parameters["target_log_prob_fn"]
91 @property
92 def incidence_matrix(self):
93 return self._parameters["incidence_matrix"]
95 @property
96 def target_transition_id(self):
97 return self._parameters["target_transition_id"]
99 @property
100 def initial_conditions(self):
101 return self._parameters["initial_conditions"]
103 @property
104 def count_max(self):
105 return self._parameters["count_max"]
107 @property
108 def t_range(self):
109 return self._parameters["t_range"]
111 @property
112 def name(self):
113 return self._parameters["name"]
115 @property
116 def parameters(self):
117 """Return `dict` of ``__init__`` arguments and their values."""
118 return self._parameters
120 @property
121 def is_calibrated(self):
122 return False
124 def one_step(self, current_events, previous_kernel_results, seed): # noqa: ARG002
125 """One update of event times.
126 :param current_events: a [T, M, R] tensor containing number of events
127 per time t, metapopulation m,
128 and transition r.
129 :param previous_kernel_results: an object of type
130 UncalibratedRandomWalkResults.
131 :returns: a tuple containing new_state and UncalibratedRandomWalkResults
132 """
133 with jax.named_scope("occult_rw/onestep"):
134 current_events = jnp.asarray(current_events)
135 initial_conditions = jnp.asarray(self.initial_conditions)
137 proposal_seed, add_del_seed = jax.random.split(seed)
138 t_range_slice = slice(*self.t_range)
140 state = compute_state(
141 initial_conditions,
142 current_events,
143 self.incidence_matrix,
144 )
145 src_dest_ids = transition_coords(self.incidence_matrix)[
146 self.target_transition_id, :
147 ]
149 # Pull out the section of events and state that are within
150 # the requested time interval - we focus on this, and insert
151 # updated values back into the full events at the end.
152 range_events = current_events[t_range_slice]
153 state_slice = state[t_range_slice]
155 def add_occult_fn():
156 with jax.named_scope("true_fn"):
157 proposal = add_occult_proposal(
158 count_max=self.count_max,
159 events=range_events[..., self.target_transition_id],
160 src_state=state_slice[..., src_dest_ids[0]],
161 )
162 update = proposal.sample(seed=proposal_seed)
164 next_events = _add_events(
165 events=range_events,
166 unit=update.unit,
167 timepoint=update.timepoint,
168 target_transition_id=self.target_transition_id,
169 event_count=update.event_count,
170 )
172 next_dest_state = compute_state(
173 state_slice[0], next_events, self.incidence_matrix
174 )
175 reverse = del_occult_proposal(
176 count_max=self.count_max,
177 events=next_events[..., self.target_transition_id],
178 dest_state=next_dest_state[..., src_dest_ids[1]],
179 )
180 q_fwd = jnp.sum(proposal.log_prob(update))
181 q_rev = jnp.sum(reverse.log_prob(update))
182 log_acceptance_correction = q_rev - q_fwd
184 return (
185 update,
186 next_events,
187 log_acceptance_correction,
188 True,
189 )
191 def del_occult_fn():
192 with jax.named_scope("false_fn"):
193 proposal = del_occult_proposal(
194 count_max=self.count_max,
195 events=range_events[..., self.target_transition_id],
196 dest_state=state_slice[..., src_dest_ids[1]],
197 )
198 update = proposal.sample(seed=proposal_seed)
200 next_events = _add_events(
201 events=range_events,
202 unit=update.unit,
203 timepoint=update.timepoint,
204 target_transition_id=self.target_transition_id,
205 event_count=-update.event_count,
206 )
208 next_src_state = compute_state(
209 state_slice[0], next_events, self.incidence_matrix
210 )
211 reverse = add_occult_proposal(
212 count_max=self.count_max,
213 events=next_events[..., self.target_transition_id],
214 src_state=next_src_state[..., src_dest_ids[0]],
215 )
216 q_fwd = jnp.sum(proposal.log_prob(update))
217 q_rev = jnp.sum(reverse.log_prob(update))
218 log_acceptance_correction = q_rev - q_fwd
220 return (
221 update,
222 next_events,
223 log_acceptance_correction,
224 False,
225 )
227 u = tfd.Uniform().sample(seed=add_del_seed)
228 delta, next_range_events, log_acceptance_correction, is_add = (
229 jax.lax.cond(
230 (u < PROB_DIRECTION)
231 & (jnp.count_nonzero(range_events) > 0),
232 del_occult_fn,
233 add_occult_fn,
234 )
235 )
237 # Update current_events with the new next_range_events tensor
238 next_events = current_events.at[t_range_slice].set(
239 next_range_events
240 )
241 next_target_log_prob = self.target_log_prob_fn(next_events)
243 return (
244 next_events,
245 OccultKernelResults(
246 log_acceptance_correction=log_acceptance_correction,
247 target_log_prob=next_target_log_prob,
248 unit=delta.unit,
249 timepoint=delta.timepoint,
250 is_add=is_add,
251 event_count=delta.event_count,
252 seed=seed,
253 ),
254 )
256 def bootstrap_results(self, init_state):
257 with jax.named_scope("uncalibrated_event_times_rw/bootstrap_results"):
258 if not mcmc_util.is_list_like(init_state):
259 init_state = [init_state]
261 init_state = [jnp.asarray(x, dtype=self._dtype) for x in init_state]
262 init_target_log_prob = self.target_log_prob_fn(*init_state)
263 return OccultKernelResults(
264 log_acceptance_correction=jnp.asarray(
265 0.0, dtype=init_target_log_prob.dtype
266 ),
267 target_log_prob=init_target_log_prob,
268 unit=jnp.zeros((), dtype=np.int32),
269 timepoint=jnp.zeros((), dtype=np.int32),
270 is_add=jnp.asarray(True),
271 event_count=jnp.zeros((), dtype=np.int32),
272 seed=sanitize_key(0),
273 )