Coverage for gemlib/tensor_util.py: 94%
17 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Tensor utilities"""
3from functools import reduce as ft_reduce
5import jax.numpy as jnp
7__all__ = ["broadcast_fn_to", "broadcast_together"]
10# TODO: Port to jax
13def broadcast_together(*tensors):
14 """Broadcast tensors together
16 Args:
17 *tensors: tensors
19 Returns:
20 a tuple of tensors broadcast to have common shape
21 """
22 shapes = [jnp.asarray(x).shape for x in tensors]
24 common_shape = ft_reduce(lambda a, x: jnp.broadcast_shapes(a, x), shapes)
25 broadcast_tensors = [jnp.broadcast_to(x, common_shape) for x in tensors]
27 if len(tensors) == 1:
28 return broadcast_tensors[0]
30 return tuple(broadcast_tensors)
33def broadcast_fn_to(func, shape):
34 """Transform function `func` such that its outputs broadcast to `shape`"""
36 def wrapped(*args, **kwargs):
37 retval = func(*args, **kwargs)
38 if isinstance(retval, tuple):
39 return tuple(jnp.broadcast_to(x, shape) for x in retval)
40 return (jnp.broadcast_to(retval, shape),)
42 return wrapped