Coverage for gemlib/mcmc/multi_scan.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""MultiScanKernel calls one_step a number of times on an inner kernel""" 

2 

3from functools import partial 

4from typing import NamedTuple 

5 

6import jax 

7 

8from .sampling_algorithm import LogProbFnType, Position, SamplingAlgorithm 

9 

10__all__ = ["multi_scan"] 

11 

12 

13class MultiScanKernelState(NamedTuple): 

14 last_results: NamedTuple 

15 

16 

17class MultiScanKernelInfo(NamedTuple): 

18 last_results: NamedTuple 

19 

20 

21def multi_scan( 

22 num_updates: int, sampling_algorithm: SamplingAlgorithm 

23) -> SamplingAlgorithm: 

24 """Performs multiple applications of a kernel 

25 

26 :obj:`sampling_algorithm` is invoked :obj:`num_updates` times 

27 returning the state and info after the last step. 

28 

29 Args: 

30 num_updates: integer giving the number of updates 

31 sampling_algorithm: an instance of :obj:`SamplingAlgorithm` 

32 

33 Returns: 

34 An instance of :obj:`SamplingAlgorithm` 

35 """ 

36 # warn( 

37 # "Use of `multi_scan` is deprecated, and will be removed in future.\ 

38 # Instead, please make use of SamplingAlgorithm.__mul__.", 

39 # DeprecationWarning, 

40 # stacklevel=2, 

41 # ) 

42 

43 def init_fn(target_log_prob_fn, position): 

44 cs, ks = sampling_algorithm.init(target_log_prob_fn, position) 

45 return cs, MultiScanKernelState(ks) 

46 

47 def step_fn( 

48 target_log_prob_fn: LogProbFnType, 

49 current_state: tuple[Position, MultiScanKernelState], 

50 seed=None, 

51 ): 

52 seeds = jax.random.split(seed, num=num_updates) 

53 step_fn = partial(sampling_algorithm.step, target_log_prob_fn) 

54 

55 def body(a): 

56 i, state, info = a 

57 state1, info1 = step_fn(state, seeds[i]) 

58 return i + 1, state1, info1 

59 

60 def cond(a): 

61 i, state, info = a 

62 return i < num_updates 

63 

64 chain_state, kernel_state = current_state 

65 

66 init_state, init_info = step_fn( 

67 (chain_state, kernel_state.last_results), seed 

68 ) # unrolled first it 

69 

70 _, last_state, last_info = jax.lax.while_loop( 

71 cond, body, init_val=(1, init_state, init_info) 

72 ) 

73 

74 return ( 

75 (last_state[0], MultiScanKernelState(last_state[1])), 

76 MultiScanKernelInfo(last_info), 

77 ) 

78 

79 return SamplingAlgorithm(init_fn, step_fn)