Coverage for gemlib/mcmc/transformed_sampling_algorithm.py: 100%
41 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Transform a sampling algorithm"""
3from typing import Callable, Iterable # noqa: UP035
5import jax
6import jax.numpy as jnp
8from gemlib.mcmc.mcmc_util import is_list_like
10from .sampling_algorithm import (
11 ChainAndKernelState,
12 ChainState,
13 LogProbFnType,
14 Position,
15 SamplingAlgorithm,
16 SeedType,
17)
20def map_structure(fn, first, *rest):
21 """Map `fn` over `first` and `rest`.
23 This function relaxes :code:`jax.tree.map`'s strict requirement
24 that `first` and `*rest` have exactly equal structures, providing
25 for situations where duck-typed tree topology suffices. In this respect,
26 it behaves like `tf.nest.map_structure` where `check_types=False`. However,
27 it _only_ requires that the flat representations of each tree are
28 compatible, and will therefore not raise an error if the arity of the nodes
29 differ.
30 """
32 def is_leaf(x):
33 return not isinstance(x, list | tuple)
35 flat_first, first_treedef = jax.tree.flatten(first, is_leaf=is_leaf)
36 flat_rest = [jax.tree.leaves(x, is_leaf=is_leaf) for x in rest]
37 result = tuple(
38 fn(*args) for args in zip(flat_first, *flat_rest, strict=True)
39 )
41 return jax.tree.unflatten(first_treedef, result)
44def _transform_tlp_fn(
45 bijectors: list, target_log_prob_fn: Callable, disable_bijector_caching=True
46):
47 if not is_list_like(bijectors):
48 bijectors = [bijectors]
50 def fn(*transformed_position):
51 if disable_bijector_caching:
52 transformed_position = map_structure(
53 lambda x: x,
54 transformed_position,
55 )
57 untransformed_position = map_structure(
58 lambda x, bij: bij.forward(x),
59 transformed_position,
60 bijectors,
61 )
62 tlp = target_log_prob_fn(*untransformed_position) - jnp.sum(
63 jnp.asarray(
64 map_structure(
65 lambda x, bij: bij.inverse_log_det_jacobian(x),
66 untransformed_position,
67 bijectors,
68 )
69 )
70 )
72 return tlp
74 return fn
77def transform_sampling_algorithm(
78 bijectors: Iterable, sampling_algorithm: SamplingAlgorithm
79):
80 r"""Transform a sampling algorithm.
82 This wrapper transforms `sampling_algorithm` with respect to the
83 probability measure on which it acts. `transform_sampling_algorithm`
84 is particularly useful for unbounding parameter spaces in order to use
85 algorithms such as Hamiltonian Monte Carlo or the No-U-Turn-Samplers (NUTS).
87 It does this by applying the change-of-variables formula, such that for
88 :math:`Y = g(X)`,
90 .. math::
92 f_Y(y)=f_X(g^{-1}(y))\left|\frac{\mathrm{d}g^{-1}(y)}{\mathrm{d}y}\right|
94 Args:
95 bijectors: a structure of TensorFlow Probability bijectors compatible with
96 `position`
97 sampling_algorithm: a sampling algorithm
99 Returns:
100 A new :obj:`SamplingAlgorithm` representing a kernel working on the
101 transformed space.
103 Examples:
104 Instantiate a transformed hmc kernel::
106 import tensorflow_probability as tfp
107 from gemlib.mcmc import transform_sampling_algorithm
108 from gemlib.mcmc import hmc
110 kernel = transform_sampling_algorithm(
111 bijectors=[tfp.bijectors.Exp(), tfp.bijectors.Exp()],
112 sampling_algorithm=hmc(step_size=0.1, num_leapfrog_steps=16),
113 )
114 """
116 def init_fn(
117 target_log_prob_fn: LogProbFnType, position: Position, **kwargs
118 ):
119 transformed_tlp = _transform_tlp_fn(bijectors, target_log_prob_fn)
121 transformed_position = map_structure(
122 lambda x, bij: bij.inverse(x),
123 position,
124 bijectors,
125 )
126 chain_state, kernel_state = sampling_algorithm.init(
127 transformed_tlp, transformed_position, **kwargs
128 )
130 new_chain_state = ChainState(
131 position=position,
132 log_density=chain_state.log_density
133 + jnp.sum(
134 jnp.asarray(
135 map_structure(
136 lambda bij, x: bij.inverse_log_det_jacobian(x),
137 bijectors,
138 position,
139 )
140 )
141 ),
142 log_density_grad=(),
143 )
145 return new_chain_state, kernel_state
147 def step_fn(
148 target_log_prob_fn: LogProbFnType,
149 chain_and_kernel_state: ChainAndKernelState,
150 seed: SeedType,
151 **kwargs,
152 ):
153 chain_state, kernel_state = chain_and_kernel_state
155 # Transform to a unconstrained space (inverse transform)
156 transformed_tlp_fn = _transform_tlp_fn(bijectors, target_log_prob_fn)
157 transformed_position = map_structure(
158 lambda x, bij: bij.inverse(x),
159 chain_state.position,
160 bijectors,
161 )
162 transformed_tlp = chain_state.log_density - jnp.sum(
163 jnp.asarray(
164 map_structure(
165 lambda x, bij: bij.inverse_log_det_jacobian(x),
166 chain_state.position,
167 bijectors,
168 )
169 )
170 )
172 transformed_chain_state = ChainState(
173 position=transformed_position,
174 log_density=transformed_tlp,
175 log_density_grad=(),
176 )
177 (new_chain_state, new_kernel_state), info = sampling_algorithm.step(
178 transformed_tlp_fn,
179 (transformed_chain_state, kernel_state),
180 seed=seed,
181 **kwargs,
182 )
184 # Transform back to the constrained space
185 constrained_new_position = map_structure(
186 lambda x, bij: bij.forward(x),
187 new_chain_state.position,
188 bijectors,
189 )
190 constrained_log_density = new_chain_state.log_density + jnp.sum(
191 jnp.asarray(
192 map_structure(
193 lambda x, bij: bij.inverse_log_det_jacobian(x),
194 constrained_new_position,
195 bijectors,
196 )
197 )
198 )
199 new_chain_state = ChainState(
200 position=constrained_new_position,
201 log_density=constrained_log_density,
202 log_density_grad=(),
203 )
205 return (new_chain_state, new_kernel_state), info
207 return SamplingAlgorithm(init_fn, step_fn)