Coverage for gemlib/prng_util.py: 88%

8 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Pseudo random number utils""" 

2 

3import jax 

4from jax import Array 

5 

6 

7def sanitize_key(key: int | Array | None): 

8 """Sanitize a user-supplied key 

9 

10 Args: 

11 key: either an int or a JAX pseudorandom key 

12 

13 Returns: 

14 an instance of a JAX pseudorandom key 

15 """ 

16 if key is None: 

17 raise ValueError("Explicit key required") 

18 

19 if isinstance(key, int): 

20 key = jax.random.key(key) 

21 

22 return key