Coverage for gemlib/distributions/brownian.py: 100%
54 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Brownian motion as a distribution"""
3import jax
4import jax.numpy as jnp
5import tensorflow_probability.substrates.jax as tfp
6from tensorflow_probability.substrates.jax.internal import (
7 distribution_util as dist_util,
8)
9from tensorflow_probability.substrates.jax.internal import (
10 dtype_util,
11)
12from tensorflow_probability.substrates.jax.internal.tensor_util import (
13 convert_nonref_to_tensor,
14)
16tfd = tfp.distributions
19class BrownianMotion(tfd.Distribution):
20 def __init__(
21 self,
22 index_points,
23 x0=0.0,
24 scale=1.0,
25 validate_args=False,
26 allow_nan_stats=True,
27 name="BrownianMotion",
28 ):
29 parameters = dict(locals())
30 dtype = dtype_util.common_dtype(
31 [index_points, x0, scale],
32 dtype_hint=jnp.asarray(index_points).dtype,
33 )
34 self._x0 = convert_nonref_to_tensor(x0, dtype_hint=dtype)
36 self._index_points = convert_nonref_to_tensor(
37 index_points, dtype_hint=dtype
38 )
39 self._scale = convert_nonref_to_tensor(scale, dtype_hint=dtype)
40 self._increments = tfd.MultivariateNormalDiag(
41 loc=jnp.zeros(self._index_points[..., 1:].shape, dtype=dtype),
42 scale_diag=jnp.sqrt(
43 self._index_points[..., 1:] - self._index_points[..., :-1]
44 ).astype(dtype)
45 * self._scale,
46 validate_args=validate_args,
47 allow_nan_stats=allow_nan_stats,
48 name="bm_increments",
49 ) # iid increments
51 with jax.named_scope(name):
52 super().__init__(
53 dtype=dtype,
54 reparameterization_type=tfd.FULLY_REPARAMETERIZED,
55 validate_args=validate_args,
56 allow_nan_stats=allow_nan_stats,
57 parameters=parameters,
58 name=name,
59 )
61 def _batch_shape(self):
62 return self._x0.shape
64 def _event_shape(self):
65 return self._index_points.shape[-1] - 1
67 def _sample_n(self, n, seed=None):
68 return self._x0 + jnp.cumsum(
69 self._increments.sample(n, seed=seed), axis=-1
70 )
72 def _log_prob(self, x):
73 path = dist_util.pad(
74 x, axis=-1, front=True, value=jnp.astype(self._x0, x.dtype)
75 )
76 diff = path[..., 1:] - path[..., :-1]
78 return self._increments.log_prob(diff)
81class BrownianBridge(tfd.Distribution):
82 def __init__(
83 self,
84 index_points,
85 x0=0.0,
86 x1=0.0,
87 scale=1.0,
88 validate_args=False,
89 allow_nan_stats=True,
90 name="BrownianBridge",
91 ):
92 parameters = dict(locals())
93 dtype = jnp.asarray(x0).dtype
94 self._index_points = convert_nonref_to_tensor(
95 index_points, dtype_hint=dtype
96 )
97 self._x0 = convert_nonref_to_tensor(x0, dtype_hint=dtype)
98 self._x1 = convert_nonref_to_tensor(x1, dtype_hint=dtype)
99 self._scale = convert_nonref_to_tensor(scale, dtype_hint=dtype)
101 self._increments = tfd.MultivariateNormalDiag(
102 loc=0.0,
103 scale_diag=jnp.sqrt(
104 self._index_points[..., 1:] - self._index_points[..., :-1]
105 )
106 * self._scale,
107 name="bb_increments",
108 )
110 with jax.named_scope(name):
111 super().__init__(
112 dtype=dtype,
113 reparameterization_type=tfd.FULLY_REPARAMETERIZED,
114 validate_args=validate_args,
115 allow_nan_stats=allow_nan_stats,
116 name=name,
117 )
119 def _batch_shape(self):
120 return self._x0.shape
122 def _event_shape(self):
123 return self._index_points.shape[-1] - 2
125 def _sample_n(self, n, seed=None):
126 """Sampling based on re-leveling pure
127 Brownian motion
128 """
129 z = self._increments.sample(n, seed=seed)
130 z = jnp.cumsum(z, axis=-1)
132 y_ref_0 = jnp.stack([jnp.zeros_like(z[..., 0]), z[..., -1]], axis=-1)
133 y_ref_1 = jnp.stack([self._x0, self._x1], axis=-1)
134 line = tfp.math.interp_regular_1d_grid(
135 x=self._index_points[..., 1:-1],
136 x_ref_min=self._index_points[..., 0],
137 x_ref_max=self._index_points[..., -1],
138 y_ref=y_ref_1 - y_ref_0,
139 )
140 return z[..., :-1] + line
142 def _log_prob(self, x):
143 path = dist_util.pad(x, -1, front=True, value=self._x0)
144 path = dist_util.pad(path, -1, back=True, value=self._x1)
145 diff = path[..., 1:] - path[..., :-1]
146 return self._increments.log_prob(diff)