Coverage for gemlib/mcmc/discrete_time_state_transition_model/left_censored_events_impl.py: 96%

78 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Metropolis Hastings implementation for left-censored events 

2in a discrete-time metapopulation epidemic model 

3""" 

4 

5from typing import NamedTuple 

6 

7import jax.numpy as jnp 

8import numpy as np 

9import tensorflow_probability.substrates.jax as tfp 

10from tensorflow_probability.substrates.jax.mcmc.internal import ( 

11 util as mcmc_util, 

12) 

13 

14from gemlib.prng_util import sanitize_key 

15 

16from .left_censored_events_proposal import ( 

17 left_censored_event_time_proposal, 

18) 

19 

20tfd = tfp.distributions 

21 

22__all__ = ["UncalibratedLeftCensoredEventTimesUpdate"] 

23 

24LEFT_CENSORED_EVENT_TIMES_UPDATE_TUPLE_LEN = 2 

25 

26 

27def _update_state(update, current_state, transition_idx, incidence_matrix): 

28 current_initial_conditions, current_events = ( 

29 jnp.asarray(x) for x in current_state 

30 ) 

31 

32 # -1 is moving events forward in time 

33 sign = jnp.array([-1, 1])[update["direction"]] 

34 events_delta = update["num_events"] * sign 

35 

36 # Update initial conditions 

37 indices = ( 

38 jnp.broadcast_to( 

39 update["unit"], [current_initial_conditions.shape[-1]] 

40 ), 

41 np.arange(current_initial_conditions.shape[-1]), 

42 ) 

43 new_initial_conditions = current_initial_conditions.at[indices].add( 

44 events_delta * incidence_matrix[..., transition_idx] 

45 ) 

46 

47 # Update events 

48 new_events = current_events.at[ 

49 update["timepoint"], update["unit"], transition_idx 

50 ].subtract(events_delta) 

51 

52 return new_initial_conditions, new_events 

53 

54 

55def _reverse_update(update): 

56 direction = (update["direction"] + 1) % 2 

57 return { 

58 "unit": update["unit"], 

59 "timepoint": update["timepoint"], 

60 "direction": direction, 

61 "num_events": update["num_events"], 

62 } 

63 

64 

65class LeftCensoredEventTimeResults(NamedTuple): 

66 log_acceptance_correction: float 

67 target_log_prob: float 

68 unit: int 

69 timepoint: int 

70 direction: int 

71 num_events: int 

72 seed: tuple[int, int] 

73 

74 

75class UncalibratedLeftCensoredEventTimesUpdate(tfp.mcmc.TransitionKernel): 

76 """UncalibratedLeftCensoredEventTimesUpdate""" 

77 

78 def __init__( 

79 self, 

80 target_log_prob_fn, 

81 transition_index, 

82 incidence_matrix, 

83 max_timepoint, 

84 max_events, 

85 name=None, 

86 ): 

87 """An uncalibrated random walk for initial conditions. 

88 :param target_log_prob_fn: the log density of the target distribution 

89 :param transition_index: the index of the transition to adjust 

90 :param incidence_matrix: the `[S,R]` incidence matrix 

91 :param max_timepoint: max timepoint up to which to move events 

92 :param max_events: max number of events per unit/timepoint to move 

93 """ 

94 self._name = name 

95 self._parameters = { 

96 "target_log_prob_fn": target_log_prob_fn, 

97 "transition_index": transition_index, 

98 "incidence_matrix": incidence_matrix, 

99 "max_timepoint": max_timepoint, 

100 "max_events": max_events, 

101 "name": name, 

102 } 

103 

104 @property 

105 def target_log_prob_fn(self): 

106 return self._parameters["target_log_prob_fn"] 

107 

108 @property 

109 def transition_index(self): 

110 return self._parameters["transition_index"] 

111 

112 @property 

113 def incidence_matrix(self): 

114 return self._parameters["incidence_matrix"] 

115 

116 @property 

117 def max_timepoint(self): 

118 return self._parameters["max_timepoint"] 

119 

120 @property 

121 def max_events(self): 

122 return self._parameters["max_events"] 

123 

124 @property 

125 def name(self): 

126 return self._parameters["name"] 

127 

128 @property 

129 def parameters(self): 

130 return self._parameters 

131 

132 @property 

133 def is_calibrated(self): 

134 return False 

135 

136 def one_step(self, current_state, previous_kernel_results, seed=None): # noqa: ARG002 

137 """Update the initial conditions 

