Coverage for gemlib/mcmc/discrete_time_state_transition_model/left_censored_events_mh.py: 100%
44 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Left-censored events MCMC kernel for DiscreteTimeStateTransitionModel"""
3from functools import partial
4from typing import NamedTuple
6import numpy as np
7import tensorflow_probability.substrates.jax as tfp
9from gemlib.mcmc.discrete_time_state_transition_model.left_censored_events_impl import ( # noqa: E501
10 UncalibratedLeftCensoredEventTimesUpdate,
11)
12from gemlib.mcmc.sampling_algorithm import (
13 ChainState,
14 LogProbFnType,
15 SamplingAlgorithm,
16)
18Tensor = np.typing.NDArray
21class LeftCensoredEventsState(NamedTuple):
22 max_timepoint: int
23 max_events: int
26class LeftCensoredEventsInfo(NamedTuple):
27 is_accepted: bool
28 log_acceptance_correction: float
29 target_log_prob: float
30 unit: int
31 timepoint: int
32 direction: int
33 num_events: int
34 seed: tuple[int, int]
37class LeftCensoredEventsPosition(NamedTuple):
38 initial_conditions: Tensor
39 events: Tensor
42def _get_state_tuple(
43 initial_conditions_varname: str,
44 events_varname: str,
45 position: NamedTuple,
46):
47 return LeftCensoredEventsPosition(
48 getattr(position, initial_conditions_varname),
49 getattr(position, events_varname),
50 )
53def _repack_state_tuple(
54 initial_conditions_varname: str,
55 events_varname: str,
56 new_structure: LeftCensoredEventsPosition,
57 original_structure: NamedTuple,
58):
59 return original_structure.__class__(
60 **{
61 initial_conditions_varname: new_structure.initial_conditions,
62 events_varname: new_structure.events,
63 }
64 )
67def left_censored_events_mh(
68 incidence_matrix: Tensor,
69 transition_index: int,
70 max_timepoint: int,
71 max_events: int,
72 events_varname: str,
73 initial_conditions_varname: str,
74 name: str | None = None,
75):
76 """Update initial conditions and events for DiscreteTimeStateTransitionModel
78 In observations of a DiscreteTimeStateTransitionModel realisation,
79 there may be uncertainty about the initial conditions and hence the number
80 of events occurring in the early part of the timeseries. This MCMC kernel
81 provides a means of updating these left-censored events.
83 Args
84 ----
85 incidence_matrix: the state-transition graph incidence matrix
86 transition_index: the index of the transition in `incidence_matrix` to
87 update
88 max_timepoint: max timepoint up to which to propose moves
89 max_events: max number of events per unit/timepoint to move
90 events_varname: the name of the random variable holding the events
91 timeseries in a NamedTuple supplied to both the `init` and
92 `step` functions.
93 initial_conditions_varname: the name of the random variable representing the
94 initial conditions in a NamedTuple supplied to
95 both the `init` and `step` functions.
96 name: name of the kernel.
98 Returns
99 -------
100 A instance of SamplingAlgorithm
101 """
103 canonize = partial(
104 _get_state_tuple, initial_conditions_varname, events_varname
105 )
106 uncanonize = partial(
107 _repack_state_tuple, initial_conditions_varname, events_varname
108 )
110 def _build_kernel(target_log_prob_fn):
111 return tfp.mcmc.MetropolisHastings(
112 inner_kernel=UncalibratedLeftCensoredEventTimesUpdate(
113 target_log_prob_fn,
114 transition_index,
115 incidence_matrix,
116 max_timepoint,
117 max_events,
118 name,
119 )
120 )
122 def init_fn(target_log_prob_fn: LogProbFnType, target_state: NamedTuple):
123 kernel = _build_kernel(target_log_prob_fn)
124 results = kernel.bootstrap_results(canonize(target_state))
125 chain_state = ChainState(
126 position=target_state,
127 log_density=results.accepted_results.target_log_prob,
128 )
129 kernel_state = LeftCensoredEventsState(
130 max_timepoint=max_timepoint, max_events=max_events
131 )
133 return chain_state, kernel_state
135 def step_fn(target_log_prob_fn, target_and_kernel_state, seed):
136 kernel = _build_kernel(target_log_prob_fn)
137 target_chain_state, kernel_state = target_and_kernel_state
139 new_target_position, results = kernel.one_step(
140 canonize(target_chain_state.position),
141 kernel.bootstrap_results(canonize(target_chain_state.position)),
142 seed=seed,
143 )
145 new_chain_and_kernel_state = (
146 ChainState(
147 position=uncanonize(
148 new_target_position, target_chain_state.position
149 ),
150 log_density=results.accepted_results.target_log_prob,
151 ),
152 kernel_state,
153 )
155 return new_chain_and_kernel_state, LeftCensoredEventsInfo(
156 is_accepted=results.is_accepted,
157 log_acceptance_correction=results.proposed_results.log_acceptance_correction,
158 target_log_prob=results.proposed_results.target_log_prob,
159 unit=results.proposed_results.unit,
160 timepoint=results.proposed_results.timepoint,
161 direction=results.proposed_results.direction,
162 num_events=results.proposed_results.num_events,
163 seed=seed,
164 )
166 return SamplingAlgorithm(init_fn, step_fn)