Coverage for gemlib/mcmc/discrete_time_state_transition_model/right_censored_events_proposal.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Right-censored event time proposal mechanism""" 

2 

3import jax.numpy as jnp 

4import numpy as np 

5import tensorflow_probability.substrates.jax as tfp 

6 

7from gemlib.distributions.uniform_integer import UniformInteger 

8 

9tfd = tfp.distributions 

10Root = tfd.JointDistribution.Root 

11 

12Tensor = np.typing.NDArray 

13 

14 

15def _slice_min(state_tensor, start): 

16 """Compute `min(state_tensor[start:]` in an XLA-safe way 

17 

18 Args 

19 ---- 

20 state_tensor: a 1-D tensor 

21 start: an index into state_tensor 

22 

23 Return 

24 ------ 

25 `min(state_tensor[start:]` 

26 """ 

27 state_tensor = jnp.asarray(state_tensor) 

28 

29 masked_state_tensor = jnp.where( 

30 jnp.arange(state_tensor.shape[-1]) < start, 

31 np.inf, 

32 state_tensor, 

33 ) 

34 

35 return jnp.min(masked_state_tensor) 

36 

37 

38def add_occult_proposal( 

39 count_max: int, 

40 events: Tensor, 

41 src_state: Tensor, 

42 name=None, 

43): 

44 events = jnp.asarray(events) 

45 src_state = jnp.asarray(src_state) 

46 

47 num_times = events.shape[-2] 

48 num_units = events.shape[-1] 

49 

50 def proposal(): 

51 # Select unit 

52 unit = yield Root( 

53 UniformInteger( 

54 low=0, 

55 high=num_units, 

56 float_dtype=events.dtype, 

57 name="unit", 

58 ) 

59 ) 

60 

61 # Select timepoint 

62 timepoint = yield Root( 

63 UniformInteger( 

64 low=0, 

65 high=num_times, 

66 float_dtype=events.dtype, 

67 name="timepoint", 

68 ) 

69 ) 

70 

71 # event_count is bounded by the minimum value of the source state 

72 state_bound = _slice_min(src_state[..., unit], timepoint) 

73 bound = jnp.minimum(state_bound, count_max).astype(np.int32) 

74 

75 yield UniformInteger( 

76 low=jnp.minimum(1, bound), 

77 high=bound + 1, 

78 float_dtype=events.dtype, 

79 name="event_count", 

80 ) 

81 

82 return tfd.JointDistributionCoroutineAutoBatched(proposal, name=name) 

83 

84 

85def del_occult_proposal( 

86 count_max: int, 

87 events: Tensor, 

88 dest_state: Tensor, 

89 name=None, 

90): 

91 events = jnp.asarray(events) 

92 dest_state = jnp.asarray(dest_state) 

93 

94 def proposal(): 

95 # Select unit to delete events from 

96 nonzero_units = jnp.any(events > 0, axis=-2) 

97 probs = nonzero_units / jnp.linalg.norm(nonzero_units, ord=1, axis=-1) 

98 unit = yield Root( 

99 tfd.Categorical( 

100 probs=probs, 

101 name="unit", 

102 ) 

103 ) 

104 # If there are no events to delete, unit will be events.shape[-1] + 1 

105 # Therefore clip to ensure we don't get an error in the next stage. 

106 unit = jnp.clip(unit, min=0, max=events.shape[-1] - 1) 

107 

108 # Select timepoint to delete events from 

109 unit_events = events[..., unit] # T 

110 probs = (unit_events > 0) / jnp.linalg.norm( 

111 unit_events > 0, ord=1, axis=-1 

112 ) 

113 timepoint = yield tfd.Categorical(probs=probs, name="timepoint") 

114 

115 # Clip if there are no events to delete 

116 timepoint = jnp.clip(timepoint, min=0, max=events.shape[-2] - 1) 

117 # Draw num to delete - this is bounded by the minimum value of the 

118 # destination state over the range [timepoint, num_times) 

119 unit_dest_state = dest_state[..., unit] 

120 state_bound = _slice_min(unit_dest_state, timepoint + 1) 

121 bound = jnp.minimum(state_bound, events[..., timepoint, unit]) 

122 bound = jnp.minimum(bound, count_max).astype(np.int32) 

123 

124 yield UniformInteger( 

125 low=jnp.minimum(1, bound), 

126 high=bound + 1, 

127 float_dtype=events.dtype, 

128 name="event_count", 

129 ) 

130 

131 return tfd.JointDistributionCoroutineAutoBatched(proposal, name=name)