Coverage for gemlib/mcmc/discrete_time_state_transition_model/left_censored_events_mh.py: 100%

44 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Left-censored events MCMC kernel for DiscreteTimeStateTransitionModel""" 

2 

3from functools import partial 

4from typing import NamedTuple 

5 

6import numpy as np 

7import tensorflow_probability.substrates.jax as tfp 

8 

9from gemlib.mcmc.discrete_time_state_transition_model.left_censored_events_impl import ( # noqa: E501 

10 UncalibratedLeftCensoredEventTimesUpdate, 

11) 

12from gemlib.mcmc.sampling_algorithm import ( 

13 ChainState, 

14 LogProbFnType, 

15 SamplingAlgorithm, 

16) 

17 

18Tensor = np.typing.NDArray 

19 

20 

21class LeftCensoredEventsState(NamedTuple): 

22 max_timepoint: int 

23 max_events: int 

24 

25 

26class LeftCensoredEventsInfo(NamedTuple): 

27 is_accepted: bool 

28 log_acceptance_correction: float 

29 target_log_prob: float 

30 unit: int 

31 timepoint: int 

32 direction: int 

33 num_events: int 

34 seed: tuple[int, int] 

35 

36 

37class LeftCensoredEventsPosition(NamedTuple): 

38 initial_conditions: Tensor 

39 events: Tensor 

40 

41 

42def _get_state_tuple( 

43 initial_conditions_varname: str, 

44 events_varname: str, 

45 position: NamedTuple, 

46): 

47 return LeftCensoredEventsPosition( 

48 getattr(position, initial_conditions_varname), 

49 getattr(position, events_varname), 

50 ) 

51 

52 

53def _repack_state_tuple( 

54 initial_conditions_varname: str, 

55 events_varname: str, 

56 new_structure: LeftCensoredEventsPosition, 

57 original_structure: NamedTuple, 

58): 

59 return original_structure.__class__( 

60 **{ 

61 initial_conditions_varname: new_structure.initial_conditions, 

62 events_varname: new_structure.events, 

63 } 

64 ) 

65 

66 

67def left_censored_events_mh( 

68 incidence_matrix: Tensor, 

69 transition_index: int, 

70 max_timepoint: int, 

71 max_events: int, 

72 events_varname: str, 

73 initial_conditions_varname: str, 

74 name: str | None = None, 

75): 

76 """Update initial conditions and events for DiscreteTimeStateTransitionModel 

77 

78 In observations of a DiscreteTimeStateTransitionModel realisation, 

79 there may be uncertainty about the initial conditions and hence the number 

80 of events occurring in the early part of the timeseries. This MCMC kernel 

81 provides a means of updating these left-censored events. 

82 

83 Args 

84 ---- 

85 incidence_matrix: the state-transition graph incidence matrix 

86 transition_index: the index of the transition in `incidence_matrix` to 

87 update 

88 max_timepoint: max timepoint up to which to propose moves 

89 max_events: max number of events per unit/timepoint to move 

90 events_varname: the name of the random variable holding the events 

91 timeseries in a NamedTuple supplied to both the `init` and 

92 `step` functions. 

93 initial_conditions_varname: the name of the random variable representing the 

94 initial conditions in a NamedTuple supplied to 

95 both the `init` and `step` functions. 

96 name: name of the kernel. 

97 

98 Returns 

99 ------- 

100 A instance of SamplingAlgorithm 

101 """ 

102 

103 canonize = partial( 

104 _get_state_tuple, initial_conditions_varname, events_varname 

105 ) 

106 uncanonize = partial( 

107 _repack_state_tuple, initial_conditions_varname, events_varname 

108 ) 

109 

110 def _build_kernel(target_log_prob_fn): 

111 return tfp.mcmc.MetropolisHastings( 

112 inner_kernel=UncalibratedLeftCensoredEventTimesUpdate( 

113 target_log_prob_fn, 

114 transition_index, 

115 incidence_matrix, 

116 max_timepoint, 

117 max_events, 

118 name, 

119 ) 

120 ) 

121 

122 def init_fn(target_log_prob_fn: LogProbFnType, target_state: NamedTuple): 

123 kernel = _build_kernel(target_log_prob_fn) 

124 results = kernel.bootstrap_results(canonize(target_state)) 

125 chain_state = ChainState( 

126 position=target_state, 

127 log_density=results.accepted_results.target_log_prob, 

128 ) 

129 kernel_state = LeftCensoredEventsState( 

130 max_timepoint=max_timepoint, max_events=max_events 

131 ) 

132 

133 return chain_state, kernel_state 

134 

135 def step_fn(target_log_prob_fn, target_and_kernel_state, seed): 

136 kernel = _build_kernel(target_log_prob_fn) 

137 target_chain_state, kernel_state = target_and_kernel_state 

138 

139 new_target_position, results = kernel.one_step( 

140 canonize(target_chain_state.position), 

141 kernel.bootstrap_results(canonize(target_chain_state.position)), 

142 seed=seed, 

143 ) 

144 

145 new_chain_and_kernel_state = ( 

146 ChainState( 

147 position=uncanonize( 

148 new_target_position, target_chain_state.position 

149 ), 

150 log_density=results.accepted_results.target_log_prob, 

151 ), 

152 kernel_state, 

153 ) 

154 

155 return new_chain_and_kernel_state, LeftCensoredEventsInfo( 

156 is_accepted=results.is_accepted, 

157 log_acceptance_correction=results.proposed_results.log_acceptance_correction, 

158 target_log_prob=results.proposed_results.target_log_prob, 

159 unit=results.proposed_results.unit, 

160 timepoint=results.proposed_results.timepoint, 

161 direction=results.proposed_results.direction, 

162 num_events=results.proposed_results.num_events, 

163 seed=seed, 

164 ) 

165 

166 return SamplingAlgorithm(init_fn, step_fn)