Coverage for gemlib/mcmc/mwg_step.py: 96%

54 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Implementation of Metropolis-within-Gibbs framework""" 

2 

3from __future__ import annotations 

4 

5from collections import namedtuple 

6from collections.abc import Callable 

7 

8import tensorflow_probability.substrates.jax as tfp 

9 

10from gemlib.mcmc.mcmc_util import is_list_like 

11from gemlib.mcmc.sampling_algorithm import ( 

12 ChainAndKernelState, 

13 ChainState, 

14 KernelInfo, 

15 LogProbFnType, 

16 Position, 

17 SamplingAlgorithm, 

18 SeedType, 

19) 

20 

21split_seed = tfp.random.split_seed 

22 

23__all__ = ["MwgStep"] 

24 

25 

26def as_list(x): 

27 if is_list_like(x): 

28 return x 

29 return [x] 

30 

31 

32def _make_target_type(target_names): 

33 if is_list_like(target_names): 

34 return namedtuple("_Target", target_names) 

35 return lambda x: x # identity 

36 

37 

38def _make_position_projector(target_names: list[str]): 

39 target_names = as_list(target_names) 

40 

41 def fn(position: Position) -> tuple[tuple, dict]: 

42 position_dict = position._asdict() 

43 

44 for name in target_names: 

45 if name not in position_dict: 

46 raise ValueError(f"`{name}` is not present in `position`") 

47 

48 target_tuple = tuple(position_dict[k] for k in target_names) 

49 target_compl_dict = { 

50 k: v for k, v in position_dict.items() if k not in target_names 

51 } 

52 

53 return ( 

54 target_tuple, 

55 target_compl_dict, 

56 ) 

57 

58 return fn 

59 

60 

61class MwgStep: # pylint: disable=too-few-public-methods 

62 """A Metropolis-within-Gibbs step. 

63 

64 Transforms a base kernel to operate on a substate of a Markov chain. 

65 

66 Args: 

67 sampling_algorithm: a named tuple containing the generic kernel `init` 

68 and `step` function. 

69 target_names: a list of variable names on which the 

70 Metropolis-within-Gibbs step is to operate 

71 kernel_kwargs_fn: a callable taking the chain position as an argument, 

72 and returning a dictionary of extra kwargs to 

73 `sampling_algorithm.step`. 

74 

75 Returns: 

76 An instance of SamplingAlgorithm. 

77 

78 """ 

79 

80 def __new__( 

81 cls, 

82 sampling_algorithm: SamplingAlgorithm, 

83 target_names: str | list[str], 

84 kernel_kwargs_fn: Callable[[Position], dict] = lambda _: {}, 

85 ): 

86 """Create a new Metropolis-within-Gibbs step""" 

87 

88 target_names_list = as_list(target_names) 

89 _project_position = _make_position_projector(target_names_list) 

90 

91 TargetType = _make_target_type(target_names) 

92 

93 def _name_target(target: tuple) -> dict: 

94 return dict(target_names, target) 

95 

96 def init( 

97 target_log_prob_fn: LogProbFnType, 

98 initial_position: Position, 

99 ): 

100 target, target_compl = _project_position(initial_position) 

101 

102 def conditional_tlp(*args): 

103 tlp_kwargs = ( 

104 dict(zip(target_names_list, args, strict=True)) 

105 | target_compl 

106 ) 

107 return target_log_prob_fn(**tlp_kwargs) 

108 

109 kernel_state = sampling_algorithm.init( 

110 conditional_tlp, 

111 TargetType(*target), 

112 **kernel_kwargs_fn(initial_position), 

113 ) 

114 

115 chain_state = ChainState( 

116 position=initial_position, 

117 log_density=kernel_state[0].log_density, 

118 log_density_grad=kernel_state[0].log_density_grad, 

119 ) 

120 

121 return chain_state, kernel_state[1] 

122 

123 def step( 

124 target_log_prob_fn: LogProbFnType, 

125 chain_and_kernel_state: ChainAndKernelState, 

126 seed: SeedType, 

127 ) -> tuple[ChainAndKernelState, KernelInfo]: 

128 chain_state, kernel_state = chain_and_kernel_state 

129 

130 # Split global state and generate conditional density 

131 target, target_compl = _project_position(chain_state.position) 

132 

133 # Calculate the conditional log density 

134 def conditional_tlp(*args): 

135 tlp_kwargs = ( 

136 dict(zip(target_names_list, args, strict=True)) 

137 | target_compl 

138 ) 

139 return target_log_prob_fn(**tlp_kwargs) 

140 

141 chain_substate = chain_state._replace(position=TargetType(*target)) 

142 

143 # Invoke the kernel on the target state 

144 (new_chain_substate, new_kernel_state), info = ( 

145 sampling_algorithm.step( 

146 conditional_tlp, 

147 (chain_substate, kernel_state), 

148 seed, 

149 **kernel_kwargs_fn(chain_state.position), 

150 ) 

151 ) 

152 

153 # Stitch the global position back together 

154 new_position_dict = dict( 

155 zip( 

156 target_names_list, 

157 as_list(new_chain_substate.position), 

158 strict=True, 

159 ) 

160 ) 

161 new_global_state = new_chain_substate._replace( 

162 position=chain_state.position.__class__( 

163 **(new_position_dict | target_compl) 

164 ) 

165 ) 

166 

167 return (new_global_state, new_kernel_state), info 

168 

169 return SamplingAlgorithm(init, step)