Coverage for gemlib/mcmc/discrete_time_state_transition_model/right_censored_events_proposal.py: 100%
41 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Right-censored event time proposal mechanism"""
3import jax.numpy as jnp
4import numpy as np
5import tensorflow_probability.substrates.jax as tfp
7from gemlib.distributions.uniform_integer import UniformInteger
9tfd = tfp.distributions
10Root = tfd.JointDistribution.Root
12Tensor = np.typing.NDArray
15def _slice_min(state_tensor, start):
16 """Compute `min(state_tensor[start:]` in an XLA-safe way
18 Args
19 ----
20 state_tensor: a 1-D tensor
21 start: an index into state_tensor
23 Return
24 ------
25 `min(state_tensor[start:]`
26 """
27 state_tensor = jnp.asarray(state_tensor)
29 masked_state_tensor = jnp.where(
30 jnp.arange(state_tensor.shape[-1]) < start,
31 np.inf,
32 state_tensor,
33 )
35 return jnp.min(masked_state_tensor)
38def add_occult_proposal(
39 count_max: int,
40 events: Tensor,
41 src_state: Tensor,
42 name=None,
43):
44 events = jnp.asarray(events)
45 src_state = jnp.asarray(src_state)
47 num_times = events.shape[-2]
48 num_units = events.shape[-1]
50 def proposal():
51 # Select unit
52 unit = yield Root(
53 UniformInteger(
54 low=0,
55 high=num_units,
56 float_dtype=events.dtype,
57 name="unit",
58 )
59 )
61 # Select timepoint
62 timepoint = yield Root(
63 UniformInteger(
64 low=0,
65 high=num_times,
66 float_dtype=events.dtype,
67 name="timepoint",
68 )
69 )
71 # event_count is bounded by the minimum value of the source state
72 state_bound = _slice_min(src_state[..., unit], timepoint)
73 bound = jnp.minimum(state_bound, count_max).astype(np.int32)
75 yield UniformInteger(
76 low=jnp.minimum(1, bound),
77 high=bound + 1,
78 float_dtype=events.dtype,
79 name="event_count",
80 )
82 return tfd.JointDistributionCoroutineAutoBatched(proposal, name=name)
85def del_occult_proposal(
86 count_max: int,
87 events: Tensor,
88 dest_state: Tensor,
89 name=None,
90):
91 events = jnp.asarray(events)
92 dest_state = jnp.asarray(dest_state)
94 def proposal():
95 # Select unit to delete events from
96 nonzero_units = jnp.any(events > 0, axis=-2)
97 probs = nonzero_units / jnp.linalg.norm(nonzero_units, ord=1, axis=-1)
98 unit = yield Root(
99 tfd.Categorical(
100 probs=probs,
101 name="unit",
102 )
103 )
104 # If there are no events to delete, unit will be events.shape[-1] + 1
105 # Therefore clip to ensure we don't get an error in the next stage.
106 unit = jnp.clip(unit, min=0, max=events.shape[-1] - 1)
108 # Select timepoint to delete events from
109 unit_events = events[..., unit] # T
110 probs = (unit_events > 0) / jnp.linalg.norm(
111 unit_events > 0, ord=1, axis=-1
112 )
113 timepoint = yield tfd.Categorical(probs=probs, name="timepoint")
115 # Clip if there are no events to delete
116 timepoint = jnp.clip(timepoint, min=0, max=events.shape[-2] - 1)
117 # Draw num to delete - this is bounded by the minimum value of the
118 # destination state over the range [timepoint, num_times)
119 unit_dest_state = dest_state[..., unit]
120 state_bound = _slice_min(unit_dest_state, timepoint + 1)
121 bound = jnp.minimum(state_bound, events[..., timepoint, unit])
122 bound = jnp.minimum(bound, count_max).astype(np.int32)
124 yield UniformInteger(
125 low=jnp.minimum(1, bound),
126 high=bound + 1,
127 float_dtype=events.dtype,
128 name="event_count",
129 )
131 return tfd.JointDistributionCoroutineAutoBatched(proposal, name=name)