Coverage for gemlib/mcmc/adaptive_hmc.py: 100%
47 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Adaptive Hamiltonian Monte Carlo sampling algorithm"""
3from enum import IntEnum
5import jax
6import numpy as np
7import tensorflow_probability.substrates.jax as tfp
8from tensorflow_probability.python.internal import unnest
10from gemlib.mcmc.sampling_algorithm import (
11 ChainAndKernelState,
12 ChainState,
13 LogProbFnType,
14 Position,
15 SamplingAlgorithm,
16 SeedType,
17)
19RunningVariance = tfp.experimental.stats.RunningVariance
21__all__ = [
22 "RunningVariance",
23 "StepSizeAlgorithm",
24 "make_initial_running_variance",
25 "adaptive_hmc",
26]
29# Base samplers
30class StepSizeAlgorithm(IntEnum):
31 SIMPLE = 0
32 NESTEROV = 1
35PreconditionedHMC = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo
36StepSizeKernels = [
37 tfp.mcmc.SimpleStepSizeAdaptation,
38 tfp.mcmc.DualAveragingStepSizeAdaptation,
39]
40MassMatrixAdaptation = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation
43def make_initial_running_variance(
44 position: Position,
45 variance: "Position | None" = None,
46):
47 def fn(x, v):
48 return RunningVariance.from_stats(
49 np.array(1.0, x.dtype), mean=x, variance=v
50 )
52 if variance is None:
53 variance = jax.tree.map(lambda x: np.ones_like(x), position)
55 return jax.tree.map(fn, position, variance)
58def adaptive_hmc(
59 initial_step_size: float,
60 num_leapfrog_steps: int,
61 num_step_size_adaptation_steps: int,
62 num_mass_matrix_estimation_steps: int,
63 initial_running_variance: RunningVariance,
64 step_size_algorithm: StepSizeAlgorithm = StepSizeAlgorithm.NESTEROV,
65 step_size_kwargs: dict | None = None,
66) -> SamplingAlgorithm:
67 """Adaptive Hamiltonian Monte Carlo sampler
69 Implements a Hamiltonian Monte Carlo sampler with adaptive step size
70 and mass matrix.
72 Args:
73 initial_step_size: the initial step size
74 num_leapfrog_steps: number of leapfrog steps to perform
75 num_step_size_adaptation_steps: the number of steps before adaptation
76 stops
77 num_mass_matrix_estimation_steps: the number of steps before mass matrix
78 estimation stops
79 initial_running_variance: an initial running variance, constructed using
80 :class:`make_initial_running_variance`
81 step_size_algorithm: the algorithm used to adapt the step size
82 step_size_kwargs: arguments to pass to the step size adaptation kernel
83 name: the name of the kernel
85 Returns:
86 a SamplingAlgorithm implementing an adaptive Hamiltonian Monte Carlo
87 sampler
88 """
90 if step_size_kwargs is None:
91 step_size_kwargs = {}
93 def _build_kernel(target_log_prob_fn: LogProbFnType):
94 base_hmc_kernel = PreconditionedHMC(
95 target_log_prob_fn=target_log_prob_fn,
96 step_size=initial_step_size,
97 num_leapfrog_steps=num_leapfrog_steps,
98 store_parameters_in_results=True,
99 )
101 StepSizeAlg = StepSizeKernels[step_size_algorithm]
102 step_size_adapt = StepSizeAlg(
103 inner_kernel=base_hmc_kernel,
104 num_adaptation_steps=num_step_size_adaptation_steps,
105 **step_size_kwargs,
106 )
108 kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
109 inner_kernel=step_size_adapt,
110 initial_running_variance=initial_running_variance,
111 num_estimation_steps=num_mass_matrix_estimation_steps,
112 )
114 return kernel
116 def init_fn(target_log_prob_fn: LogProbFnType, initial_position: Position):
117 kernel = _build_kernel(target_log_prob_fn)
118 results = kernel.bootstrap_results(initial_position)
120 # Repack the results data structure into our own
121 chain_state = ChainState(
122 position=initial_position,
123 log_density=unnest.get_innermost(results, "target_log_prob"),
124 log_density_grad=tuple(
125 unnest.get_innermost(results, "grads_target_log_prob")
126 ),
127 )
129 return ChainAndKernelState(chain_state, results)
131 def step_fn(
132 target_log_prob_fn: LogProbFnType,
133 chain_and_kernel_state: ChainAndKernelState,
134 seed: SeedType,
135 ):
136 chain_state, kernel_state = chain_and_kernel_state
138 seed = tfp.random.sanitize_seed(seed)
140 kernel = _build_kernel(
141 target_log_prob_fn,
142 )
144 # If grads have not been forwarded, recompute them
145 if chain_state.log_density_grad == ():
146 _, log_density_grad = tfp.math.value_and_gradient(
147 target_log_prob_fn,
148 *chain_state.position,
149 )
150 else:
151 log_density_grad = chain_state.log_density_grad
153 next_results = unnest.replace_innermost(
154 kernel_state,
155 target_log_prob=chain_state.log_density,
156 grads_target_log_prob=list(log_density_grad),
157 )
159 new_position, results = kernel.one_step(
160 chain_state.position,
161 next_results,
162 seed,
163 )
165 info = results
167 chain_state = ChainState(
168 position=new_position,
169 log_density=unnest.get_innermost(results, "target_log_prob"),
170 log_density_grad=tuple(
171 unnest.get_innermost(results, "grads_target_log_prob")
172 ),
173 )
175 return ChainAndKernelState(chain_state, results), info
177 return SamplingAlgorithm(init_fn, step_fn)