Coverage for gemlib/mcmc/discrete_time_state_transition_model/right_censored_events_impl.py: 96%

107 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Sampler for discrete-space occult events""" 

2 

3from collections.abc import Callable 

4from typing import NamedTuple 

5 

6import jax 

7import jax.numpy as jnp 

8import numpy as np 

9import tensorflow_probability.substrates.jax as tfp 

10from tensorflow_probability.substrates.jax.mcmc.internal import ( 

11 util as mcmc_util, 

12) 

13 

14from gemlib.distributions.discrete_markov import ( 

15 compute_state, 

16) 

17from gemlib.mcmc.discrete_time_state_transition_model.right_censored_events_proposal import ( # noqa:E501 

18 add_occult_proposal, 

19 del_occult_proposal, 

20) 

21from gemlib.mcmc.sampling_algorithm import Position 

22from gemlib.prng_util import sanitize_key 

23from gemlib.util import transition_coords 

24 

25tfd = tfp.distributions 

26Tensor = np.typing.NDArray 

27 

28__all__ = ["UncalibratedOccultUpdate"] 

29 

30 

31PROB_DIRECTION = 0.5 

32 

33 

34class OccultKernelResults(NamedTuple): 

35 log_acceptance_correction: float 

36 target_log_prob: float 

37 unit: int 

38 timepoint: int 

39 is_add: bool 

40 event_count: int 

41 seed: tuple[int, int] 

42 

43 

44def _is_row_nonzero(m): 

45 return jnp.sum(m, axis=-1) > 0.0 

46 

47 

48def _maybe_expand_dims(x): 

49 """If x is a scalar, give it at least 1 dimension""" 

50 return jnp.atleast_1d(x) 

51 

52 

53def _add_events(events, unit, timepoint, target_transition_id, event_count): 

54 """Adds `x_star` events to metapopulation `m`, 

55 time `t`, transition `x` in `events`. 

56 """ 

57 events = jnp.asarray(events) 

58 return events.at[timepoint, unit, target_transition_id].add(event_count) 

59 

60 

61class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel): 

62 """UncalibratedOccultUpdate""" 

63 

64 def __init__( 

65 self, 

66 target_log_prob_fn: Callable[[Position], float], 

67 incidence_matrix: Tensor, 

68 initial_conditions: Tensor, 

69 target_transition_id: int, 

70 count_max: int, 

71 t_range, 

72 name=None, 

73 ): 

74 """An uncalibrated random walk for event times. 

75 :param target_log_prob_fn: the log density of the target distribution 

76 :param target_event_id: the position in the last dimension of the events 

77 tensor that we wish to move 

78 :param t_range: a tuple containing earliest and latest times between 

79 which to update occults. 

80 :param seed: a random seed 

81 :param name: the name of the update step 

82 """ 

83 self._parameters = dict(locals()) 

84 self._name = name or "uncalibrated_occult_update" 

85 self._dtype = jnp.asarray(initial_conditions).dtype 

86 

87 @property 

88 def target_log_prob_fn(self): 

89 return self._parameters["target_log_prob_fn"] 

90 

91 @property 

92 def incidence_matrix(self): 

93 return self._parameters["incidence_matrix"] 

94 

95 @property 

96 def target_transition_id(self): 

97 return self._parameters["target_transition_id"] 

98 

99 @property 

100 def initial_conditions(self): 

101 return self._parameters["initial_conditions"] 

102 

103 @property 

104 def count_max(self): 

105 return self._parameters["count_max"] 

106 

107 @property 

108 def t_range(self): 

109 return self._parameters["t_range"] 

110 

111 @property 

112 def name(self): 

113 return self._parameters["name"] 

114 

115 @property 

116 def parameters(self): 

117 """Return `dict` of ``__init__`` arguments and their values.""" 

118 return self._parameters 

119 

120 @property 

121 def is_calibrated(self): 

122 return False 

123 

124 def one_step(self, current_events, previous_kernel_results, seed): # noqa: ARG002 

125 """One update of event times. 

126 :param current_events: a [T, M, R] tensor containing number of events 

127 per time t, metapopulation m, 

128 and transition r. 

129 :param previous_kernel_results: an object of type 

130 UncalibratedRandomWalkResults. 

131 :returns: a tuple containing new_state and UncalibratedRandomWalkResults 

