Coverage for gemlib/mcmc/adaptive_hmc.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Adaptive Hamiltonian Monte Carlo sampling algorithm""" 

2 

3from enum import IntEnum 

4 

5import jax 

6import numpy as np 

7import tensorflow_probability.substrates.jax as tfp 

8from tensorflow_probability.python.internal import unnest 

9 

10from gemlib.mcmc.sampling_algorithm import ( 

11 ChainAndKernelState, 

12 ChainState, 

13 LogProbFnType, 

14 Position, 

15 SamplingAlgorithm, 

16 SeedType, 

17) 

18 

19RunningVariance = tfp.experimental.stats.RunningVariance 

20 

21__all__ = [ 

22 "RunningVariance", 

23 "StepSizeAlgorithm", 

24 "make_initial_running_variance", 

25 "adaptive_hmc", 

26] 

27 

28 

29# Base samplers 

30class StepSizeAlgorithm(IntEnum): 

31 SIMPLE = 0 

32 NESTEROV = 1 

33 

34 

35PreconditionedHMC = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo 

36StepSizeKernels = [ 

37 tfp.mcmc.SimpleStepSizeAdaptation, 

38 tfp.mcmc.DualAveragingStepSizeAdaptation, 

39] 

40MassMatrixAdaptation = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation 

41 

42 

43def make_initial_running_variance( 

44 position: Position, 

45 variance: "Position | None" = None, 

46): 

47 def fn(x, v): 

48 return RunningVariance.from_stats( 

49 np.array(1.0, x.dtype), mean=x, variance=v 

50 ) 

51 

52 if variance is None: 

53 variance = jax.tree.map(lambda x: np.ones_like(x), position) 

54 

55 return jax.tree.map(fn, position, variance) 

56 

57 

58def adaptive_hmc( 

59 initial_step_size: float, 

60 num_leapfrog_steps: int, 

61 num_step_size_adaptation_steps: int, 

62 num_mass_matrix_estimation_steps: int, 

63 initial_running_variance: RunningVariance, 

64 step_size_algorithm: StepSizeAlgorithm = StepSizeAlgorithm.NESTEROV, 

65 step_size_kwargs: dict | None = None, 

66) -> SamplingAlgorithm: 

67 """Adaptive Hamiltonian Monte Carlo sampler 

68 

69 Implements a Hamiltonian Monte Carlo sampler with adaptive step size 

70 and mass matrix. 

71 

72 Args: 

73 initial_step_size: the initial step size 

74 num_leapfrog_steps: number of leapfrog steps to perform 

75 num_step_size_adaptation_steps: the number of steps before adaptation 

76 stops 

77 num_mass_matrix_estimation_steps: the number of steps before mass matrix 

78 estimation stops 

79 initial_running_variance: an initial running variance, constructed using 

80 :class:`make_initial_running_variance` 

81 step_size_algorithm: the algorithm used to adapt the step size 

82 step_size_kwargs: arguments to pass to the step size adaptation kernel 

83 name: the name of the kernel 

84 

85 Returns: 

86 a SamplingAlgorithm implementing an adaptive Hamiltonian Monte Carlo 

87 sampler 

88 """ 

89 

90 if step_size_kwargs is None: 

91 step_size_kwargs = {} 

92 

93 def _build_kernel(target_log_prob_fn: LogProbFnType): 

94 base_hmc_kernel = PreconditionedHMC( 

95 target_log_prob_fn=target_log_prob_fn, 

96 step_size=initial_step_size, 

97 num_leapfrog_steps=num_leapfrog_steps, 

98 store_parameters_in_results=True, 

99 ) 

100 

101 StepSizeAlg = StepSizeKernels[step_size_algorithm] 

102 step_size_adapt = StepSizeAlg( 

103 inner_kernel=base_hmc_kernel, 

104 num_adaptation_steps=num_step_size_adaptation_steps, 

105 **step_size_kwargs, 

106 ) 

107 

108 kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation( 

109 inner_kernel=step_size_adapt, 

110 initial_running_variance=initial_running_variance, 

111 num_estimation_steps=num_mass_matrix_estimation_steps, 

112 ) 

113 

114 return kernel 

115 

116 def init_fn(target_log_prob_fn: LogProbFnType, initial_position: Position): 

117 kernel = _build_kernel(target_log_prob_fn) 

118 results = kernel.bootstrap_results(initial_position) 

119 

120 # Repack the results data structure into our own 

121 chain_state = ChainState( 

122 position=initial_position, 

123 log_density=unnest.get_innermost(results, "target_log_prob"), 

124 log_density_grad=tuple( 

125 unnest.get_innermost(results, "grads_target_log_prob") 

126 ), 

127 ) 

128 

129 return ChainAndKernelState(chain_state, results) 

130 

131 def step_fn( 

132 target_log_prob_fn: LogProbFnType, 

133 chain_and_kernel_state: ChainAndKernelState, 

134 seed: SeedType, 

135 ): 

136 chain_state, kernel_state = chain_and_kernel_state 

137 

138 seed = tfp.random.sanitize_seed(seed) 

139 

140 kernel = _build_kernel( 

141 target_log_prob_fn, 

142 ) 

143 

144 # If grads have not been forwarded, recompute them 

145 if chain_state.log_density_grad == (): 

146 _, log_density_grad = tfp.math.value_and_gradient( 

147 target_log_prob_fn, 

148 *chain_state.position, 

149 ) 

150 else: 

151 log_density_grad = chain_state.log_density_grad 

152 

153 next_results = unnest.replace_innermost( 

154 kernel_state, 

155 target_log_prob=chain_state.log_density, 

156 grads_target_log_prob=list(log_density_grad), 

157 ) 

158 

159 new_position, results = kernel.one_step( 

160 chain_state.position, 

161 next_results, 

162 seed, 

163 ) 

164 

165 info = results 

166 

167 chain_state = ChainState( 

168 position=new_position, 

169 log_density=unnest.get_innermost(results, "target_log_prob"), 

170 log_density_grad=tuple( 

171 unnest.get_innermost(results, "grads_target_log_prob") 

172 ), 

173 ) 

174 

175 return ChainAndKernelState(chain_state, results), info 

176 

177 return SamplingAlgorithm(init_fn, step_fn)