Coverage for gemlib/mcmc/sampling_algorithm.py: 90%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Base MCMC datatypes""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable 

6from typing import Any, NamedTuple 

7 

8import tensorflow_probability.substrates.jax as tfp 

9 

10split_seed = tfp.random.split_seed 

11 

12 

13__all__ = [ 

14 "ChainState", 

15 "KernelState", 

16 "ChainAndKernelState", 

17 "LogProbFnType", 

18 "KernelInfo", 

19 "KernelInitFnType", 

20 "KernelStepFnType", 

21 "Position", 

22 "SamplingAlgorithm", 

23 "SeedType", 

24] 

25 

26 

27# Type aliases 

28Position = NamedTuple 

29KernelInfo = NamedTuple 

30 

31 

32class ChainState(NamedTuple): 

33 """Represent the state of an MCMC probability space""" 

34 

35 position: Position 

36 log_density: float 

37 log_density_grad: float | None = () 

38 

39 

40class KernelState(NamedTuple): 

41 """Represent the state of a stateful MCMC kernel""" 

42 

43 pass 

44 

45 

46class ChainAndKernelState(NamedTuple): 

47 chain_state: ChainState 

48 kernel_state: KernelState 

49 

50 

51LogProbFnType = Callable[[Position], float] 

52 

53KernelInitFnType = Callable[[LogProbFnType, Position], ChainAndKernelState] 

54 

55SeedType = tuple[int, int] 

56 

57KernelStepFnType = Callable[ 

58 [LogProbFnType, ChainAndKernelState, SeedType], 

59 tuple[ChainAndKernelState, KernelInfo], 

60] 

61 

62 

63def _maybe_flatten(x: list[Any]): 

64 """Flatten a list if `len(x) <= 1`""" 

65 if len(x) == 0: 

66 return None 

67 if len(x) == 1: 

68 return x[0] 

69 return x 

70 

71 

72def _squeeze(x: list[Any]): 

73 if len(x) == 1: 

74 return x[0] 

75 return x 

76 

77 

78def _maybe_list(x): 

79 if isinstance(x, list): 

80 return x 

81 return [x] 

82 

83 

84def _maybe_tuple(x): 

85 if type(x) is tuple: 

86 return x 

87 return (x,) 

88 

89 

90class KernelInitMonad: 

91 """KernelInitMonad is a Writer monad allowing us to build an initial 

92 state tuple for a Metropolis-within-Gibbs algorithm 

93 """ 

94 

95 def __init__(self, fn: KernelInitFnType): 

96 """The monad 'unit' function""" 

97 self._fn = fn 

98 

99 def __call__(self, *args, **kwargs): 

100 """Monad ``run'' function""" 

101 return self._fn(*args, **kwargs) 

102 

103 def then(self, next_kernel_init_fn: KernelInitMonad): 

104 """Monad combination, i.e. Haskell fish operator""" 

105 

106 @KernelInitMonad 

107 def compound_init_fn( 

108 target_log_prob_fn: LogProbFnType, 

109 initial_position: ChainState, 

110 ) -> KernelInitFnType: 

111 _, self_kernel_state = self(target_log_prob_fn, initial_position) 

112 next_chain_state, next_kernel_state = next_kernel_init_fn( 

113 target_log_prob_fn, initial_position 

114 ) 

115 

116 return ( 

117 next_chain_state, 

118 _maybe_list(self_kernel_state) + [next_kernel_state], 

119 ) 

120 

121 return compound_init_fn 

122 

123 def __rshift__(self, next_kernel: KernelInitMonad): 

124 return self.then(next_kernel) 

125 

126 

127class KernelStepMonad: 

128 """StepMonad is a state monad that allows us to chain MCMC kernels 

129 together. 

130 """ 

131 

132 def __init__(self, fn: KernelStepFnType): 

133 """The monad 'unit' function""" 

134 self._fn = fn # Make private 

135 

136 def __call__(self, *args, **kwargs): 

137 """Apply the state transformer computation to a state.""" 

138 return self._fn(*args, **kwargs) 

139 

140 def then(self, next_kernel_fn: KernelStepMonad): 

141 """The monad 'bind' operator which allows chaining. 

142 ma >> mb :: ma -> mb -> mc 

143 """ 

144 

145 @KernelStepMonad 

146 def compound_step_kernel( 

147 target_log_prob_fn: LogProbFnType, 

148 chain_and_kernel_state: ChainAndKernelState, 

149 seed: SeedType, 

150 ) -> tuple[ChainAndKernelState, KernelInfo]: 

151 self_seed, next_seed = split_seed(seed) 

152 

153 chain_state, kernel_state = chain_and_kernel_state 

154 

155 self_kernel_state = _squeeze(kernel_state[:-1]) 

156 next_kernel_state = kernel_state[-1] 

157 