132 """ 

133 with jax.named_scope("occult_rw/onestep"): 

134 current_events = jnp.asarray(current_events) 

135 initial_conditions = jnp.asarray(self.initial_conditions) 

136 

137 proposal_seed, add_del_seed = jax.random.split(seed) 

138 t_range_slice = slice(*self.t_range) 

139 

140 state = compute_state( 

141 initial_conditions, 

142 current_events, 

143 self.incidence_matrix, 

144 ) 

145 src_dest_ids = transition_coords(self.incidence_matrix)[ 

146 self.target_transition_id, : 

147 ] 

148 

149 # Pull out the section of events and state that are within 

150 # the requested time interval - we focus on this, and insert 

151 # updated values back into the full events at the end. 

152 range_events = current_events[t_range_slice] 

153 state_slice = state[t_range_slice] 

154 

155 def add_occult_fn(): 

156 with jax.named_scope("true_fn"): 

157 proposal = add_occult_proposal( 

158 count_max=self.count_max, 

159 events=range_events[..., self.target_transition_id], 

160 src_state=state_slice[..., src_dest_ids[0]], 

161 ) 

162 update = proposal.sample(seed=proposal_seed) 

163 

164 next_events = _add_events( 

165 events=range_events, 

166 unit=update.unit, 

167 timepoint=update.timepoint, 

168 target_transition_id=self.target_transition_id, 

169 event_count=update.event_count, 

170 ) 

171 

172 next_dest_state = compute_state( 

173 state_slice[0], next_events, self.incidence_matrix 

174 ) 

175 reverse = del_occult_proposal( 

176 count_max=self.count_max, 

177 events=next_events[..., self.target_transition_id], 

178 dest_state=next_dest_state[..., src_dest_ids[1]], 

179 ) 

180 q_fwd = jnp.sum(proposal.log_prob(update)) 

181 q_rev = jnp.sum(reverse.log_prob(update)) 

182 log_acceptance_correction = q_rev - q_fwd 

183 

184 return ( 

185 update, 

186 next_events, 

187 log_acceptance_correction, 

188 True, 

189 ) 

190 

191 def del_occult_fn(): 

192 with jax.named_scope("false_fn"): 

193 proposal = del_occult_proposal( 

194 count_max=self.count_max, 

195 events=range_events[..., self.target_transition_id], 

196 dest_state=state_slice[..., src_dest_ids[1]], 

197 ) 

198 update = proposal.sample(seed=proposal_seed) 

199 

200 next_events = _add_events( 

201 events=range_events, 

202 unit=update.unit, 

203 timepoint=update.timepoint, 

204 target_transition_id=self.target_transition_id, 

205 event_count=-update.event_count, 

206 ) 

207 

208 next_src_state = compute_state( 

209 state_slice[0], next_events, self.incidence_matrix 

210 ) 

211 reverse = add_occult_proposal( 

212 count_max=self.count_max, 

213 events=next_events[..., self.target_transition_id], 

214 src_state=next_src_state[..., src_dest_ids[0]], 

215 ) 

216 q_fwd = jnp.sum(proposal.log_prob(update)) 

217 q_rev = jnp.sum(reverse.log_prob(update)) 

218 log_acceptance_correction = q_rev - q_fwd 

219 

220 return ( 

221 update, 

222 next_events, 

223 log_acceptance_correction, 

224 False, 

225 ) 

226 

227 u = tfd.Uniform().sample(seed=add_del_seed) 

228 delta, next_range_events, log_acceptance_correction, is_add = ( 

229 jax.lax.cond( 

230 (u < PROB_DIRECTION) 

231 & (jnp.count_nonzero(range_events) > 0), 

232 del_occult_fn, 

233 add_occult_fn, 

234 ) 

235 ) 

236 

237 # Update current_events with the new next_range_events tensor 

238 next_events = current_events.at[t_range_slice].set( 

239 next_range_events 

240 ) 

241 next_target_log_prob = self.target_log_prob_fn(next_events) 

242 

243 return ( 

244 next_events, 

245 OccultKernelResults( 

246 log_acceptance_correction=log_acceptance_correction, 

247 target_log_prob=next_target_log_prob, 

248 unit=delta.unit, 

249 timepoint=delta.timepoint, 

250 is_add=is_add, 

251 event_count=delta.event_count, 

252 seed=seed, 

253 ), 

254 ) 

255 

256 def bootstrap_results(self, init_state): 

257 with jax.named_scope("uncalibrated_event_times_rw/bootstrap_results"): 

258 if not mcmc_util.is_list_like(init_state): 

259 init_state = [init_state] 

260 

261 init_state = [jnp.asarray(x, dtype=self._dtype) for x in init_state] 

262 init_target_log_prob = self.target_log_prob_fn(*init_state) 

263 return OccultKernelResults( 

264 log_acceptance_correction=jnp.asarray( 

265 0.0, dtype=init_target_log_prob.dtype 

266 ), 

267 target_log_prob=init_target_log_prob, 

268 unit=jnp.zeros((), dtype=np.int32), 

269 timepoint=jnp.zeros((), dtype=np.int32), 

270 is_add=jnp.asarray(True), 

271 event_count=jnp.zeros((), dtype=np.int32), 

272 seed=sanitize_key(0), 

273 )