Coverage for gemlib/mcmc/test_util.py: 100%
19 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Simple counting kernel for testing"""
3from typing import NamedTuple
5from .sampling_algorithm import ChainState, SamplingAlgorithm
7__all__ = ["CountingKernelInfo", "CountingKernelState", "counting_kernel"]
10class CountingKernelState(NamedTuple):
11 invocation: int
14class CountingKernelInfo(NamedTuple):
15 is_accepted: bool
18def counting_kernel():
19 def init_fn(target_log_prob_fn, position):
20 chain_state = ChainState(
21 position=position,
22 log_density=target_log_prob_fn(**position._asdict()),
23 log_density_grad=(),
24 )
25 kernel_state = CountingKernelState(0)
27 return chain_state, kernel_state
29 def step_fn(target_log_prob_fn, chain_and_kernel_state, seed): # noqa: ARG001
30 chain_state, kernel_state = chain_and_kernel_state
32 new_position = chain_state.position.__class__(
33 **{k: v + 1.0 for k, v in chain_state.position._asdict().items()}
34 )
36 new_chain_state = ChainState(
37 position=new_position,
38 log_density=target_log_prob_fn(**new_position._asdict()),
39 log_density_grad=(),
40 )
41 new_kernel_state = CountingKernelState(kernel_state.invocation + 1)
43 return (new_chain_state, new_kernel_state), CountingKernelInfo(True)
45 return SamplingAlgorithm(init_fn, step_fn)