Coverage for gemlib/mcmc/hmc.py: 97%

35 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Base Hamiltonian Monte Carlo""" 

2 

3from collections.abc import Iterable 

4from typing import NamedTuple 

5 

6import jax.numpy as jnp 

7import numpy as np 

8import tensorflow_probability.substrates.jax as tfp 

9 

10from .sampling_algorithm import ChainState, SamplingAlgorithm 

11 

12pu = tfp.experimental.mcmc.preconditioning_utils 

13tfd = tfp.distributions 

14tfde = tfp.experimental.distributions 

15 

16 

17class HmcKernelState(NamedTuple): 

18 step_size: float 

19 num_leapfrog_steps: float 

20 mass_matrix: Iterable 

21 

22 

23def hmc( 

24 step_size: float = 0.1, 

25 num_leapfrog_steps: int = 16, 

26 mass_matrix: Iterable | None = None, 

27) -> SamplingAlgorithm: 

28 """Hamiltonian Monte Carlo 

29 

30 Args: 

31 step_size: the step size to take 

32 num_leapfrog_steps: number of leapfrog steps to take 

33 mass_matrix: a mass matrix (defaults to :obj:`diag(1)` if :obj:`None`) 

34 

35 Returns: 

36 A :obj:`SamplingAlgorithm` 

37 

38 References: 

39 Michael Betancourt. A Conceptual Introduction to Hamiltonian 

40 Monte Carlo. arXiv, 2017. https://arxiv.org/abs/1701.02434 

41 

42 """ 

43 

44 def _make_momentum_distribution(position): 

45 # return None 

46 if mass_matrix is None: 

47 return pu.make_momentum_distribution( 

48 position, np.array((), np.int32) 

49 ) 

50 

51 else: 

52 return tfde.MultivariateNormalPrecisionFactorLinearOperator( 

53 precision_factor=jnp.linalg.LinearOperatorFullMatrix( 

54 mass_matrix 

55 ), 

56 precision=jnp.linalg.LinearOperatorFullMatrix(mass_matrix), 

57 ) 

58 

59 def _build_kernel(target_log_prob_fn, momentum_distribution): 

60 return tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo( 

61 target_log_prob_fn=target_log_prob_fn, 

62 step_size=step_size, 

63 num_leapfrog_steps=num_leapfrog_steps, 

64 momentum_distribution=momentum_distribution, 

65 store_parameters_in_results=True, 

66 ) 

67 

68 def init_fn(target_log_prob_fn, initial_position): 

69 kernel = _build_kernel( 

70 target_log_prob_fn, _make_momentum_distribution(initial_position) 

71 ) 

72 results = kernel.bootstrap_results(initial_position) 

73 

74 # Repack the results data structure into our own 

75 chain_state = ChainState( 

76 position=initial_position, 

77 log_density=results.accepted_results.target_log_prob, 

78 log_density_grad=results.accepted_results.grads_target_log_prob, 

79 ) 

80 kernel_state = results 

81 

82 return chain_state, kernel_state 

83 

84 def step_fn(target_log_prob_fn, chain_and_kernel_state, seed): 

85 chain_state, kernel_state = chain_and_kernel_state 

86 

87 # Pack kernel state into results here 

88 seed = tfp.random.sanitize_seed(seed) 

89 

90 kernel = _build_kernel( 

91 target_log_prob_fn, 

92 _make_momentum_distribution(chain_state.position), 

93 ) 

94 

95 new_position, results = kernel.one_step( 

96 chain_state.position, 

97 kernel.bootstrap_results(chain_state.position), 

98 seed, 

99 ) 

100 

101 info = results 

102 

103 chain_state = ChainState( 

104 position=new_position, 

105 log_density=results.accepted_results.target_log_prob, 

106 log_density_grad=results.accepted_results.grads_target_log_prob, 

107 ) 

108 

109 return (chain_state, results), info 

110 

111 return SamplingAlgorithm(init_fn, step_fn)