Coverage for gemlib/mcmc/mcmc_sampler.py: 100%
22 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Higher-order functions to run MCMC"""
3from collections.abc import Callable, Iterable
4from functools import partial
5from typing import Any
7import jax
8import jax.numpy as jnp
9import tensorflow_probability.substrates.jax as tfp
11from .sampling_algorithm import SamplingAlgorithm
13__all__ = ["mcmc"]
16def _split_seed(seed, n):
17 return tfp.random.split_seed(seed, n=n)
20def _scan(fn, init, xs):
21 """Scan
23 This function is equivalent to
25 ```
26 scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
27 ```
28 """
29 return jax.lax.scan(fn, init, xs)
32def mcmc(
33 num_samples: int,
34 sampling_algorithm: SamplingAlgorithm,
35 target_density_fn: Callable[[Any, ...], float],
36 initial_position: Iterable,
37 seed: tuple[int, int],
38 kernel_kwargs_fn=lambda _: {},
39):
40 """Runs an MCMC using `sampling_algorithm`
42 Args:
43 num_updates: integer giving the number of updates
44 sampling_algorithm: an instance of `SamplingAlgorithm`
45 target_density_fn: Python callable which takes an argument like
46 `current_state` and returns its (possibly unnormalized) log-density
47 under the target distribution.
48 initial_position: initial state structured tuple
49 seed: an optional list of two scalar ``int`` tensors.
50 kernel_kwargs_fn: a callable taking the chain position as an argument,
51 and returning a dictionary of extra kwargs
54 Returns:
55 A tuple containing samples of the Markov chain and information about the
56 behaviour of the sampler(s) (e.g. whether kernels accepted or rejected,
57 adaptive covariance matrices, etc).
58 """
60 initial_position = jax.tree.map(lambda x: jnp.asarray(x), initial_position)
61 initial_state = sampling_algorithm.init(
62 target_density_fn,
63 initial_position,
64 **kernel_kwargs_fn(initial_position),
65 )
66 kernel_step_fn = partial(sampling_algorithm.step, target_density_fn)
68 def one_step(state, rng_key):
69 new_state, info = kernel_step_fn(
70 state, rng_key, **kernel_kwargs_fn(state)
71 )
72 return new_state, (new_state[0].position, info)
74 keys = _split_seed(seed, num_samples)
76 _, trace = _scan(one_step, initial_state, keys)
78 return trace