Coverage for gemlib/func_util.py: 100%

15 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Utility functions for state transition models""" 

2 

3from collections.abc import Callable, Iterable 

4from warnings import warn 

5 

6 

7def _check_deprecated(fn): 

8 def dep_fn(*args, **kwargs): 

9 result = fn(*args, **kwargs) 

10 if isinstance(result, tuple): 

11 warn( 

12 "Returning a tuple of tensors is \ 

13 deprecated. Please instead supply a list of functions returning \ 

14 single tensors. This functionality will be removed in a future \ 

15 release.", 

16 DeprecationWarning, 

17 stacklevel=3, 

18 ) 

19 return result 

20 

21 return dep_fn 

22 

23 

24def maybe_combine_fn(fn: Callable | Iterable[Callable]) -> Callable: 

25 """Takes an iterable of Callables of the same signature, returning a 

26 function that combines their results as a tuple.""" 

27 

28 if isinstance(fn, Iterable): 

29 

30 def fn_combined(*args, **kwargs): 

31 return tuple([f(*args, **kwargs) for f in fn]) 

32 

33 return fn_combined 

34 

35 return _check_deprecated(fn)