Coverage for gemlib/mcmc/random_walk_metropolis.py: 100%
26 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Implementation of the Random Walk Metropolis algorithm"""
3from collections.abc import Callable
4from typing import NamedTuple, TypeVar
6import tensorflow_probability.substrates.jax as tfp
8from .sampling_algorithm import ChainState, SamplingAlgorithm
10Tensor = TypeVar("Tensor")
13class RwmhInfo(NamedTuple):
14 """Represents the information about a random walk Metropolis-Hastings (RWMH)
15 step.
17 This can be expanded to include more information in the future (as needed
18 for a specific kernel).
20 Attributes
21 ----------
22 is_accepted: Indicates whether the proposal was accepted or not.
23 proposed_state: the proposed state
24 """
26 is_accepted: bool
27 proposed_state: Tensor
30class RwmhKernelState(NamedTuple):
31 """Represents the kernel state of an arbitrary MCMC kernel.
33 Attributes
34 ----------
35 scale (float): The scale parameter for the RWMH kernel.
36 """
38 scale: float
41def rwmh(scale: int = 1.0):
42 """Random walk Metropolis Hastings MCMC kernel
44 Args:
45 scale: the random walk scaling parameter. Should broadcast with the
46 state.
47 Returns:
48 an instance of :obj:`SamplingAlgorithm`
49 """
51 def _build_kernel(log_prob_fn):
52 """Partial"""
53 return tfp.mcmc.RandomWalkMetropolis(
54 target_log_prob_fn=log_prob_fn,
55 new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=scale),
56 )
58 def init_fn(target_log_prob_fn, target_state):
59 kernel = _build_kernel(target_log_prob_fn)
60 results = kernel.bootstrap_results(target_state)
62 chain_state = ChainState(
63 position=target_state,
64 log_density=results.accepted_results.target_log_prob,
65 log_density_grad=(),
66 )
67 kernel_state = RwmhKernelState(scale=scale)
69 return chain_state, kernel_state
71 def step_fn(
72 target_log_prob_fn: Callable[[NamedTuple], float],
73 target_and_kernel_state: tuple[ChainState, RwmhKernelState],
74 seed,
75 ) -> Callable[[ChainState], tuple[ChainState, RwmhInfo]]:
76 # This could be replaced with BlackJAX easily
77 kernel = _build_kernel(target_log_prob_fn)
79 target_chain_state, kernel_state = target_and_kernel_state
81 new_target_position, results = kernel.one_step(
82 target_chain_state.position,
83 kernel.bootstrap_results(target_chain_state.position),
84 seed=seed,
85 )
87 new_chain_and_kernel_state = (
88 ChainState(
89 position=new_target_position,
90 log_density=results.accepted_results.target_log_prob,
91 log_density_grad=(),
92 ),
93 kernel_state,
94 )
96 return new_chain_and_kernel_state, RwmhInfo(
97 results.is_accepted, results.proposed_state
98 )
100 return SamplingAlgorithm(init_fn, step_fn)