Coverage for gemlib/mcmc/hmc.py: 97%
35 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Base Hamiltonian Monte Carlo"""
3from collections.abc import Iterable
4from typing import NamedTuple
6import jax.numpy as jnp
7import numpy as np
8import tensorflow_probability.substrates.jax as tfp
10from .sampling_algorithm import ChainState, SamplingAlgorithm
12pu = tfp.experimental.mcmc.preconditioning_utils
13tfd = tfp.distributions
14tfde = tfp.experimental.distributions
17class HmcKernelState(NamedTuple):
18 step_size: float
19 num_leapfrog_steps: float
20 mass_matrix: Iterable
23def hmc(
24 step_size: float = 0.1,
25 num_leapfrog_steps: int = 16,
26 mass_matrix: Iterable | None = None,
27) -> SamplingAlgorithm:
28 """Hamiltonian Monte Carlo
30 Args:
31 step_size: the step size to take
32 num_leapfrog_steps: number of leapfrog steps to take
33 mass_matrix: a mass matrix (defaults to :obj:`diag(1)` if :obj:`None`)
35 Returns:
36 A :obj:`SamplingAlgorithm`
38 References:
39 Michael Betancourt. A Conceptual Introduction to Hamiltonian
40 Monte Carlo. arXiv, 2017. https://arxiv.org/abs/1701.02434
42 """
44 def _make_momentum_distribution(position):
45 # return None
46 if mass_matrix is None:
47 return pu.make_momentum_distribution(
48 position, np.array((), np.int32)
49 )
51 else:
52 return tfde.MultivariateNormalPrecisionFactorLinearOperator(
53 precision_factor=jnp.linalg.LinearOperatorFullMatrix(
54 mass_matrix
55 ),
56 precision=jnp.linalg.LinearOperatorFullMatrix(mass_matrix),
57 )
59 def _build_kernel(target_log_prob_fn, momentum_distribution):
60 return tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
61 target_log_prob_fn=target_log_prob_fn,
62 step_size=step_size,
63 num_leapfrog_steps=num_leapfrog_steps,
64 momentum_distribution=momentum_distribution,
65 store_parameters_in_results=True,
66 )
68 def init_fn(target_log_prob_fn, initial_position):
69 kernel = _build_kernel(
70 target_log_prob_fn, _make_momentum_distribution(initial_position)
71 )
72 results = kernel.bootstrap_results(initial_position)
74 # Repack the results data structure into our own
75 chain_state = ChainState(
76 position=initial_position,
77 log_density=results.accepted_results.target_log_prob,
78 log_density_grad=results.accepted_results.grads_target_log_prob,
79 )
80 kernel_state = results
82 return chain_state, kernel_state
84 def step_fn(target_log_prob_fn, chain_and_kernel_state, seed):
85 chain_state, kernel_state = chain_and_kernel_state
87 # Pack kernel state into results here
88 seed = tfp.random.sanitize_seed(seed)
90 kernel = _build_kernel(
91 target_log_prob_fn,
92 _make_momentum_distribution(chain_state.position),
93 )
95 new_position, results = kernel.one_step(
96 chain_state.position,
97 kernel.bootstrap_results(chain_state.position),
98 seed,
99 )
101 info = results
103 chain_state = ChainState(
104 position=new_position,
105 log_density=results.accepted_results.target_log_prob,
106 log_density_grad=results.accepted_results.grads_target_log_prob,
107 )
109 return (chain_state, results), info
111 return SamplingAlgorithm(init_fn, step_fn)