Coverage for gemlib/mcmc/mcmc_util.py: 100%
11 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""MCMC utility functions"""
3import jax
4import numpy as np
5import tensorflow_probability.substrates.jax as tfp
8def is_list_like(x):
9 return isinstance(x, list | tuple)
12def get_flattening_bijector(example):
13 """A bijector that converts a data structure to a 1-D tensor"""
15 flat_example, example_treedef = jax.tree.flatten(example)
17 split = tfp.bijectors.Split(
18 [np.prod(x.shape) for x in flat_example], axis=-1
19 )
20 reshape = tfp.bijectors.JointMap(
21 [tfp.bijectors.Reshape(x.shape) for x in flat_example]
22 )
23 restructure = tfp.bijectors.Restructure(
24 output_structure=jax.tree.unflatten(
25 example_treedef, range(len(flat_example))
26 )
27 )
29 return tfp.bijectors.Invert(
30 tfp.bijectors.Chain([restructure, reshape, split])
31 )