Coverage for gemlib/mcmc/test_util.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Simple counting kernel for testing""" 

2 

3from typing import NamedTuple 

4 

5from .sampling_algorithm import ChainState, SamplingAlgorithm 

6 

7__all__ = ["CountingKernelInfo", "CountingKernelState", "counting_kernel"] 

8 

9 

10class CountingKernelState(NamedTuple): 

11 invocation: int 

12 

13 

14class CountingKernelInfo(NamedTuple): 

15 is_accepted: bool 

16 

17 

18def counting_kernel(): 

19 def init_fn(target_log_prob_fn, position): 

20 chain_state = ChainState( 

21 position=position, 

22 log_density=target_log_prob_fn(**position._asdict()), 

23 log_density_grad=(), 

24 ) 

25 kernel_state = CountingKernelState(0) 

26 

27 return chain_state, kernel_state 

28 

29 def step_fn(target_log_prob_fn, chain_and_kernel_state, seed): # noqa: ARG001 

30 chain_state, kernel_state = chain_and_kernel_state 

31 

32 new_position = chain_state.position.__class__( 

33 **{k: v + 1.0 for k, v in chain_state.position._asdict().items()} 

34 ) 

35 

36 new_chain_state = ChainState( 

37 position=new_position, 

38 log_density=target_log_prob_fn(**new_position._asdict()), 

39 log_density_grad=(), 

40 ) 

41 new_kernel_state = CountingKernelState(kernel_state.invocation + 1) 

42 

43 return (new_chain_state, new_kernel_state), CountingKernelInfo(True) 

44 

45 return SamplingAlgorithm(init_fn, step_fn)