Coverage for gemlib/mcmc/discrete_time_state_transition_model/left_censored_events_proposal.py: 100%
45 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Implements a proposal for left-censored events"""
3import jax
4import jax.numpy as jnp
5import numpy as np
6import tensorflow_probability.substrates.jax as tfp
8from gemlib.distributions.discrete_markov import compute_state
9from gemlib.distributions.kcategorical import UniformKCategorical
10from gemlib.distributions.uniform_integer import UniformInteger
11from gemlib.util import states_from_transition_idx
13tfd = tfp.distributions
16def _mask_max(x, t, axis=0):
17 """Fills all elements where `i>t` along axis `axis`
18 of `x` with `x.dtype.max`.
19 """
20 mask = (np.arange(x.shape[axis]) > t) * np.finfo(x.dtype).max
21 return x + mask[:, np.newaxis]
24def left_censored_event_time_proposal(
25 events,
26 initial_state,
27 transition,
28 incidence_matrix,
29 num_units,
30 max_timepoint,
31 max_events,
32 dtype=np.int32,
33 name=None,
34):
35 """Propose to move `transition` events into or out of a randomly selected
36 time-window [0, max_timepoint] (n.b. inclusive!) subject to bounds imposed
37 by a state-transition process described by `stoichiometry`. The event
38 timeseries describes the number of events occuring along `R` transitions in
39 `M` coupled units across `T` timepoints between `S` states.
41 :param events: a [T, M, R] tensor describing number of events for each
42 transition in each unit at each timepoint.
43 :param initial_state: a [M, S] tensor describing the initial values of the
44 state at the first timepoint.
45 :param transition: the transition for which we wish to move events backwards
46 or forwards in time.
47 :param incidence_matrix: a [S, R] matrix describing the change in the value
48 of each state in response to an event along each
49 transition.
50 :param num_units: number of units to propose for (currently restricted to
51 1!)
52 :param max_timepoint: events are moved to and from time window `[0,
53 max_timepoint]`
54 :param dtype=tf.int32: the return type of the update.
55 :param name: name of the returned JointDistributionNamed.
56 :returns: a JointDistributionNamed object for which `sample()` returns the
57 id of a unit, distance and direction through which to move events,
58 and number of events.
59 """
60 assert num_units == 1, (
61 "LeftCensoredEventTimeProposal with `num_units!=1` is not supported"
62 )
64 events = jnp.asarray(events)
65 initial_state = jnp.asarray(initial_state)
66 time_range = np.arange(0, max_timepoint + 1)
68 # Identify the source and destination states for `transition`
69 src_state_idx, dest_state_idx = states_from_transition_idx(
70 transition,
71 incidence_matrix,
72 )
73 dtype = events.dtype
75 def unit():
76 """Draw unit to update"""
77 return UniformKCategorical(
78 num_units,
79 mask=np.ones(events.shape[-2], dtype=events.dtype),
80 float_dtype=events.dtype,
81 name="unit",
82 )
84 def timepoint():
85 """Draw timepoint to move events to or from"""
86 return UniformInteger(
87 low=[0], high=[max_timepoint + 1], float_dtype=dtype
88 )
90 def direction():
91 r"""Do we move events from [-\infty, 0) into [0, max_timepoint]
92 or vice versa?
93 0=move from past into present
94 1=move from present into past
95 """
96 return UniformInteger(low=[0], high=[2], float_dtype=dtype)
98 def num_events(unit, timepoint, direction):
99 """Draw a number of events to move to/from `timepoint` in `unit`
100 from the past into the present or vice versa according to `direction`.
101 The number of possible events is bounded by the topology of the state
102 transition model or a user-configurable `max_events`, whichever is the
103 minimum.
104 """
105 # Events is a [num_timepoints, num_units, num_transitions] tensor.
106 # We need to gather on the first 2 dimensions,
107 # i.e. events[time_range, unit, :]
108 unit_events = events[:, unit, :] # Tx1xS
109 unit_time_events = unit_events[time_range, :, :]
110 # Initial state is a [M, S] tensor. We gather the
111 unit_initial_state = initial_state[unit, :]
112 state = compute_state(
113 unit_initial_state,
114 unit_time_events,
115 incidence_matrix,
116 ) # TxMxS
118 def pull_from_past():
119 r"""Choose number of A->B transition events to move from
120 pre-history into the events time-window, subject to the bound
121 $$
122 \phi(m,t,B) = \min (1, b_1,\dots, b_T, n_{\max})
123 $$
124 """
125 state_ = state[..., dest_state_idx] # TxM
126 state_ = _mask_max(state_, timepoint + 1, axis=-2)
127 x_bound = jnp.minimum(jnp.min(state_, axis=-2), max_events)
128 return x_bound
130 def push_to_past():
131 r"""Choose number of A->B transition events to move from events
132 time-window into pre-history, subject to bound
133 $$
134 \phi{mtA} = \min (1, a_1,\dots, a_t, y^{AB}_{mt}, n_{\max})
135 $$
136 """
137 # Extract vector of source state values over [0, max_timepoint]
138 # for the selected metapopulation
139 state_ = state[..., src_state_idx] # TxM
140 state_ = _mask_max(state_, timepoint, axis=-2)
142 # Pushing to past is also bound by
143 # events[timepoint, unit, transition]
144 num_available_events = events[timepoint, unit, transition]
146 x_bound = jnp.concatenate(
147 [
148 state_,
149 num_available_events[np.newaxis, :],
150 np.asarray(max_events)[np.newaxis, np.newaxis],
151 ],
152 axis=-2,
153 )
154 x_bound = jnp.min(x_bound, axis=0)
156 return x_bound
158 x_bound = jax.lax.cond(
159 direction[0] == 0,
160 true_fun=pull_from_past,
161 false_fun=push_to_past,
162 ).astype(np.int32)
164 # tf.minimum is to cater for conditions where
165 # `events[unit,timepoint,target]` == 0. If it does,
166 # we effectively return an identity proposal.
167 return UniformInteger(
168 low=jnp.minimum(1, x_bound), # x_bound may be 0 if no events
169 high=x_bound + 1, # are available
170 float_dtype=dtype,
171 )
173 return tfd.JointDistributionNamed(
174 {
175 "unit": unit,
176 "timepoint": timepoint,
177 "direction": direction,
178 "num_events": num_events,
179 },
180 name=name,
181 )