Coverage for gemlib/mcmc/adaptive_random_walk_metropolis.py: 97%

68 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""An adaptive random walk Metropolis Hastings algorithm""" 

2 

3from typing import NamedTuple, TypeVar 

4 

5import jax.numpy as jnp 

6import numpy as np 

7import tensorflow_probability.substrates.jax as tfp 

8 

9from .mcmc_util import get_flattening_bijector, is_list_like 

10from .sampling_algorithm import ( 

11 ChainState, 

12 LogProbFnType, 

13 Position, 

14 SamplingAlgorithm, 

15) 

16 

17tfd = tfp.distributions 

18 

19__all__ = ["adaptive_rwmh", "RunningCovariance"] 

20 

21RunningCovariance = tfp.experimental.stats.RunningCovariance 

22Tensor = TypeVar("Tensor") 

23 

24 

25class AdaptiveRwmhKernelState(NamedTuple): 

26 scale: float 

27 adaptation_quantity: float 

28 running_covariance: RunningCovariance 

29 is_adaptive: bool 

30 is_accepted: bool 

31 num_steps: bool 

32 

33 

34class AdaptiveRwmhInfo(NamedTuple): 

35 scale: float 

36 adaptation_quantity: float 

37 running_covariance: RunningCovariance 

38 is_adaptive: bool 

39 is_accepted: bool 

40 num_steps: bool 

41 proposed_state: Tensor 

42 log_acceptance: Tensor 

43 

44 

45def adaptive_rwmh( 

46 initial_scale: float = 0.1, 

47 initial_running_covariance: Tensor | None = None, 

48 lambda0: float = 0.1, 

49 adapt_prob: float = 0.95, 

50 adaptation_quantity: float = 1.0, 

51) -> SamplingAlgorithm: 

52 """An Adaptive Random Walk Metropolis Hastings algorithm 

53 

54 This algorithm implements an adaptive random walk Metropolis 

55 Hastings algorithm as described in Sherlock et al. 2010. 

56 

57 Args: 

58 initial_scale: the initial value of the covariance scalar 

59 initial_running_covariance: an initial covariance matrix of shape `[p,p]` 

60 where `p` is the length of the vector of 

61 concatenated parameters 

62 lambda0: the scaling for the non-adaptive covariance matrix 

63 adapt_prob: probability we draw using the adaptive covariance matrix 

64 adaptation_quantity: the (log) amount to increase or decrease the 

65 covariance scaling parameter. 

66 name: an optional name for the kernel 

67 

68 Returns: 

69 A :obj:`SamplingAlgorithm` 

70 

71 References: 

72 Chris Sherlock, Paul Fearnhead, Gareth O. Roberts. "The Random Walk 

73 Metropolis: Linking Theory and Practice Through a Case Study." 

74 Statistical Science, 25(2) 172-190 May 2010. 

