Coverage for gemlib/math.py: 97%
29 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Mathematical operators and functions"""
3import jax
4import jax.numpy as jnp
5import numpy as np
7__all__ = ["cumsum"]
10def cumsum(
11 x: np.typing.ArrayLike,
12 axis: int = 0,
13 exclusive: bool = False,
14 reverse: bool = False,
15) -> np.typing.ArrayLike:
16 """Compute the cumulative sum of :code:`x` along :code:`axis`
18 Args:
19 x: input array
20 axis: Axis along which to compute the cumulative sum
21 exclusive: if :code:`True`, perform exclusive cumsum
22 reverse: if :code:`True`, perform a reverse cumsum
24 Returns:
25 An array of the same type and shape as :code:`x`
26 """
28 x = jnp.asarray(x)
29 if axis < 0: # Turn neg axis into pos axis
30 axis = x.ndim + axis
32 if reverse is True:
33 x = jnp.flip(x, axis=axis)
35 if exclusive is False:
36 return jnp.cumsum(x, axis)
38 zeros_slice = tuple(
39 slice(0, 1) if i == axis else slice(None) for i in range(x.ndim)
40 )
42 return jnp.roll(jnp.cumsum(x, axis), 1, axis=axis).at[zeros_slice].set(0)
45def cumsum_np(
46 x: np.typing.ArrayLike,
47 axis: int = 0,
48 exclusive: bool = False,
49 reverse: bool = False,
50) -> np.typing.ArrayLike:
51 """Compute the cumulative sum of :code:`x` along :code:`axis`
53 Args:
54 x: input array
55 axis: Axis along which to compute the cumulative sum
56 exclusive: if :code:`True`, perform exclusive cumsum
57 reverse: if :code:`True`, perform a reverse cumsum
59 Returns:
60 An array of the same type and shape as :code:`x`
61 """
63 x = np.asarray(x)
64 if axis < 0: # Turn neg axis into pos axis
65 axis = x.ndim + axis
67 if reverse is True:
68 x = np.flip(x, axis=axis)
70 if exclusive is False:
71 return np.cumsum(x, axis)
73 zeros_slice = tuple(
74 slice(0, 1) if i == axis else slice(None) for i in range(x.ndim)
75 )
77 res = np.roll(np.cumsum(x, axis), 1, axis=axis)
78 res[zeros_slice] = 0
79 return res
82def multiply_no_nan(x: jax.Array, y: jax.Array) -> jax.Array:
83 """Implements the equivalent function to tf.multiply_no_nan.
84 Appears to compile to the same HLO.
86 Args:
87 x (jax.Array): First argument
88 y (jax.Array): Second argument
90 Returns:
91 jax.Array: The elementwise product of x and y, except where y is zero
92 return zero.
93 """
94 prod = x * y
95 return jnp.where(y == 0, jnp.zeros_like(prod), prod)