Coverage for gemlib/mcmc/mwg_step.py: 96%
54 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Implementation of Metropolis-within-Gibbs framework"""
3from __future__ import annotations
5from collections import namedtuple
6from collections.abc import Callable
8import tensorflow_probability.substrates.jax as tfp
10from gemlib.mcmc.mcmc_util import is_list_like
11from gemlib.mcmc.sampling_algorithm import (
12 ChainAndKernelState,
13 ChainState,
14 KernelInfo,
15 LogProbFnType,
16 Position,
17 SamplingAlgorithm,
18 SeedType,
19)
21split_seed = tfp.random.split_seed
23__all__ = ["MwgStep"]
26def as_list(x):
27 if is_list_like(x):
28 return x
29 return [x]
32def _make_target_type(target_names):
33 if is_list_like(target_names):
34 return namedtuple("_Target", target_names)
35 return lambda x: x # identity
38def _make_position_projector(target_names: list[str]):
39 target_names = as_list(target_names)
41 def fn(position: Position) -> tuple[tuple, dict]:
42 position_dict = position._asdict()
44 for name in target_names:
45 if name not in position_dict:
46 raise ValueError(f"`{name}` is not present in `position`")
48 target_tuple = tuple(position_dict[k] for k in target_names)
49 target_compl_dict = {
50 k: v for k, v in position_dict.items() if k not in target_names
51 }
53 return (
54 target_tuple,
55 target_compl_dict,
56 )
58 return fn
61class MwgStep: # pylint: disable=too-few-public-methods
62 """A Metropolis-within-Gibbs step.
64 Transforms a base kernel to operate on a substate of a Markov chain.
66 Args:
67 sampling_algorithm: a named tuple containing the generic kernel `init`
68 and `step` function.
69 target_names: a list of variable names on which the
70 Metropolis-within-Gibbs step is to operate
71 kernel_kwargs_fn: a callable taking the chain position as an argument,
72 and returning a dictionary of extra kwargs to
73 `sampling_algorithm.step`.
75 Returns:
76 An instance of SamplingAlgorithm.
78 """
80 def __new__(
81 cls,
82 sampling_algorithm: SamplingAlgorithm,
83 target_names: str | list[str],
84 kernel_kwargs_fn: Callable[[Position], dict] = lambda _: {},
85 ):
86 """Create a new Metropolis-within-Gibbs step"""
88 target_names_list = as_list(target_names)
89 _project_position = _make_position_projector(target_names_list)
91 TargetType = _make_target_type(target_names)
93 def _name_target(target: tuple) -> dict:
94 return dict(target_names, target)
96 def init(
97 target_log_prob_fn: LogProbFnType,
98 initial_position: Position,
99 ):
100 target, target_compl = _project_position(initial_position)
102 def conditional_tlp(*args):
103 tlp_kwargs = (
104 dict(zip(target_names_list, args, strict=True))
105 | target_compl
106 )
107 return target_log_prob_fn(**tlp_kwargs)
109 kernel_state = sampling_algorithm.init(
110 conditional_tlp,
111 TargetType(*target),
112 **kernel_kwargs_fn(initial_position),
113 )
115 chain_state = ChainState(
116 position=initial_position,
117 log_density=kernel_state[0].log_density,
118 log_density_grad=kernel_state[0].log_density_grad,
119 )
121 return chain_state, kernel_state[1]
123 def step(
124 target_log_prob_fn: LogProbFnType,
125 chain_and_kernel_state: ChainAndKernelState,
126 seed: SeedType,
127 ) -> tuple[ChainAndKernelState, KernelInfo]:
128 chain_state, kernel_state = chain_and_kernel_state
130 # Split global state and generate conditional density
131 target, target_compl = _project_position(chain_state.position)
133 # Calculate the conditional log density
134 def conditional_tlp(*args):
135 tlp_kwargs = (
136 dict(zip(target_names_list, args, strict=True))
137 | target_compl
138 )
139 return target_log_prob_fn(**tlp_kwargs)
141 chain_substate = chain_state._replace(position=TargetType(*target))
143 # Invoke the kernel on the target state
144 (new_chain_substate, new_kernel_state), info = (
145 sampling_algorithm.step(
146 conditional_tlp,
147 (chain_substate, kernel_state),
148 seed,
149 **kernel_kwargs_fn(chain_state.position),
150 )
151 )
153 # Stitch the global position back together
154 new_position_dict = dict(
155 zip(
156 target_names_list,
157 as_list(new_chain_substate.position),
158 strict=True,
159 )
160 )
161 new_global_state = new_chain_substate._replace(
162 position=chain_state.position.__class__(
163 **(new_position_dict | target_compl)
164 )
165 )
167 return (new_global_state, new_kernel_state), info
169 return SamplingAlgorithm(init, step)