Coverage for gemlib/distributions/continuous_time_state_transition_model.py: 97%

61 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Continuous time state transition model""" 

2 

3from __future__ import annotations 

4 

5import math as pymath 

6from collections.abc import Callable, Sequence 

7 

8import jax 

9import jax.numpy as jnp 

10import numpy as np 

11import tensorflow_probability.substrates.jax as tfp 

12 

13from gemlib.distributions.continuous_markov import ( 

14 EventList, 

15 compute_state, 

16 continuous_markov_simulation, 

17 continuous_time_log_likelihood, 

18) 

19from gemlib.func_util import maybe_combine_fn 

20from gemlib.prng_util import sanitize_key 

21from gemlib.tensor_util import broadcast_fn_to 

22 

23# aliasing for convenience 

24Array = jax.Array 

25ArrayLike = jax.typing.ArrayLike 

26NDArray = np.typing.NDArray 

27tfd = tfp.distributions 

28 

29 

30class ContinuousTimeStateTransitionModel(tfd.Distribution): 

31 """Continuous time state transition model. 

32 

33 Example: 

34 A homogeneously mixing SIR model implementation:: 

35 

36 import jax.numpy as jnp 

37 from gemlib.distributions import ContinuousTimeStateTransitionModel 

38 

39 # Initial state, counts per compartment (S, I, R), for one 

40 # population 

41 initial_state = jnp.array([[99, 1, 0]], jnp.float32) 

42 

43 # Note that the incidence_matrix is treated as a static parameter, 

44 # and so is supplied as a numpy ndarray, not a jax Array. 

45 incidence_matrix = np.array( 

46 [ 

47 [-1, 0], 

48 [1, -1], 

49 [0, 1], 

50 ], 

51 dtype=np.float32, 

52 ) 

53 

54 

55 def si_rate(t, state): 

56 return 0.28 * state[:, 1] / jnp.sum(state, axis=-1) 

57 

58 

59 def ir_rate(t, state): 

60 return 0.14 

61 

62 

63 # Instantiate model 

64 sir = ContinuousTimeStateTransitionModel( 

65 transition_rate_fn=(si_rate, ir_rate), 

66 incidence_matrix=incidence_matrix, 

67 initial_state=initial_state, 

68 num_steps=100, 

69 ) 

70 

71 # One realisation of the epidemic process 

72 sim = sir.sample(seed=0) 

73 

74 # Compute the log probability of observing `sim` 

75 sir.log_prob(sim) 

76 """ 

77 

78 def __init__( 

79 self, 

80 transition_rate_fn: Sequence[Callable[[float, ArrayLike], ArrayLike]] 

81 | Callable[[float, ArrayLike], tuple[ArrayLike, ...]], 

82 incidence_matrix: NDArray, 

83 initial_state: Array, 

84 num_steps: int, 

85 initial_time: float = 0.0, 

86 validate_args: bool = False, 

87 allow_nan_stats: bool = True, 

88 name: str = "ContinuousTimeStateTransitionModel", 

89 ): 

90 """ 

91 Initializes a ContinuousTimeStateTransitionModel object. 

92 

93 Args: 

94 transition_rate_fn: Either a list of callables of the form 

95 :code:`fn(t: float, state: Tensor) -> Tensor` or a Python callable 

96 of the form :code:`fn(t: float, state: Tensor) -> tuple(Tensor,...)` 

97 . In the first 

98 (preferred) form, each callable in the list corresponds to the 

99 respective transition in :code:`incidence_matrix`. In the second 

100 form, the callable should return a :code:`tuple` of transition rate 

101 tensors corresponding to transitions in :code:`incidence_matrix`. 

102 **Note**: the second form will be deprecated in future releases of 

103 :code:`gemlib`. 

104 incidence_matrix: Matrix representing the incidence of transitions 

105 between states. 

106 initial_state: A :code:`[N, S]` tensor containing the initial state of 

107 the population of :code:`N` units in :code:`S` epidemiological 

108 classes. 

109 num_steps: the number of markov jumps for a single iteration. 

110 initial_time: Initial time of the model. Defaults to 0.0. 

111 name: Name of the model. Defaults to 

112 "ContinuousTimeStateTransitionModel". 

113 