138 

139 :param current_state: a tuple of `(current_initial_conditions, 

140 current_events)` 

141 :param previous_kernel_results: previous kernel results tuple 

142 :param seed: optional seed tuple `(int32, int32)` 

143 :returns: new state tuple `(next_initial_conditions, next_events)` 

144 """ 

145 if (not mcmc_util.is_list_like(current_state)) and ( 

146 len(current_state) == LEFT_CENSORED_EVENT_TIMES_UPDATE_TUPLE_LEN 

147 ): 

148 raise ValueError( 

149 f"State for LeftCensoredEventTimesUpdate must be list/tuple\ 

150 of length {LEFT_CENSORED_EVENT_TIMES_UPDATE_TUPLE_LEN}" 

151 ) 

152 

153 proposal = left_censored_event_time_proposal( 

154 events=current_state[1], 

155 initial_state=current_state[0], 

156 transition=self.transition_index, 

157 incidence_matrix=self.incidence_matrix, 

158 num_units=1, 

159 max_timepoint=self.max_timepoint, 

160 max_events=self.max_events, 

161 name=f"{self.name}/fwd_proposal", 

162 ) 

163 fwd_update = proposal.sample(seed=seed) 

164 fwd_proposal_log_prob = proposal.log_prob( 

165 fwd_update, name="fwd_proposal_log_prob" 

166 ) 

167 

168 next_state = _update_state( 

169 fwd_update, 

170 current_state, 

171 self.transition_index, 

172 self.incidence_matrix, 

173 ) 

174 next_target_log_prob = self.target_log_prob_fn(*next_state) 

175 

176 rev_update = _reverse_update(fwd_update) 

177 rev_proposal = left_censored_event_time_proposal( 

178 events=next_state[1], 

179 initial_state=next_state[0], 

180 transition=self.transition_index, 

181 incidence_matrix=self.incidence_matrix, 

182 num_units=1, 

183 max_timepoint=self.max_timepoint, 

184 max_events=self.max_events, 

185 name=f"{self.name}/rev_proposal", 

186 ) 

187 rev_proposal_log_prob = rev_proposal.log_prob(rev_update) 

188 log_acceptance_correction = jnp.sum( 

189 rev_proposal_log_prob - fwd_proposal_log_prob 

190 ) 

191 results = ( 

192 next_state, 

193 LeftCensoredEventTimeResults( 

194 log_acceptance_correction=log_acceptance_correction, 

195 target_log_prob=next_target_log_prob, 

196 unit=fwd_update["unit"], 

197 timepoint=fwd_update["timepoint"], 

198 direction=fwd_update["direction"], 

199 num_events=fwd_update["num_events"], 

200 seed=seed, 

201 ), 

202 ) 

203 

204 return results 

205 

206 def bootstrap_results(self, init_state): 

207 if (not mcmc_util.is_list_like(init_state)) and ( 

208 len(init_state) == 2 # noqa: PLR2004 

209 ): 

210 raise ValueError( 

211 f"State for LeftCensoredEventTimesUpdate must be list/tuple\ 

212 of length {LEFT_CENSORED_EVENT_TIMES_UPDATE_TUPLE_LEN}" 

213 ) 

214 

215 initial_conditions = init_state[0] 

216 events = init_state[1] 

217 init_target_log_prob = self.target_log_prob_fn( 

218 initial_conditions, events 

219 ) 

220 return LeftCensoredEventTimeResults( 

221 log_acceptance_correction=jnp.asarray( 

222 0.0, dtype=init_target_log_prob.dtype 

223 ), 

224 target_log_prob=init_target_log_prob, 

225 unit=jnp.zeros([1], dtype=np.int32), 

226 timepoint=jnp.zeros([1], dtype=np.int32), 

227 direction=jnp.array(0, dtype=np.int32), 

228 num_events=jnp.ones([1], dtype=np.int32), 

229 seed=sanitize_key(0), 

230 )