Coverage for gemlib/mcmc/transformed_sampling_algorithm.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Transform a sampling algorithm""" 

2 

3from typing import Callable, Iterable # noqa: UP035 

4 

5import jax 

6import jax.numpy as jnp 

7 

8from gemlib.mcmc.mcmc_util import is_list_like 

9 

10from .sampling_algorithm import ( 

11 ChainAndKernelState, 

12 ChainState, 

13 LogProbFnType, 

14 Position, 

15 SamplingAlgorithm, 

16 SeedType, 

17) 

18 

19 

20def map_structure(fn, first, *rest): 

21 """Map `fn` over `first` and `rest`. 

22 

23 This function relaxes :code:`jax.tree.map`'s strict requirement 

24 that `first` and `*rest` have exactly equal structures, providing 

25 for situations where duck-typed tree topology suffices. In this respect, 

26 it behaves like `tf.nest.map_structure` where `check_types=False`. However, 

27 it _only_ requires that the flat representations of each tree are 

28 compatible, and will therefore not raise an error if the arity of the nodes 

29 differ. 

30 """ 

31 

32 def is_leaf(x): 

33 return not isinstance(x, list | tuple) 

34 

35 flat_first, first_treedef = jax.tree.flatten(first, is_leaf=is_leaf) 

36 flat_rest = [jax.tree.leaves(x, is_leaf=is_leaf) for x in rest] 

37 result = tuple( 

38 fn(*args) for args in zip(flat_first, *flat_rest, strict=True) 

39 ) 

40 

41 return jax.tree.unflatten(first_treedef, result) 

42 

43 

44def _transform_tlp_fn( 

45 bijectors: list, target_log_prob_fn: Callable, disable_bijector_caching=True 

46): 

47 if not is_list_like(bijectors): 

48 bijectors = [bijectors] 

49 

50 def fn(*transformed_position): 

51 if disable_bijector_caching: 

52 transformed_position = map_structure( 

53 lambda x: x, 

54 transformed_position, 

55 ) 

56 

57 untransformed_position = map_structure( 

58 lambda x, bij: bij.forward(x), 

59 transformed_position, 

60 bijectors, 

61 ) 

62 tlp = target_log_prob_fn(*untransformed_position) - jnp.sum( 

63 jnp.asarray( 

64 map_structure( 

65 lambda x, bij: bij.inverse_log_det_jacobian(x), 

66 untransformed_position, 

67 bijectors, 

68 ) 

69 ) 

70 ) 

71 

72 return tlp 

73 

74 return fn 

75 

76 

77def transform_sampling_algorithm( 

78 bijectors: Iterable, sampling_algorithm: SamplingAlgorithm 

79): 

80 r"""Transform a sampling algorithm. 

81 

82 This wrapper transforms `sampling_algorithm` with respect to the 

83 probability measure on which it acts. `transform_sampling_algorithm` 

84 is particularly useful for unbounding parameter spaces in order to use 

85 algorithms such as Hamiltonian Monte Carlo or the No-U-Turn-Samplers (NUTS). 

86 

87 It does this by applying the change-of-variables formula, such that for 

88 :math:`Y = g(X)`, 

89 

90 .. math:: 

91 

92 f_Y(y)=f_X(g^{-1}(y))\left|\frac{\mathrm{d}g^{-1}(y)}{\mathrm{d}y}\right| 

93 

94 Args: 

95 bijectors: a structure of TensorFlow Probability bijectors compatible with 

96 `position` 

97 sampling_algorithm: a sampling algorithm 

98 

99 Returns: 

100 A new :obj:`SamplingAlgorithm` representing a kernel working on the 

101 transformed space. 

102 

103 Examples: 

104 Instantiate a transformed hmc kernel:: 

105 

106 import tensorflow_probability as tfp 

107 from gemlib.mcmc import transform_sampling_algorithm 

108 from gemlib.mcmc import hmc 

109 

110 kernel = transform_sampling_algorithm( 

111 bijectors=[tfp.bijectors.Exp(), tfp.bijectors.Exp()], 

112 sampling_algorithm=hmc(step_size=0.1, num_leapfrog_steps=16), 

113 ) 

114 """ 

115 

116 def init_fn( 

117 target_log_prob_fn: LogProbFnType, position: Position, **kwargs 

118 ): 

119 transformed_tlp = _transform_tlp_fn(bijectors, target_log_prob_fn) 

120 

121 transformed_position = map_structure( 

122 lambda x, bij: bij.inverse(x), 

123 position, 

124 bijectors, 

125 ) 

126 chain_state, kernel_state = sampling_algorithm.init( 

127 transformed_tlp, transformed_position, **kwargs 

128 ) 

129 

130 new_chain_state = ChainState( 

131 position=position, 

132 log_density=chain_state.log_density 

133 + jnp.sum( 

134 jnp.asarray( 

135 map_structure( 

136 lambda bij, x: bij.inverse_log_det_jacobian(x), 

137 bijectors, 

138 position, 

139 ) 

140 ) 

141 ), 

142 log_density_grad=(), 

143 ) 

144 

145 return new_chain_state, kernel_state 

146 

147 def step_fn( 

148 target_log_prob_fn: LogProbFnType, 

149 chain_and_kernel_state: ChainAndKernelState, 

150 seed: SeedType, 

151 **kwargs, 

152 ): 

153 chain_state, kernel_state = chain_and_kernel_state 

154 

155 # Transform to a unconstrained space (inverse transform) 

156 transformed_tlp_fn = _transform_tlp_fn(bijectors, target_log_prob_fn) 

157 transformed_position = map_structure( 

158 lambda x, bij: bij.inverse(x), 

159 chain_state.position, 

160 bijectors, 

161 ) 

162 transformed_tlp = chain_state.log_density - jnp.sum( 

163 jnp.asarray( 

164 map_structure( 

165 lambda x, bij: bij.inverse_log_det_jacobian(x), 

166 chain_state.position, 

167 bijectors, 

168 ) 

169 ) 

170 ) 

171 

172 transformed_chain_state = ChainState( 

173 position=transformed_position, 

174 log_density=transformed_tlp, 

175 log_density_grad=(), 

176 ) 

177 (new_chain_state, new_kernel_state), info = sampling_algorithm.step( 

178 transformed_tlp_fn, 

179 (transformed_chain_state, kernel_state), 

180 seed=seed, 

181 **kwargs, 

182 ) 

183 

184 # Transform back to the constrained space 

185 constrained_new_position = map_structure( 

186 lambda x, bij: bij.forward(x), 

187 new_chain_state.position, 

188 bijectors, 

189 ) 

190 constrained_log_density = new_chain_state.log_density + jnp.sum( 

191 jnp.asarray( 

192 map_structure( 

193 lambda x, bij: bij.inverse_log_det_jacobian(x), 

194 constrained_new_position, 

195 bijectors, 

196 ) 

197 ) 

198 ) 

199 new_chain_state = ChainState( 

200 position=constrained_new_position, 

201 log_density=constrained_log_density, 

202 log_density_grad=(), 

203 ) 

204 

205 return (new_chain_state, new_kernel_state), info 

206 

207 return SamplingAlgorithm(init_fn, step_fn)