Coverage for gemlib/mcmc/multi_scan.py: 100%
28 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""MultiScanKernel calls one_step a number of times on an inner kernel"""
3from functools import partial
4from typing import NamedTuple
6import jax
8from .sampling_algorithm import LogProbFnType, Position, SamplingAlgorithm
10__all__ = ["multi_scan"]
13class MultiScanKernelState(NamedTuple):
14 last_results: NamedTuple
17class MultiScanKernelInfo(NamedTuple):
18 last_results: NamedTuple
21def multi_scan(
22 num_updates: int, sampling_algorithm: SamplingAlgorithm
23) -> SamplingAlgorithm:
24 """Performs multiple applications of a kernel
26 :obj:`sampling_algorithm` is invoked :obj:`num_updates` times
27 returning the state and info after the last step.
29 Args:
30 num_updates: integer giving the number of updates
31 sampling_algorithm: an instance of :obj:`SamplingAlgorithm`
33 Returns:
34 An instance of :obj:`SamplingAlgorithm`
35 """
36 # warn(
37 # "Use of `multi_scan` is deprecated, and will be removed in future.\
38 # Instead, please make use of SamplingAlgorithm.__mul__.",
39 # DeprecationWarning,
40 # stacklevel=2,
41 # )
43 def init_fn(target_log_prob_fn, position):
44 cs, ks = sampling_algorithm.init(target_log_prob_fn, position)
45 return cs, MultiScanKernelState(ks)
47 def step_fn(
48 target_log_prob_fn: LogProbFnType,
49 current_state: tuple[Position, MultiScanKernelState],
50 seed=None,
51 ):
52 seeds = jax.random.split(seed, num=num_updates)
53 step_fn = partial(sampling_algorithm.step, target_log_prob_fn)
55 def body(a):
56 i, state, info = a
57 state1, info1 = step_fn(state, seeds[i])
58 return i + 1, state1, info1
60 def cond(a):
61 i, state, info = a
62 return i < num_updates
64 chain_state, kernel_state = current_state
66 init_state, init_info = step_fn(
67 (chain_state, kernel_state.last_results), seed
68 ) # unrolled first it
70 _, last_state, last_info = jax.lax.while_loop(
71 cond, body, init_val=(1, init_state, init_info)
72 )
74 return (
75 (last_state[0], MultiScanKernelState(last_state[1])),
76 MultiScanKernelInfo(last_info),
77 )
79 return SamplingAlgorithm(init_fn, step_fn)