Coverage for gemlib/tensor_util.py: 94%

17 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Tensor utilities""" 

2 

3from functools import reduce as ft_reduce 

4 

5import jax.numpy as jnp 

6 

7__all__ = ["broadcast_fn_to", "broadcast_together"] 

8 

9 

10# TODO: Port to jax 

11 

12 

13def broadcast_together(*tensors): 

14 """Broadcast tensors together 

15 

16 Args: 

17 *tensors: tensors 

18 

19 Returns: 

20 a tuple of tensors broadcast to have common shape 

21 """ 

22 shapes = [jnp.asarray(x).shape for x in tensors] 

23 

24 common_shape = ft_reduce(lambda a, x: jnp.broadcast_shapes(a, x), shapes) 

25 broadcast_tensors = [jnp.broadcast_to(x, common_shape) for x in tensors] 

26 

27 if len(tensors) == 1: 

28 return broadcast_tensors[0] 

29 

30 return tuple(broadcast_tensors) 

31 

32 

33def broadcast_fn_to(func, shape): 

34 """Transform function `func` such that its outputs broadcast to `shape`""" 

35 

36 def wrapped(*args, **kwargs): 

37 retval = func(*args, **kwargs) 

38 if isinstance(retval, tuple): 

39 return tuple(jnp.broadcast_to(x, shape) for x in retval) 

40 return (jnp.broadcast_to(retval, shape),) 

41 

42 return wrapped