158 (chain_state, self_kernel_state), self_info = self._fn( 

159 target_log_prob_fn, 

160 (chain_state, self_kernel_state), 

161 seed=self_seed, 

162 ) 

163 

164 (chain_state, next_kernel_state), next_info = next_kernel_fn( 

165 target_log_prob_fn, 

166 (chain_state, next_kernel_state), 

167 seed=next_seed, 

168 ) 

169 

170 return ( 

171 ( 

172 chain_state, 

173 _maybe_list(self_kernel_state) + [next_kernel_state], 

174 ), 

175 _maybe_list(self_info) + [next_info], 

176 ) 

177 

178 return compound_step_kernel 

179 

180 def __rshift__(self, next_kernel: KernelStepMonad): 

181 return self.then(next_kernel) 

182 

183 

184class SamplingAlgorithm: 

185 """Represent a sampling algorithm""" 

186 

187 def __init__( 

188 self, 

189 init_fn: KernelInitFnType | KernelInitMonad, 

190 step_fn: KernelStepFnType | KernelStepMonad, 

191 ): 

192 """Create a new sampling algorithm 

193 

194 Args: 

195 init_fn: the kernel initialisation function 

196 step_fn: the kernel step function 

197 """ 

198 if isinstance(init_fn, KernelInitMonad) and isinstance( 

199 step_fn, KernelStepMonad 

200 ): 

201 self._init: KernelInitMonad = init_fn 

202 self._step: KernelStepMonad = step_fn 

203 else: 

204 self._init: KernelInitMonad = KernelInitMonad(init_fn) 

205 self._step: KernelStepMonad = KernelStepMonad(step_fn) 

206 

207 def init(self, *args, **kwargs): 

208 """Initialize and MCMC chain""" 

209 return self._init(*args, **kwargs) 

210 

211 def step(self, *args, **kwargs): 

212 """Function to invoke the MCMC kernel""" 

213 return self._step(*args, **kwargs) 

214 

215 def then(self, next_kernel: SamplingAlgorithm): 

216 """Sequential combinator""" 

217 return SamplingAlgorithm( 

218 init_fn=(self._init >> next_kernel._init), 

219 step_fn=(self._step >> next_kernel._step), 

220 ) 

221 

222 def __rshift__(self, next_kernel: SamplingAlgorithm): 

223 return self.then(next_kernel) 

224 

225 def __mul__(self, n: int): 

226 """Performs multiple applications of a kernel 

227 

228 :obj:`sampling_algorithm` is invoked :obj:`n` times 

229 returning the state and info after the last step. 

230 

231 Args: 

232 num_updates: integer giving the number of updates 

233 sampling_algorithm: an instance of :obj:`SamplingAlgorithm` 

234 

235 Returns: 

236 An instance of :obj:`SamplingAlgorithm` 

237 """ 

238 raise NotImplementedError( 

239 "Not implemented. Please use `multi_scan` instead." 

240 ) 

241 # return _repeat_sampling_algorithm(n, self) 

242 

243 

244# def _repeat_sampling_algorithm( 

245# num_updates: int, sampling_algorithm: SamplingAlgorithm 

246# ) -> SamplingAlgorithm: 

247# """Performs multiple applications of a kernel 

248 

249# :obj:`sampling_algorithm` is invoked :obj:`num_updates` times 

250# returning the state and info after the last step. 

251 

252# Args: 

253# num_updates: integer giving the number of updates 

254# sampling_algorithm: an instance of :obj:`SamplingAlgorithm` 

255 

256# Returns: 

257# An instance of :obj:`SamplingAlgorithm` 

258# """ 

259 

260# num_updates_ = tf.convert_to_tensor(num_updates) 

261 

262# def init_fn(target_log_prob_fn, position): 

263# cs, ks = sampling_algorithm.init(target_log_prob_fn, position) 

264# return cs, ks 

265 

266# def step_fn( 

267# target_log_prob_fn: LogProbFnType, 

268# current_state: tuple[Position, NamedTuple], 

269# seed=None, 

270# ): 

271# seeds = tfp.random.split_seed( 

272# seed, n=num_updates, salt="multi_scan_kernel" 

273# ) 

274# step_fn = partial(sampling_algorithm.step, target_log_prob_fn) 

275 

276# def body(i, state, _): 

277# state, info = step_fn(state, tf.gather(seeds, i, axis=-2)) 

278# return i + 1, state, info 

279 

280# def cond(i, *_): 

281# return i < num_updates_ 

282 

283# chain_state, kernel_state = current_state 

284 

285# init_state, init_info = step_fn( 

286# (chain_state, kernel_state), seed 

287# ) # unrolled first it 

288 

289# _, last_state, last_info = tf.while_loop( 

290# cond, body, loop_vars=(1, init_state, init_info) 

291# ) 

292 

293# return last_state, last_info 

294 

295# return SamplingAlgorithm(init_fn, step_fn)