Coverage for gemlib/math.py: 97%

29 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Mathematical operators and functions""" 

2 

3import jax 

4import jax.numpy as jnp 

5import numpy as np 

6 

7__all__ = ["cumsum"] 

8 

9 

10def cumsum( 

11 x: np.typing.ArrayLike, 

12 axis: int = 0, 

13 exclusive: bool = False, 

14 reverse: bool = False, 

15) -> np.typing.ArrayLike: 

16 """Compute the cumulative sum of :code:`x` along :code:`axis` 

17 

18 Args: 

19 x: input array 

20 axis: Axis along which to compute the cumulative sum 

21 exclusive: if :code:`True`, perform exclusive cumsum 

22 reverse: if :code:`True`, perform a reverse cumsum 

23 

24 Returns: 

25 An array of the same type and shape as :code:`x` 

26 """ 

27 

28 x = jnp.asarray(x) 

29 if axis < 0: # Turn neg axis into pos axis 

30 axis = x.ndim + axis 

31 

32 if reverse is True: 

33 x = jnp.flip(x, axis=axis) 

34 

35 if exclusive is False: 

36 return jnp.cumsum(x, axis) 

37 

38 zeros_slice = tuple( 

39 slice(0, 1) if i == axis else slice(None) for i in range(x.ndim) 

40 ) 

41 

42 return jnp.roll(jnp.cumsum(x, axis), 1, axis=axis).at[zeros_slice].set(0) 

43 

44 

45def cumsum_np( 

46 x: np.typing.ArrayLike, 

47 axis: int = 0, 

48 exclusive: bool = False, 

49 reverse: bool = False, 

50) -> np.typing.ArrayLike: 

51 """Compute the cumulative sum of :code:`x` along :code:`axis` 

52 

53 Args: 

54 x: input array 

55 axis: Axis along which to compute the cumulative sum 

56 exclusive: if :code:`True`, perform exclusive cumsum 

57 reverse: if :code:`True`, perform a reverse cumsum 

58 

59 Returns: 

60 An array of the same type and shape as :code:`x` 

61 """ 

62 

63 x = np.asarray(x) 

64 if axis < 0: # Turn neg axis into pos axis 

65 axis = x.ndim + axis 

66 

67 if reverse is True: 

68 x = np.flip(x, axis=axis) 

69 

70 if exclusive is False: 

71 return np.cumsum(x, axis) 

72 

73 zeros_slice = tuple( 

74 slice(0, 1) if i == axis else slice(None) for i in range(x.ndim) 

75 ) 

76 

77 res = np.roll(np.cumsum(x, axis), 1, axis=axis) 

78 res[zeros_slice] = 0 

79 return res 

80 

81 

82def multiply_no_nan(x: jax.Array, y: jax.Array) -> jax.Array: 

83 """Implements the equivalent function to tf.multiply_no_nan. 

84 Appears to compile to the same HLO. 

85 

86 Args: 

87 x (jax.Array): First argument 

88 y (jax.Array): Second argument 

89 

90 Returns: 

91 jax.Array: The elementwise product of x and y, except where y is zero 

92 return zero. 

93 """ 

94 prod = x * y 

95 return jnp.where(y == 0, jnp.zeros_like(prod), prod)