Coverage for gemlib/mcmc/mcmc_util.py: 100%

11 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""MCMC utility functions""" 

2 

3import jax 

4import numpy as np 

5import tensorflow_probability.substrates.jax as tfp 

6 

7 

8def is_list_like(x): 

9 return isinstance(x, list | tuple) 

10 

11 

12def get_flattening_bijector(example): 

13 """A bijector that converts a data structure to a 1-D tensor""" 

14 

15 flat_example, example_treedef = jax.tree.flatten(example) 

16 

17 split = tfp.bijectors.Split( 

18 [np.prod(x.shape) for x in flat_example], axis=-1 

19 ) 

20 reshape = tfp.bijectors.JointMap( 

21 [tfp.bijectors.Reshape(x.shape) for x in flat_example] 

22 ) 

23 restructure = tfp.bijectors.Restructure( 

24 output_structure=jax.tree.unflatten( 

25 example_treedef, range(len(flat_example)) 

26 ) 

27 ) 

28 

29 return tfp.bijectors.Invert( 

30 tfp.bijectors.Chain([restructure, reshape, split]) 

31 )