Coverage for gemlib/mcmc/adaptive_random_walk_metropolis.py: 97%
68 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""An adaptive random walk Metropolis Hastings algorithm"""
3from typing import NamedTuple, TypeVar
5import jax.numpy as jnp
6import numpy as np
7import tensorflow_probability.substrates.jax as tfp
9from .mcmc_util import get_flattening_bijector, is_list_like
10from .sampling_algorithm import (
11 ChainState,
12 LogProbFnType,
13 Position,
14 SamplingAlgorithm,
15)
17tfd = tfp.distributions
19__all__ = ["adaptive_rwmh", "RunningCovariance"]
21RunningCovariance = tfp.experimental.stats.RunningCovariance
22Tensor = TypeVar("Tensor")
25class AdaptiveRwmhKernelState(NamedTuple):
26 scale: float
27 adaptation_quantity: float
28 running_covariance: RunningCovariance
29 is_adaptive: bool
30 is_accepted: bool
31 num_steps: bool
34class AdaptiveRwmhInfo(NamedTuple):
35 scale: float
36 adaptation_quantity: float
37 running_covariance: RunningCovariance
38 is_adaptive: bool
39 is_accepted: bool
40 num_steps: bool
41 proposed_state: Tensor
42 log_acceptance: Tensor
45def adaptive_rwmh(
46 initial_scale: float = 0.1,
47 initial_running_covariance: Tensor | None = None,
48 lambda0: float = 0.1,
49 adapt_prob: float = 0.95,
50 adaptation_quantity: float = 1.0,
51) -> SamplingAlgorithm:
52 """An Adaptive Random Walk Metropolis Hastings algorithm
54 This algorithm implements an adaptive random walk Metropolis
55 Hastings algorithm as described in Sherlock et al. 2010.
57 Args:
58 initial_scale: the initial value of the covariance scalar
59 initial_running_covariance: an initial covariance matrix of shape `[p,p]`
60 where `p` is the length of the vector of
61 concatenated parameters
62 lambda0: the scaling for the non-adaptive covariance matrix
63 adapt_prob: probability we draw using the adaptive covariance matrix
64 adaptation_quantity: the (log) amount to increase or decrease the
65 covariance scaling parameter.
66 name: an optional name for the kernel
68 Returns:
69 A :obj:`SamplingAlgorithm`
71 References:
72 Chris Sherlock, Paul Fearnhead, Gareth O. Roberts. "The Random Walk
73 Metropolis: Linking Theory and Practice Through a Case Study."
74 Statistical Science, 25(2) 172-190 May 2010.
75 """
77 def init_fn(target_log_prob_fn: LogProbFnType, position: Position):
78 position_parts = position if is_list_like(position) else [position]
80 # Flatten and concatenate
81 flattening_bijector = get_flattening_bijector(position_parts)
82 flat_position = flattening_bijector(position_parts)
83 flat_size = flat_position.shape[-1]
85 # The initial_variance has 3 possibilities:
86 # 1. a scalar, which is broadcast down the diagonal of
87 # the covariance matrix.
88 # 2. a structure compatible with `position` of elements
89 # corresponding to variances of each variable in `position`.
90 # 3. a full covariance matrix of dimension `(flat_position.shape[-1],
91 # flat_position.shape[-1])` giving the covariance matrix for all
92 # parameters in the flattened and concatenated `position`.
94 if initial_running_covariance is None:
95 initial_running_covariance_ = (
96 tfp.experimental.stats.RunningCovariance(
97 num_samples=10,
98 mean=flat_position,
99 sum_squared_residuals=0.1
100 * jnp.diag(jnp.ones_like(flat_position)),
101 event_ndims=1,
102 )
103 )
104 else:
105 initial_running_covariance_ = initial_running_covariance
107 if initial_scale is None:
108 initial_scale_ = 2.38 / jnp.sqrt(
109 jnp.astype(flat_size, flat_position.dtype)
110 )
111 else:
112 initial_scale_ = jnp.asarray(initial_scale, flat_position.dtype)
114 chain_state = ChainState(
115 position=position_parts
116 if is_list_like(position)
117 else position_parts[0],
118 log_density=target_log_prob_fn(*position_parts),
119 log_density_grad=(),
120 )
122 kernel_state = AdaptiveRwmhKernelState(
123 scale=initial_scale_,
124 adaptation_quantity=adaptation_quantity,
125 running_covariance=initial_running_covariance_,
126 is_adaptive=False,
127 is_accepted=False,
128 num_steps=jnp.asarray(0, np.int32),
129 )
131 return chain_state, kernel_state
133 def step_fn(target_log_prob_fn, chain_and_kernel_state, seed):
134 mh_seed, adapt_seed, accept_seed = tfp.random.split_seed(seed, n=3)
136 cs, ks = chain_and_kernel_state
138 position_parts = (
139 cs.position if is_list_like(cs.position) else [cs.position]
140 )
142 flatten = get_flattening_bijector(position_parts)
143 flat_position = flatten(position_parts)
144 flat_size = flat_position.shape[-1]
146 # Update the covariance matrix
147 next_running_covariance = ks.running_covariance.update(flat_position)
148 next_adapt_scale = jnp.where(
149 ks.is_accepted,
150 ks.scale
151 * jnp.exp(
152 2.3 * ks.adaptation_quantity / jnp.sqrt(ks.num_steps)
153 ), # Magic number from Alg 6 of Sherlock et al.
154 ks.scale
155 * jnp.exp(-ks.adaptation_quantity / jnp.sqrt(ks.num_steps)),
156 )
157 next_scale = jnp.where(ks.is_adaptive, next_adapt_scale, ks.scale)
159 # Covariances
160 static_covariance = jnp.diag(
161 jnp.ones_like(flat_position) / flat_size * lambda0**2
162 )
163 adaptive_covariance = (
164 next_scale**2 * next_running_covariance.covariance()
165 )
167 # Choose either scaled empirical covariance or static covariance
168 is_adaptive = tfd.Bernoulli(probs=adapt_prob, dtype=np.bool).sample(
169 seed=adapt_seed
170 )
171 covariance_matrix = jnp.where(
172 is_adaptive,
173 adaptive_covariance,
174 static_covariance,
175 )
177 # Do the proposal
178 flat_proposed_position = tfd.MultivariateNormalTriL(
179 loc=flat_position,
180 scale_tril=jnp.linalg.cholesky(covariance_matrix),
181 ).sample(seed=mh_seed)
183 # Accept/reject
184 proposed_log_prob = target_log_prob_fn(
185 *flatten.inverse(flat_proposed_position)
186 )
187 log_acceptance = proposed_log_prob - cs.log_density
188 is_accept = tfd.Bernoulli(
189 probs=jnp.exp(log_acceptance), dtype=np.bool
190 ).sample(seed=accept_seed)
191 next_position = jnp.where(
192 is_accept, flat_proposed_position, flat_position
193 )
194 next_log_prob = jnp.where(is_accept, proposed_log_prob, cs.log_density)
196 next_position = flatten.inverse(next_position)
197 if not is_list_like(cs.position):
198 next_position = next_position[0]
200 new_chain_state = ChainState(
201 position=next_position,
202 log_density=next_log_prob,
203 log_density_grad=(),
204 )
206 new_kernel_state = ks._replace(
207 scale=next_scale,
208 running_covariance=next_running_covariance,
209 is_adaptive=is_adaptive,
210 is_accepted=is_accept,
211 num_steps=ks.num_steps + 1,
212 )
214 return (new_chain_state, new_kernel_state), AdaptiveRwmhInfo(
215 *new_kernel_state,
216 proposed_state=flatten.inverse(flat_proposed_position),
217 log_acceptance=log_acceptance,
218 )
220 return SamplingAlgorithm(init_fn, step_fn)