Coverage for gemlib/util.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Utility functions for model implementation code.""" 

2 

3import jax.numpy as jnp 

4import numpy as np 

5from jax import Array 

6from jax.typing import ArrayLike 

7 

8from gemlib.math import cumsum_np 

9 

10 

11def batch_gather(arr: ArrayLike, indices: ArrayLike) -> Array: 

12 """Batched gather of ``indices`` from the right-most dimensions of ``arr``. 

13 

14 This function gathers elements on the right-most ``indices`` of ``arr``. 

15 If ``arr`` has dimension ``m`` and shape 

16 ``(a_1, a_2, ..., a_{m-d}, a_{m-d+1}, ..., a_{m})`` 

17 and ``indices`` has dimension ``n`` and shape ``(b_1, ..., b_{n-1}, d)``, 

18 then the output has shape ``(a_1, ..., a_{m-d}, b_1, ..., b_{n})``, 

19 where ``output[i_1, ... i_{m-d}, j_1, ... , j_{n-1}] = 

20 a[i_1, ..., i_{m-d}, b[j_1, ..., j_{n-1}, 1], ..., 

21 b[j_1, ..., j_{n-1}, d]]`` 

22 

23 This function is equivalent to 

24 `` indices = jnp.movejax.Arrayaxis(jnp.asarray(indices, dtype=jnp.int32), 

25 -1, 0) 

26 return arr[..., *indices]`` 

27 but avoids fancy indexing. 

28 

29 Args: 

30 arr (jax.Array): an ``n``-dimensional array. 

31 indices (jax.Array): a array of coordinates into the rightmost 

32 `indices.shape[-1]` dimensions of `arr`. 

33 

34 Returns: 

35 jax.Array: Gathered values of dimension 

36 ``arr.ndim - indices.shape[-1] + indices.ndim - 1`` 

37 """ 

38 

39 arr = jnp.asarray(arr) 

40 indices = jnp.asarray(indices, dtype=jnp.int32) 

41 index_dims = indices.shape[-1] 

42 flat_shape = arr.shape[:-index_dims] + (-1,) 

43 flat_arr = jnp.reshape(arr, flat_shape) 

44 

45 stride_shape = (arr.shape[-index_dims:])[1:] + (1,) 

46 strides = jnp.cumprod(jnp.array(stride_shape)[::-1])[::-1] 

47 

48 flat_indices = jnp.dot(indices, strides) 

49 # Alternatively, flat_indices = jnp.ravel_multi_indices(indices, 

50 # arr.shape[:-index_dims]) 

51 

52 return jnp.take(flat_arr, flat_indices, axis=-1) 

53 

54 

55def transition_coords( 

56 incidence_matrix: np.ndarray, dtype=np.int32 

57) -> np.ndarray: 

58 """Compute coordinates of transitions in an incidence matrix or batch 

59 of incidence matrices. 

60 

61 For each ``incidence_matrix``, return the indices of the 

62 source and destination states. 

63 

64 Note: this requires ``incidence_matrix`` 

65 being that for a state transition model, with a single negative (source) 

66 and positive (destination) entry in each column, and all other entries 

67 ``0``. 

68 

69 

70 Args: 

71 incidence_matrix(np.ndarray): an array of shape ``(..., S, R)`` 

72 containing a single or a batch of incidence matrices, 

73 each representing ``R`` transitions within ``S`` states. 

74 

75 dtype (type): Data type of returned array. 

76 

77 Returns: 

78 np.ndarray: An array of shape ``(..., R, 2)``, 

79 where the ``[..., r, 0]`` entries is the coordinate of the source 

80 for the r-th rate, and ``[..., r, 1]`` entry the destination 

81 for the r-th rate. 

82 

83 """ 

84 

85 is_src_dest = np.stack( 

86 [incidence_matrix < 0, incidence_matrix > 0], axis=-1 

87 ) 

88 

89 coords = np.sum( 

90 cumsum_np( 

91 is_src_dest.astype(dtype), 

92 reverse=True, 

93 exclusive=True, 

94 axis=-3, 

95 ), 

96 axis=-3, 

97 ) 

98 return coords.astype(dtype) 

99 

100 

101def transition_coords_tuple( 

102 incidence_matrix: np.ndarray, dtype=jnp.int32 

103) -> np.ndarray: 

104 return tuple( 

105 map(tuple, transition_coords(incidence_matrix, dtype).tolist()) 

106 ) 

107 

108 

109def states_from_transition_idx( 

110 transition_index: int, 

111 incidence_matrix: np.ndarray, 

112) -> tuple[np.ndarray, np.ndarray]: 

113 """Return source and destination state indices from an incidence matrix 

114 for a particular transition index. 

115 

116 Given the index of a transition in ``incidence_matrix``, return 

117 the indices of the source and destination states. 

118 

119 Note: this requires the ``incidence_matrix`` 

120 to have one negative (source) and one positive (destination) entry in 

121 each column, and all other entries ``0``. 

122 

123 Args: 

124 transition_index (int): the index (column) of the event in 

125 ``incidence_matrix``. 

126 incidence_matrix (np.ndarray): an array of shape ``(..., S, R)`` 

127 containing a single or a batch of incidence matrices, 

128 each representing ``R`` transitions within ``S`` states. 

129 

130 

131 Returns: 

132 tuple[np.ndarray, np.ndarray]: A pair of integer arrays of shape 

133 ``(..., R)``, giving the indices of the source and destination 

134 states for each rate. 

135 

136 """ 

137 coords = transition_coords(incidence_matrix)[..., transition_index, :] 

138 

139 return coords[..., 0], coords[..., 1]