Coverage for gemlib/mcmc/discrete_time_state_transition_model/move_events.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Move partially-censored events in DiscreteTimeStateTransitionModel""" 

2 

3from typing import NamedTuple 

4 

5import jax 

6import tensorflow_probability.substrates.jax as tfp 

7 

8from gemlib.mcmc.discrete_time_state_transition_model.move_events_impl import ( 

9 UncalibratedEventTimesUpdate, 

10) 

11from gemlib.mcmc.sampling_algorithm import ChainState, SamplingAlgorithm 

12 

13Tensor = jax.Array 

14 

15 

16class MoveEventsState(NamedTuple): 

17 num_units: int 

18 delta_max: int 

19 count_max: int 

20 

21 

22class MoveEventsInfo(NamedTuple): 

23 is_accepted: bool 

24 initial_conditions: float 

25 proposed_events: Tensor 

26 proposed_log_density: float 

27 log_accept_ratio: float 

28 seed: Tensor 

29 

30 

31def move_events( 

32 incidence_matrix: Tensor, 

33 transition_index: int, 

34 num_units: int, 

35 delta_max: int, 

36 count_max: int, 

37 name: str | None = None, 

38): 

39 r"""Move partially-censored transition events 

40 

41 This kernel provides an MCMC algorithm that moves partially-censored 

42 transition events in a DiscreteTimeStateTransitionModel event timeseres. 

43 It caters for the situation in which a transition is known to have occurred 

44 (e.g. as a result of a subsequent case detection), but the time at 

45 which it occurred is unknown. 

46 

47 Args: 

48 incidence_matrix: the state-transition graph incidence matrix 

49 transition_index: the index of the transition in `incidence_matrix` 

50 to update. 

51 num_units: the number of epidemiological units to update at once 

52 delta_max: the maximum time interval over which to move transition 

53 times. 

54 count_max: the maximum number of transitions to move at once. 

55 name: the name of the kernel. 

56 

57 Returns: 

58 An instance of :obj:`SamplingAlgorithm` with members 

59 :obj:`init(tlp_fn: LogProbFnType, position: Position, 

60 initial_conditions: Tensor)` and :obj:`step(tlp_fn: LogProbFnType, 

61 cks: ChainAndKernelState, seed: SeedType, initial_conditions: Tensor)`. 

62 :obj:`initial_conditions` is an extra keyword 

63 argument required to be passed to the :obj:`init` and :obj:`step` 

64 functions. 

65 

66 References: 

67 Philip D O'Neill and Gareth O Roberts (1999) Baysian inference for 

68 partially observed stochastic epidemics. _Journal of the Royal 

69 Statistical Society: Series A (Statistics in Society), **162**\ :121--129. 

70 

71 Jewell *et al.* (2023) Bayesian inference for high-dimensional discrete- 

72 time epidemic models: spatial dynamics of the UK COVID-19 outbreak. 

73 Pre-print arXiv:2306.07987 

74 """ 

75 

76 def _build_kernel(target_log_prob_fn, initial_conditions): 

77 return tfp.mcmc.MetropolisHastings( 

78 inner_kernel=UncalibratedEventTimesUpdate( 

79 target_log_prob_fn, 

80 incidence_matrix=incidence_matrix, 

81 initial_conditions=initial_conditions, 

82 target_transition_id=transition_index, 

83 num_units=num_units, 

84 delta_max=delta_max, 

85 count_max=count_max, 

86 name=name or "move_events", 

87 ) 

88 ) 

89 

90 def init_fn(target_log_prob_fn, position, initial_conditions): 

91 kernel = _build_kernel(target_log_prob_fn, initial_conditions) 

92 results = kernel.bootstrap_results(position) 

93 chain_state = ChainState( 

94 position=position, 

95 log_density=results.accepted_results.target_log_prob, 

96 ) 

97 kernel_state = MoveEventsState(num_units, delta_max, count_max) 

98 

99 return chain_state, kernel_state 

100 

101 def step_fn( 

102 target_log_prob_fn, target_and_kernel_state, seed, initial_conditions 

103 ): 

104 target_chain_state, kernel_state = target_and_kernel_state 

105 

106 kernel = _build_kernel(target_log_prob_fn, initial_conditions) 

107 

108 new_position, results = kernel.one_step( 

109 target_chain_state.position, 

110 kernel.bootstrap_results(target_chain_state.position), 

111 seed=seed, 

112 ) 

113 

114 new_chain_and_kernel_state = ( 

115 ChainState( 

116 position=new_position, 

117 log_density=results.accepted_results.target_log_prob, 

118 ), 

119 kernel_state, 

120 ) 

121 

122 return new_chain_and_kernel_state, MoveEventsInfo( 

123 is_accepted=results.is_accepted, 

124 initial_conditions=initial_conditions, 

125 proposed_events=results.proposed_state, 

126 proposed_log_density=results.proposed_results.target_log_prob, 

127 log_accept_ratio=results.log_accept_ratio, 

128 seed=seed, 

129 ) 

130 

131 return SamplingAlgorithm(init_fn, step_fn)