114 """ 

115 parameters = dict(locals()) 

116 

117 self._incidence_matrix = jnp.asarray(incidence_matrix) 

118 self._initial_state = initial_state 

119 self._initial_time = jnp.asarray(initial_time) 

120 

121 self._transition_rate_fn = maybe_combine_fn(transition_rate_fn) 

122 

123 dtype = EventList( 

124 time=self._initial_time.dtype, 

125 transition=jnp.int32, 

126 unit=jnp.int32, 

127 ) 

128 

129 super().__init__( 

130 dtype=dtype, 

131 reparameterization_type=tfd.FULLY_REPARAMETERIZED, 

132 validate_args=validate_args, 

133 allow_nan_stats=allow_nan_stats, 

134 parameters=parameters, 

135 name=name, 

136 ) 

137 

138 @property 

139 def transition_rate_fn(self): 

140 """Transition rate function for the model.""" 

141 return self._parameters["transition_rate_fn"] 

142 

143 @property 

144 def incidence_matrix(self): 

145 """Incidence matrix for the model.""" 

146 return self._parameters["incidence_matrix"] 

147 

148 @property 

149 def initial_state(self): 

150 """Initial state of the model.""" 

151 return self._parameters["initial_state"] 

152 

153 @property 

154 def num_steps(self): 

155 """Number of events to simulate.""" 

156 return self._parameters["num_steps"] 

157 

158 @property 

159 def initial_time(self): 

160 """Initial wall clock for the model. Sets the time scale.""" 

161 return self._parameters["initial_time"] 

162 

163 def compute_state( 

164 self, event_list: EventList, include_final_state: bool = False 

165 ) -> jax.Array: 

166 """Compute state timeseries given an event list 

167 

168 Args: 

169 event_list: the event list, assumed to be sorted by time. 

170 include_final_state: should the final state be included in the 

171 returned timeseries? If `True`, then the time dimension of 

172 the returned tensor will be 1 greater than the length of the 

173 event list. If `False` (default) these will be equal. 

174 

175 Returns: 

176 A `[T, N, S]` tensor where `T` is the number of events, `N` is the 

177 number of units, and `S` is the number of states. 

178 """ 

179 return compute_state( 

180 self.incidence_matrix, 

181 self.initial_state, 

182 event_list, 

183 include_final_state, 

184 ) 

185 

186 def _sample_n(self, n: int, seed: int | Array | None = None) -> EventList: 

187 """ 

188 Samples n outcomes from the continuous time state transition model. 

189 

190 Args: 

191 n (int): The number of realisations of the Markov process to sample 

192 (currently ignored). 

193 seed (int, optional): The seed value for random number generation. 

194 Defaults to None. 

195 

196 Returns: 

197 Sample(s) from the continuous time state transition model. 

198 """ 

199 

200 key = sanitize_key(seed) 

201 keys = jax.random.split(key, n) 

202 

203 def one_sample(seed): 

204 return continuous_markov_simulation( 

205 transition_rate_fn=broadcast_fn_to( 

206 self._transition_rate_fn, 

207 self._initial_state.shape[:-1], 

208 ), 

209 incidence_matrix=self._incidence_matrix, 

210 initial_state=self._initial_state, 

211 initial_time=self._initial_time, 

212 num_markov_jumps=self.num_steps, 

213 seed=seed, 

214 ) 

215 

216 outcome = jax.vmap(one_sample)(keys) 

217 

218 return outcome 

219 

220 def _log_prob(self, value: EventList) -> Array: 

221 """ 

222 Computes the log probability of the given outcomes. 

223 

224 Args: 

225 value (EventList): an EventList object representing the 

226 outcomes. 

227 

228 Returns: 

229 float: The log probability of the given outcomes. 

230 """ 

231 value = jax.tree_util.tree_map(lambda x: jnp.asarray(x), value) 

232 

233 batch_shape = value.time.shape[:-1] 

234 flat_shape = ( 

235 pymath.prod(batch_shape), 

236 value.time.shape[-1], 

237 ) 

238 

239 value = jax.tree_util.tree_map(lambda x: x.reshape(flat_shape), value) 

240 

241 def one_log_prob(x): 

242 return continuous_time_log_likelihood( 

243 transition_rate_fn=broadcast_fn_to( 

244 self._transition_rate_fn, 

245 self._initial_state.shape[:-1], 

246 ), 

247 incidence_matrix=self.incidence_matrix, 

248 initial_state=self.initial_state, 

249 initial_time=self.initial_time, 

250 event_list=x, 

251 ) 

252 

253 log_probs = jax.vmap(one_log_prob)(value) 

254 

255 return jnp.reshape(log_probs, batch_shape) 

256 

257 def _event_shape(self) -> EventList: 

258 return EventList( 

259 time=(self.num_steps,), 

260 transition=(self.num_steps,), 

261 unit=(self.num_steps,), 

262 ) 

263 

264 def _batch_shape(self) -> tuple: 

265 return ()