Coverage for gemlib/mcmc/random_walk_metropolis.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Implementation of the Random Walk Metropolis algorithm""" 

2 

3from collections.abc import Callable 

4from typing import NamedTuple, TypeVar 

5 

6import tensorflow_probability.substrates.jax as tfp 

7 

8from .sampling_algorithm import ChainState, SamplingAlgorithm 

9 

10Tensor = TypeVar("Tensor") 

11 

12 

13class RwmhInfo(NamedTuple): 

14 """Represents the information about a random walk Metropolis-Hastings (RWMH) 

15 step. 

16 

17 This can be expanded to include more information in the future (as needed 

18 for a specific kernel). 

19 

20 Attributes 

21 ---------- 

22 is_accepted: Indicates whether the proposal was accepted or not. 

23 proposed_state: the proposed state 

24 """ 

25 

26 is_accepted: bool 

27 proposed_state: Tensor 

28 

29 

30class RwmhKernelState(NamedTuple): 

31 """Represents the kernel state of an arbitrary MCMC kernel. 

32 

33 Attributes 

34 ---------- 

35 scale (float): The scale parameter for the RWMH kernel. 

36 """ 

37 

38 scale: float 

39 

40 

41def rwmh(scale: int = 1.0): 

42 """Random walk Metropolis Hastings MCMC kernel 

43 

44 Args: 

45 scale: the random walk scaling parameter. Should broadcast with the 

46 state. 

47 Returns: 

48 an instance of :obj:`SamplingAlgorithm` 

49 """ 

50 

51 def _build_kernel(log_prob_fn): 

52 """Partial""" 

53 return tfp.mcmc.RandomWalkMetropolis( 

54 target_log_prob_fn=log_prob_fn, 

55 new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=scale), 

56 ) 

57 

58 def init_fn(target_log_prob_fn, target_state): 

59 kernel = _build_kernel(target_log_prob_fn) 

60 results = kernel.bootstrap_results(target_state) 

61 

62 chain_state = ChainState( 

63 position=target_state, 

64 log_density=results.accepted_results.target_log_prob, 

65 log_density_grad=(), 

66 ) 

67 kernel_state = RwmhKernelState(scale=scale) 

68 

69 return chain_state, kernel_state 

70 

71 def step_fn( 

72 target_log_prob_fn: Callable[[NamedTuple], float], 

73 target_and_kernel_state: tuple[ChainState, RwmhKernelState], 

74 seed, 

75 ) -> Callable[[ChainState], tuple[ChainState, RwmhInfo]]: 

76 # This could be replaced with BlackJAX easily 

77 kernel = _build_kernel(target_log_prob_fn) 

78 

79 target_chain_state, kernel_state = target_and_kernel_state 

80 

81 new_target_position, results = kernel.one_step( 

82 target_chain_state.position, 

83 kernel.bootstrap_results(target_chain_state.position), 

84 seed=seed, 

85 ) 

86 

87 new_chain_and_kernel_state = ( 

88 ChainState( 

89 position=new_target_position, 

90 log_density=results.accepted_results.target_log_prob, 

91 log_density_grad=(), 

92 ), 

93 kernel_state, 

94 ) 

95 

96 return new_chain_and_kernel_state, RwmhInfo( 

97 results.is_accepted, results.proposed_state 

98 ) 

99 

100 return SamplingAlgorithm(init_fn, step_fn)