Coverage for gemlib/mcmc/sampling_algorithm.py: 90%
88 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Base MCMC datatypes"""
3from __future__ import annotations
5from collections.abc import Callable
6from typing import Any, NamedTuple
8import tensorflow_probability.substrates.jax as tfp
10split_seed = tfp.random.split_seed
13__all__ = [
14 "ChainState",
15 "KernelState",
16 "ChainAndKernelState",
17 "LogProbFnType",
18 "KernelInfo",
19 "KernelInitFnType",
20 "KernelStepFnType",
21 "Position",
22 "SamplingAlgorithm",
23 "SeedType",
24]
27# Type aliases
28Position = NamedTuple
29KernelInfo = NamedTuple
32class ChainState(NamedTuple):
33 """Represent the state of an MCMC probability space"""
35 position: Position
36 log_density: float
37 log_density_grad: float | None = ()
40class KernelState(NamedTuple):
41 """Represent the state of a stateful MCMC kernel"""
43 pass
46class ChainAndKernelState(NamedTuple):
47 chain_state: ChainState
48 kernel_state: KernelState
51LogProbFnType = Callable[[Position], float]
53KernelInitFnType = Callable[[LogProbFnType, Position], ChainAndKernelState]
55SeedType = tuple[int, int]
57KernelStepFnType = Callable[
58 [LogProbFnType, ChainAndKernelState, SeedType],
59 tuple[ChainAndKernelState, KernelInfo],
60]
63def _maybe_flatten(x: list[Any]):
64 """Flatten a list if `len(x) <= 1`"""
65 if len(x) == 0:
66 return None
67 if len(x) == 1:
68 return x[0]
69 return x
72def _squeeze(x: list[Any]):
73 if len(x) == 1:
74 return x[0]
75 return x
78def _maybe_list(x):
79 if isinstance(x, list):
80 return x
81 return [x]
84def _maybe_tuple(x):
85 if type(x) is tuple:
86 return x
87 return (x,)
90class KernelInitMonad:
91 """KernelInitMonad is a Writer monad allowing us to build an initial
92 state tuple for a Metropolis-within-Gibbs algorithm
93 """
95 def __init__(self, fn: KernelInitFnType):
96 """The monad 'unit' function"""
97 self._fn = fn
99 def __call__(self, *args, **kwargs):
100 """Monad ``run'' function"""
101 return self._fn(*args, **kwargs)
103 def then(self, next_kernel_init_fn: KernelInitMonad):
104 """Monad combination, i.e. Haskell fish operator"""
106 @KernelInitMonad
107 def compound_init_fn(
108 target_log_prob_fn: LogProbFnType,
109 initial_position: ChainState,
110 ) -> KernelInitFnType:
111 _, self_kernel_state = self(target_log_prob_fn, initial_position)
112 next_chain_state, next_kernel_state = next_kernel_init_fn(
113 target_log_prob_fn, initial_position
114 )
116 return (
117 next_chain_state,
118 _maybe_list(self_kernel_state) + [next_kernel_state],
119 )
121 return compound_init_fn
123 def __rshift__(self, next_kernel: KernelInitMonad):
124 return self.then(next_kernel)
127class KernelStepMonad:
128 """StepMonad is a state monad that allows us to chain MCMC kernels
129 together.
130 """
132 def __init__(self, fn: KernelStepFnType):
133 """The monad 'unit' function"""
134 self._fn = fn # Make private
136 def __call__(self, *args, **kwargs):
137 """Apply the state transformer computation to a state."""
138 return self._fn(*args, **kwargs)
140 def then(self, next_kernel_fn: KernelStepMonad):
141 """The monad 'bind' operator which allows chaining.
142 ma >> mb :: ma -> mb -> mc
143 """
145 @KernelStepMonad
146 def compound_step_kernel(
147 target_log_prob_fn: LogProbFnType,
148 chain_and_kernel_state: ChainAndKernelState,
149 seed: SeedType,
150 ) -> tuple[ChainAndKernelState, KernelInfo]:
151 self_seed, next_seed = split_seed(seed)
153 chain_state, kernel_state = chain_and_kernel_state
155 self_kernel_state = _squeeze(kernel_state[:-1])
156 next_kernel_state = kernel_state[-1]
158 (chain_state, self_kernel_state), self_info = self._fn(
159 target_log_prob_fn,
160 (chain_state, self_kernel_state),
161 seed=self_seed,
162 )
164 (chain_state, next_kernel_state), next_info = next_kernel_fn(
165 target_log_prob_fn,
166 (chain_state, next_kernel_state),
167 seed=next_seed,
168 )
170 return (
171 (
172 chain_state,
173 _maybe_list(self_kernel_state) + [next_kernel_state],
174 ),
175 _maybe_list(self_info) + [next_info],
176 )
178 return compound_step_kernel
180 def __rshift__(self, next_kernel: KernelStepMonad):
181 return self.then(next_kernel)
184class SamplingAlgorithm:
185 """Represent a sampling algorithm"""
187 def __init__(
188 self,
189 init_fn: KernelInitFnType | KernelInitMonad,
190 step_fn: KernelStepFnType | KernelStepMonad,
191 ):
192 """Create a new sampling algorithm
194 Args:
195 init_fn: the kernel initialisation function
196 step_fn: the kernel step function
197 """
198 if isinstance(init_fn, KernelInitMonad) and isinstance(
199 step_fn, KernelStepMonad
200 ):
201 self._init: KernelInitMonad = init_fn
202 self._step: KernelStepMonad = step_fn
203 else:
204 self._init: KernelInitMonad = KernelInitMonad(init_fn)
205 self._step: KernelStepMonad = KernelStepMonad(step_fn)
207 def init(self, *args, **kwargs):
208 """Initialize and MCMC chain"""
209 return self._init(*args, **kwargs)
211 def step(self, *args, **kwargs):
212 """Function to invoke the MCMC kernel"""
213 return self._step(*args, **kwargs)
215 def then(self, next_kernel: SamplingAlgorithm):
216 """Sequential combinator"""
217 return SamplingAlgorithm(
218 init_fn=(self._init >> next_kernel._init),
219 step_fn=(self._step >> next_kernel._step),
220 )
222 def __rshift__(self, next_kernel: SamplingAlgorithm):
223 return self.then(next_kernel)
225 def __mul__(self, n: int):
226 """Performs multiple applications of a kernel
228 :obj:`sampling_algorithm` is invoked :obj:`n` times
229 returning the state and info after the last step.
231 Args:
232 num_updates: integer giving the number of updates
233 sampling_algorithm: an instance of :obj:`SamplingAlgorithm`
235 Returns:
236 An instance of :obj:`SamplingAlgorithm`
237 """
238 raise NotImplementedError(
239 "Not implemented. Please use `multi_scan` instead."
240 )
241 # return _repeat_sampling_algorithm(n, self)
244# def _repeat_sampling_algorithm(
245# num_updates: int, sampling_algorithm: SamplingAlgorithm
246# ) -> SamplingAlgorithm:
247# """Performs multiple applications of a kernel
249# :obj:`sampling_algorithm` is invoked :obj:`num_updates` times
250# returning the state and info after the last step.
252# Args:
253# num_updates: integer giving the number of updates
254# sampling_algorithm: an instance of :obj:`SamplingAlgorithm`
256# Returns:
257# An instance of :obj:`SamplingAlgorithm`
258# """
260# num_updates_ = tf.convert_to_tensor(num_updates)
262# def init_fn(target_log_prob_fn, position):
263# cs, ks = sampling_algorithm.init(target_log_prob_fn, position)
264# return cs, ks
266# def step_fn(
267# target_log_prob_fn: LogProbFnType,
268# current_state: tuple[Position, NamedTuple],
269# seed=None,
270# ):
271# seeds = tfp.random.split_seed(
272# seed, n=num_updates, salt="multi_scan_kernel"
273# )
274# step_fn = partial(sampling_algorithm.step, target_log_prob_fn)
276# def body(i, state, _):
277# state, info = step_fn(state, tf.gather(seeds, i, axis=-2))
278# return i + 1, state, info
280# def cond(i, *_):
281# return i < num_updates_
283# chain_state, kernel_state = current_state
285# init_state, init_info = step_fn(
286# (chain_state, kernel_state), seed
287# ) # unrolled first it
289# _, last_state, last_info = tf.while_loop(
290# cond, body, loop_vars=(1, init_state, init_info)
291# )
293# return last_state, last_info
295# return SamplingAlgorithm(init_fn, step_fn)