Coverage for gemlib/prng_util.py: 88%
8 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Pseudo random number utils"""
3import jax
4from jax import Array
7def sanitize_key(key: int | Array | None):
8 """Sanitize a user-supplied key
10 Args:
11 key: either an int or a JAX pseudorandom key
13 Returns:
14 an instance of a JAX pseudorandom key
15 """
16 if key is None:
17 raise ValueError("Explicit key required")
19 if isinstance(key, int):
20 key = jax.random.key(key)
22 return key