Coverage for gemlib/distributions/brownian.py: 100%

54 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Brownian motion as a distribution""" 

2 

3import jax 

4import jax.numpy as jnp 

5import tensorflow_probability.substrates.jax as tfp 

6from tensorflow_probability.substrates.jax.internal import ( 

7 distribution_util as dist_util, 

8) 

9from tensorflow_probability.substrates.jax.internal import ( 

10 dtype_util, 

11) 

12from tensorflow_probability.substrates.jax.internal.tensor_util import ( 

13 convert_nonref_to_tensor, 

14) 

15 

16tfd = tfp.distributions 

17 

18 

19class BrownianMotion(tfd.Distribution): 

20 def __init__( 

21 self, 

22 index_points, 

23 x0=0.0, 

24 scale=1.0, 

25 validate_args=False, 

26 allow_nan_stats=True, 

27 name="BrownianMotion", 

28 ): 

29 parameters = dict(locals()) 

30 dtype = dtype_util.common_dtype( 

31 [index_points, x0, scale], 

32 dtype_hint=jnp.asarray(index_points).dtype, 

33 ) 

34 self._x0 = convert_nonref_to_tensor(x0, dtype_hint=dtype) 

35 

36 self._index_points = convert_nonref_to_tensor( 

37 index_points, dtype_hint=dtype 

38 ) 

39 self._scale = convert_nonref_to_tensor(scale, dtype_hint=dtype) 

40 self._increments = tfd.MultivariateNormalDiag( 

41 loc=jnp.zeros(self._index_points[..., 1:].shape, dtype=dtype), 

42 scale_diag=jnp.sqrt( 

43 self._index_points[..., 1:] - self._index_points[..., :-1] 

44 ).astype(dtype) 

45 * self._scale, 

46 validate_args=validate_args, 

47 allow_nan_stats=allow_nan_stats, 

48 name="bm_increments", 

49 ) # iid increments 

50 

51 with jax.named_scope(name): 

52 super().__init__( 

53 dtype=dtype, 

54 reparameterization_type=tfd.FULLY_REPARAMETERIZED, 

55 validate_args=validate_args, 

56 allow_nan_stats=allow_nan_stats, 

57 parameters=parameters, 

58 name=name, 

59 ) 

60 

61 def _batch_shape(self): 

62 return self._x0.shape 

63 

64 def _event_shape(self): 

65 return self._index_points.shape[-1] - 1 

66 

67 def _sample_n(self, n, seed=None): 

68 return self._x0 + jnp.cumsum( 

69 self._increments.sample(n, seed=seed), axis=-1 

70 ) 

71 

72 def _log_prob(self, x): 

73 path = dist_util.pad( 

74 x, axis=-1, front=True, value=jnp.astype(self._x0, x.dtype) 

75 ) 

76 diff = path[..., 1:] - path[..., :-1] 

77 

78 return self._increments.log_prob(diff) 

79 

80 

81class BrownianBridge(tfd.Distribution): 

82 def __init__( 

83 self, 

84 index_points, 

85 x0=0.0, 

86 x1=0.0, 

87 scale=1.0, 

88 validate_args=False, 

89 allow_nan_stats=True, 

90 name="BrownianBridge", 

91 ): 

92 parameters = dict(locals()) 

93 dtype = jnp.asarray(x0).dtype 

94 self._index_points = convert_nonref_to_tensor( 

95 index_points, dtype_hint=dtype 

96 ) 

97 self._x0 = convert_nonref_to_tensor(x0, dtype_hint=dtype) 

98 self._x1 = convert_nonref_to_tensor(x1, dtype_hint=dtype) 

99 self._scale = convert_nonref_to_tensor(scale, dtype_hint=dtype) 

100 

101 self._increments = tfd.MultivariateNormalDiag( 

102 loc=0.0, 

103 scale_diag=jnp.sqrt( 

104 self._index_points[..., 1:] - self._index_points[..., :-1] 

105 ) 

106 * self._scale, 

107 name="bb_increments", 

108 ) 

109 

110 with jax.named_scope(name): 

111 super().__init__( 

112 dtype=dtype, 

113 reparameterization_type=tfd.FULLY_REPARAMETERIZED, 

114 validate_args=validate_args, 

115 allow_nan_stats=allow_nan_stats, 

116 name=name, 

117 ) 

118 

119 def _batch_shape(self): 

120 return self._x0.shape 

121 

122 def _event_shape(self): 

123 return self._index_points.shape[-1] - 2 

124 

125 def _sample_n(self, n, seed=None): 

126 """Sampling based on re-leveling pure 

127 Brownian motion 

128 """ 

129 z = self._increments.sample(n, seed=seed) 

130 z = jnp.cumsum(z, axis=-1) 

131 

132 y_ref_0 = jnp.stack([jnp.zeros_like(z[..., 0]), z[..., -1]], axis=-1) 

133 y_ref_1 = jnp.stack([self._x0, self._x1], axis=-1) 

134 line = tfp.math.interp_regular_1d_grid( 

135 x=self._index_points[..., 1:-1], 

136 x_ref_min=self._index_points[..., 0], 

137 x_ref_max=self._index_points[..., -1], 

138 y_ref=y_ref_1 - y_ref_0, 

139 ) 

140 return z[..., :-1] + line 

141 

142 def _log_prob(self, x): 

143 path = dist_util.pad(x, -1, front=True, value=self._x0) 

144 path = dist_util.pad(path, -1, back=True, value=self._x1) 

145 diff = path[..., 1:] - path[..., :-1] 

146 return self._increments.log_prob(diff)