Coverage for gemlib/mcmc/mcmc_sampler.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Higher-order functions to run MCMC""" 

2 

3from collections.abc import Callable, Iterable 

4from functools import partial 

5from typing import Any 

6 

7import jax 

8import jax.numpy as jnp 

9import tensorflow_probability.substrates.jax as tfp 

10 

11from .sampling_algorithm import SamplingAlgorithm 

12 

13__all__ = ["mcmc"] 

14 

15 

16def _split_seed(seed, n): 

17 return tfp.random.split_seed(seed, n=n) 

18 

19 

20def _scan(fn, init, xs): 

21 """Scan 

22 

23 This function is equivalent to 

24 

25 ``` 

26 scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) 

27 ``` 

28 """ 

29 return jax.lax.scan(fn, init, xs) 

30 

31 

32def mcmc( 

33 num_samples: int, 

34 sampling_algorithm: SamplingAlgorithm, 

35 target_density_fn: Callable[[Any, ...], float], 

36 initial_position: Iterable, 

37 seed: tuple[int, int], 

38 kernel_kwargs_fn=lambda _: {}, 

39): 

40 """Runs an MCMC using `sampling_algorithm` 

41 

42 Args: 

43 num_updates: integer giving the number of updates 

44 sampling_algorithm: an instance of `SamplingAlgorithm` 

45 target_density_fn: Python callable which takes an argument like 

46 `current_state` and returns its (possibly unnormalized) log-density 

47 under the target distribution. 

48 initial_position: initial state structured tuple 

49 seed: an optional list of two scalar ``int`` tensors. 

50 kernel_kwargs_fn: a callable taking the chain position as an argument, 

51 and returning a dictionary of extra kwargs 

52 

53 

54 Returns: 

55 A tuple containing samples of the Markov chain and information about the 

56 behaviour of the sampler(s) (e.g. whether kernels accepted or rejected, 

57 adaptive covariance matrices, etc). 

58 """ 

59 

60 initial_position = jax.tree.map(lambda x: jnp.asarray(x), initial_position) 

61 initial_state = sampling_algorithm.init( 

62 target_density_fn, 

63 initial_position, 

64 **kernel_kwargs_fn(initial_position), 

65 ) 

66 kernel_step_fn = partial(sampling_algorithm.step, target_density_fn) 

67 

68 def one_step(state, rng_key): 

69 new_state, info = kernel_step_fn( 

70 state, rng_key, **kernel_kwargs_fn(state) 

71 ) 

72 return new_state, (new_state[0].position, info) 

73 

74 keys = _split_seed(seed, num_samples) 

75 

76 _, trace = _scan(one_step, initial_state, keys) 

77 

78 return trace