Coverage for gemlib/mcmc/discrete_time_state_transition_model/right_censored_events_mh.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Right-censored events MCMC kernel for DiscreteTimeStateTransitionModel""" 

2 

3from typing import NamedTuple 

4 

5import jax.numpy as jnp 

6import numpy as np 

7import tensorflow_probability.substrates.jax as tfp 

8 

9from gemlib.mcmc.discrete_time_state_transition_model.right_censored_events_impl import ( # noqa: E501 

10 UncalibratedOccultUpdate, 

11) 

12from gemlib.mcmc.sampling_algorithm import ChainState, SamplingAlgorithm 

13 

14Tensor = np.typing.NDArray 

15 

16 

17class RightCensoredEventsState(NamedTuple): 

18 count_max: int 

19 t_range: tuple[int, int] 

20 

21 

22class RightCensoredEventsInfo(NamedTuple): 

23 event_count: int 

24 is_accepted: bool 

25 is_add: int 

26 log_acceptance_correction: float 

27 proposed_state: Tensor 

28 proposed_log_density: float 

29 seed: tuple[int, int] 

30 timepoint: int 

31 unit: int 

32 

33 

34def right_censored_events_mh( 

35 incidence_matrix: Tensor, 

36 transition_index: int, 

37 t_range: tuple[int, int], 

38 count_max: int = 1, 

39 name: str | None = None, 

40) -> SamplingAlgorithm: 

41 r"""Update right-censored events for DiscreteTimeStateTransitionModel 

42 

43 In a partially-complete epidemic, we may wish to explore the space of 

44 events that _might_ have occurred, but are not apparent because 

45 of some detection event that is yet to occur. This MCMC kernel performs 

46 a single-site "add/delete" move, as described (in continuous time) in 

47 O'Neill and Roberts (1999). 

48 

49 Args: 

50 incidence_matrix: the state-transition graph incidence matrix 

51 transition_index: the index of the transition in :obj:`incidence_matrix` 

52 to update 

53 t_range: the time-range for which to update censored transition events. 

54 Typically this would be :obj:`[s, num_steps)` where :obj:`s < num_steps` 

55 for :obj:`num_steps` the total number of timesteps in the model. 

56 count_max: the max number of transitions to add or delete in any one 

57 update 

58 name: name of the kernel 

59 

60 Returns: 

61 An instance of :obj:`SamplingAlgorithm` with members 

62 :obj:`init(tlp_fn: LogProbFnType, position: Position, 

63 initial_conditions: Tensor)` and :obj:`step(tlp_fn: LogProbFnType, 

64 cks: ChainAndKernelState, seed: SeedType, initial_conditions: Tensor)`. 

65 :obj:`initial_conditions` is an extra keyword 

66 argument required to be passed to the :obj:`init` and :obj:`step` 

67 functions. 

68 

69 References: 

70 1. Philip D O'Neill and Gareth O Roberts (1999) Baysian inference for 

71 partially observed stochastic epidemics. _Journal of the Royal 

72 Statistical Society: Series A (Statistics in Society), **162** 

73 \ :121--129. 

74 

75 2. Jewell _et al._ (2023) Bayesian inference for high-dimensional 

76 discrete-time epidemic models: spatial dynamics of the UK COVID-19 

77 outbreak. Pre-print arXiv:2306.07987 

78 """ 

79 

80 def _build_kernel(target_log_prob_fn, initial_conditions): 

81 return tfp.mcmc.MetropolisHastings( 

82 inner_kernel=UncalibratedOccultUpdate( 

83 target_log_prob_fn, 

84 incidence_matrix=incidence_matrix, 

85 initial_conditions=initial_conditions, 

86 target_transition_id=transition_index, 

87 count_max=count_max, 

88 t_range=t_range, 

89 name=name, 

90 ), 

91 ) 

92 

93 def init_fn(target_log_prob_fn, position, initial_conditions): 

94 position = jnp.asarray(position) 

95 initial_conditions = jnp.asarray(initial_conditions) 

96 

97 kernel = _build_kernel(target_log_prob_fn, initial_conditions) 

98 results = kernel.bootstrap_results(position) 

99 chain_state = ChainState( 

100 position=position, 

101 log_density=results.accepted_results.target_log_prob, 

102 ) 

103 kernel_state = RightCensoredEventsState( 

104 count_max=count_max, t_range=t_range 

105 ) 

106 

107 return chain_state, kernel_state 

108 

109 def step_fn( 

110 target_log_prob_fn, target_and_kernel_state, seed, initial_conditions 

111 ): 

112 initial_conditions = jnp.asarray(initial_conditions) 

113 

114 target_chain_state, kernel_state = target_and_kernel_state 

115 

116 kernel = _build_kernel(target_log_prob_fn, initial_conditions) 

117 

118 new_target_position, results = kernel.one_step( 

119 target_chain_state.position, 

120 kernel.bootstrap_results(target_chain_state.position), 

121 seed=seed, 

122 ) 

123 new_chain_and_kernel_state = ( 

124 ChainState( 

125 position=new_target_position, 

126 log_density=results.accepted_results.target_log_prob, 

127 ), 

128 kernel_state, 

129 ) 

130 

131 pr = results.proposed_results 

132 

133 return new_chain_and_kernel_state, RightCensoredEventsInfo( 

134 is_accepted=results.is_accepted, 

135 unit=pr.unit, 

136 timepoint=pr.timepoint, 

137 is_add=pr.is_add, 

138 event_count=pr.event_count, 

139 log_acceptance_correction=pr.log_acceptance_correction, 

140 proposed_state=results.proposed_state, 

141 proposed_log_density=results.proposed_results.target_log_prob, 

142 seed=results.seed, 

143 ) 

144 

145 return SamplingAlgorithm(init_fn, step_fn)