Coverage for gemlib/mcmc/discrete_time_state_transition_model/left_censored_events_proposal.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Implements a proposal for left-censored events""" 

2 

3import jax 

4import jax.numpy as jnp 

5import numpy as np 

6import tensorflow_probability.substrates.jax as tfp 

7 

8from gemlib.distributions.discrete_markov import compute_state 

9from gemlib.distributions.kcategorical import UniformKCategorical 

10from gemlib.distributions.uniform_integer import UniformInteger 

11from gemlib.util import states_from_transition_idx 

12 

13tfd = tfp.distributions 

14 

15 

16def _mask_max(x, t, axis=0): 

17 """Fills all elements where `i>t` along axis `axis` 

18 of `x` with `x.dtype.max`. 

19 """ 

20 mask = (np.arange(x.shape[axis]) > t) * np.finfo(x.dtype).max 

21 return x + mask[:, np.newaxis] 

22 

23 

24def left_censored_event_time_proposal( 

25 events, 

26 initial_state, 

27 transition, 

28 incidence_matrix, 

29 num_units, 

30 max_timepoint, 

31 max_events, 

32 dtype=np.int32, 

33 name=None, 

34): 

35 """Propose to move `transition` events into or out of a randomly selected 

36 time-window [0, max_timepoint] (n.b. inclusive!) subject to bounds imposed 

37 by a state-transition process described by `stoichiometry`. The event 

38 timeseries describes the number of events occuring along `R` transitions in 

39 `M` coupled units across `T` timepoints between `S` states. 

40 

41 :param events: a [T, M, R] tensor describing number of events for each 

42 transition in each unit at each timepoint. 

43 :param initial_state: a [M, S] tensor describing the initial values of the 

44 state at the first timepoint. 

45 :param transition: the transition for which we wish to move events backwards 

46 or forwards in time. 

47 :param incidence_matrix: a [S, R] matrix describing the change in the value 

48 of each state in response to an event along each 

49 transition. 

50 :param num_units: number of units to propose for (currently restricted to 

51 1!) 

52 :param max_timepoint: events are moved to and from time window `[0, 

53 max_timepoint]` 

54 :param dtype=tf.int32: the return type of the update. 

55 :param name: name of the returned JointDistributionNamed. 

56 :returns: a JointDistributionNamed object for which `sample()` returns the 

57 id of a unit, distance and direction through which to move events, 

58 and number of events. 

59 """ 

60 assert num_units == 1, ( 

61 "LeftCensoredEventTimeProposal with `num_units!=1` is not supported" 

62 ) 

63 

64 events = jnp.asarray(events) 

65 initial_state = jnp.asarray(initial_state) 

66 time_range = np.arange(0, max_timepoint + 1) 

67 

68 # Identify the source and destination states for `transition` 

69 src_state_idx, dest_state_idx = states_from_transition_idx( 

70 transition, 

71 incidence_matrix, 

72 ) 

73 dtype = events.dtype 

74 

75 def unit(): 

76 """Draw unit to update""" 

77 return UniformKCategorical( 

78 num_units, 

79 mask=np.ones(events.shape[-2], dtype=events.dtype), 

80 float_dtype=events.dtype, 

81 name="unit", 

82 ) 

83 

84 def timepoint(): 

85 """Draw timepoint to move events to or from""" 

86 return UniformInteger( 

87 low=[0], high=[max_timepoint + 1], float_dtype=dtype 

88 ) 

89 

90 def direction(): 

91 r"""Do we move events from [-\infty, 0) into [0, max_timepoint] 

92 or vice versa? 

93 0=move from past into present 

94 1=move from present into past 

95 """ 

96 return UniformInteger(low=[0], high=[2], float_dtype=dtype) 

97 

98 def num_events(unit, timepoint, direction): 

99 """Draw a number of events to move to/from `timepoint` in `unit` 

100 from the past into the present or vice versa according to `direction`. 

101 The number of possible events is bounded by the topology of the state 

102 transition model or a user-configurable `max_events`, whichever is the 

103 minimum. 

104 """ 

105 # Events is a [num_timepoints, num_units, num_transitions] tensor. 

106 # We need to gather on the first 2 dimensions, 

107 # i.e. events[time_range, unit, :] 

108 unit_events = events[:, unit, :] # Tx1xS 

109 unit_time_events = unit_events[time_range, :, :] 

110 # Initial state is a [M, S] tensor. We gather the 

111 unit_initial_state = initial_state[unit, :] 

112 state = compute_state( 

113 unit_initial_state, 

114 unit_time_events, 

115 incidence_matrix, 

116 ) # TxMxS 

117 

118 def pull_from_past(): 

119 r"""Choose number of A->B transition events to move from 

120 pre-history into the events time-window, subject to the bound 

121 $$ 

122 \phi(m,t,B) = \min (1, b_1,\dots, b_T, n_{\max}) 

123 $$ 

124 """ 

125 state_ = state[..., dest_state_idx] # TxM 

126 state_ = _mask_max(state_, timepoint + 1, axis=-2) 

127 x_bound = jnp.minimum(jnp.min(state_, axis=-2), max_events) 

128 return x_bound 

129 

130 def push_to_past(): 

131 r"""Choose number of A->B transition events to move from events 

132 time-window into pre-history, subject to bound 

133 $$ 

134 \phi{mtA} = \min (1, a_1,\dots, a_t, y^{AB}_{mt}, n_{\max}) 

135 $$ 

136 """ 

137 # Extract vector of source state values over [0, max_timepoint] 

138 # for the selected metapopulation 

139 state_ = state[..., src_state_idx] # TxM 

140 state_ = _mask_max(state_, timepoint, axis=-2) 

141 

142 # Pushing to past is also bound by 

143 # events[timepoint, unit, transition] 

144 num_available_events = events[timepoint, unit, transition] 

145 

146 x_bound = jnp.concatenate( 

147 [ 

148 state_, 

149 num_available_events[np.newaxis, :], 

150 np.asarray(max_events)[np.newaxis, np.newaxis], 

151 ], 

152 axis=-2, 

153 ) 

154 x_bound = jnp.min(x_bound, axis=0) 

155 

156 return x_bound 

157 

158 x_bound = jax.lax.cond( 

159 direction[0] == 0, 

160 true_fun=pull_from_past, 

161 false_fun=push_to_past, 

162 ).astype(np.int32) 

163 

164 # tf.minimum is to cater for conditions where 

165 # `events[unit,timepoint,target]` == 0. If it does, 

166 # we effectively return an identity proposal. 

167 return UniformInteger( 

168 low=jnp.minimum(1, x_bound), # x_bound may be 0 if no events 

169 high=x_bound + 1, # are available 

170 float_dtype=dtype, 

171 ) 

172 

173 return tfd.JointDistributionNamed( 

174 { 

175 "unit": unit, 

176 "timepoint": timepoint, 

177 "direction": direction, 

178 "num_events": num_events, 

179 }, 

180 name=name, 

181 )