75 """ 

76 

77 def init_fn(target_log_prob_fn: LogProbFnType, position: Position): 

78 position_parts = position if is_list_like(position) else [position] 

79 

80 # Flatten and concatenate 

81 flattening_bijector = get_flattening_bijector(position_parts) 

82 flat_position = flattening_bijector(position_parts) 

83 flat_size = flat_position.shape[-1] 

84 

85 # The initial_variance has 3 possibilities: 

86 # 1. a scalar, which is broadcast down the diagonal of 

87 # the covariance matrix. 

88 # 2. a structure compatible with `position` of elements 

89 # corresponding to variances of each variable in `position`. 

90 # 3. a full covariance matrix of dimension `(flat_position.shape[-1], 

91 # flat_position.shape[-1])` giving the covariance matrix for all 

92 # parameters in the flattened and concatenated `position`. 

93 

94 if initial_running_covariance is None: 

95 initial_running_covariance_ = ( 

96 tfp.experimental.stats.RunningCovariance( 

97 num_samples=10, 

98 mean=flat_position, 

99 sum_squared_residuals=0.1 

100 * jnp.diag(jnp.ones_like(flat_position)), 

101 event_ndims=1, 

102 ) 

103 ) 

104 else: 

105 initial_running_covariance_ = initial_running_covariance 

106 

107 if initial_scale is None: 

108 initial_scale_ = 2.38 / jnp.sqrt( 

109 jnp.astype(flat_size, flat_position.dtype) 

110 ) 

111 else: 

112 initial_scale_ = jnp.asarray(initial_scale, flat_position.dtype) 

113 

114 chain_state = ChainState( 

115 position=position_parts 

116 if is_list_like(position) 

117 else position_parts[0], 

118 log_density=target_log_prob_fn(*position_parts), 

119 log_density_grad=(), 

120 ) 

121 

122 kernel_state = AdaptiveRwmhKernelState( 

123 scale=initial_scale_, 

124 adaptation_quantity=adaptation_quantity, 

125 running_covariance=initial_running_covariance_, 

126 is_adaptive=False, 

127 is_accepted=False, 

128 num_steps=jnp.asarray(0, np.int32), 

129 ) 

130 

131 return chain_state, kernel_state 

132 

133 def step_fn(target_log_prob_fn, chain_and_kernel_state, seed): 

134 mh_seed, adapt_seed, accept_seed = tfp.random.split_seed(seed, n=3) 

135 

136 cs, ks = chain_and_kernel_state 

137 

138 position_parts = ( 

139 cs.position if is_list_like(cs.position) else [cs.position] 

140 ) 

141 

142 flatten = get_flattening_bijector(position_parts) 

143 flat_position = flatten(position_parts) 

144 flat_size = flat_position.shape[-1] 

145 

146 # Update the covariance matrix 

147 next_running_covariance = ks.running_covariance.update(flat_position) 

148 next_adapt_scale = jnp.where( 

149 ks.is_accepted, 

150 ks.scale 

151 * jnp.exp( 

152 2.3 * ks.adaptation_quantity / jnp.sqrt(ks.num_steps) 

153 ), # Magic number from Alg 6 of Sherlock et al. 

154 ks.scale 

155 * jnp.exp(-ks.adaptation_quantity / jnp.sqrt(ks.num_steps)), 

156 ) 

157 next_scale = jnp.where(ks.is_adaptive, next_adapt_scale, ks.scale) 

158 

159 # Covariances 

160 static_covariance = jnp.diag( 

161 jnp.ones_like(flat_position) / flat_size * lambda0**2 

162 ) 

163 adaptive_covariance = ( 

164 next_scale**2 * next_running_covariance.covariance() 

165 ) 

166 

167 # Choose either scaled empirical covariance or static covariance 

168 is_adaptive = tfd.Bernoulli(probs=adapt_prob, dtype=np.bool).sample( 

169 seed=adapt_seed 

170 ) 

171 covariance_matrix = jnp.where( 

172 is_adaptive, 

173 adaptive_covariance, 

174 static_covariance, 

175 ) 

176 

177 # Do the proposal 

178 flat_proposed_position = tfd.MultivariateNormalTriL( 

179 loc=flat_position, 

180 scale_tril=jnp.linalg.cholesky(covariance_matrix), 

181 ).sample(seed=mh_seed) 

182 

183 # Accept/reject 

184 proposed_log_prob = target_log_prob_fn( 

185 *flatten.inverse(flat_proposed_position) 

186 ) 

187 log_acceptance = proposed_log_prob - cs.log_density 

188 is_accept = tfd.Bernoulli( 

189 probs=jnp.exp(log_acceptance), dtype=np.bool 

190 ).sample(seed=accept_seed) 

191 next_position = jnp.where( 

192 is_accept, flat_proposed_position, flat_position 

193 ) 

194 next_log_prob = jnp.where(is_accept, proposed_log_prob, cs.log_density) 

195 

196 next_position = flatten.inverse(next_position) 

197 if not is_list_like(cs.position): 

198 next_position = next_position[0] 

199 

200 new_chain_state = ChainState( 

201 position=next_position, 

202 log_density=next_log_prob, 

203 log_density_grad=(), 

204 ) 

205 

206 new_kernel_state = ks._replace( 

207 scale=next_scale, 

208 running_covariance=next_running_covariance, 

209 is_adaptive=is_adaptive, 

210 is_accepted=is_accept, 

211 num_steps=ks.num_steps + 1, 

212 ) 

213 

214 return (new_chain_state, new_kernel_state), AdaptiveRwmhInfo( 

215 *new_kernel_state, 

216 proposed_state=flatten.inverse(flat_proposed_position), 

217 log_acceptance=log_acceptance, 

218 ) 

219 

220 return SamplingAlgorithm(init_fn, step_fn)