Coverage for gemlib/util.py: 100%
24 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Utility functions for model implementation code."""
3import jax.numpy as jnp
4import numpy as np
5from jax import Array
6from jax.typing import ArrayLike
8from gemlib.math import cumsum_np
11def batch_gather(arr: ArrayLike, indices: ArrayLike) -> Array:
12 """Batched gather of ``indices`` from the right-most dimensions of ``arr``.
14 This function gathers elements on the right-most ``indices`` of ``arr``.
15 If ``arr`` has dimension ``m`` and shape
16 ``(a_1, a_2, ..., a_{m-d}, a_{m-d+1}, ..., a_{m})``
17 and ``indices`` has dimension ``n`` and shape ``(b_1, ..., b_{n-1}, d)``,
18 then the output has shape ``(a_1, ..., a_{m-d}, b_1, ..., b_{n})``,
19 where ``output[i_1, ... i_{m-d}, j_1, ... , j_{n-1}] =
20 a[i_1, ..., i_{m-d}, b[j_1, ..., j_{n-1}, 1], ...,
21 b[j_1, ..., j_{n-1}, d]]``
23 This function is equivalent to
24 `` indices = jnp.movejax.Arrayaxis(jnp.asarray(indices, dtype=jnp.int32),
25 -1, 0)
26 return arr[..., *indices]``
27 but avoids fancy indexing.
29 Args:
30 arr (jax.Array): an ``n``-dimensional array.
31 indices (jax.Array): a array of coordinates into the rightmost
32 `indices.shape[-1]` dimensions of `arr`.
34 Returns:
35 jax.Array: Gathered values of dimension
36 ``arr.ndim - indices.shape[-1] + indices.ndim - 1``
37 """
39 arr = jnp.asarray(arr)
40 indices = jnp.asarray(indices, dtype=jnp.int32)
41 index_dims = indices.shape[-1]
42 flat_shape = arr.shape[:-index_dims] + (-1,)
43 flat_arr = jnp.reshape(arr, flat_shape)
45 stride_shape = (arr.shape[-index_dims:])[1:] + (1,)
46 strides = jnp.cumprod(jnp.array(stride_shape)[::-1])[::-1]
48 flat_indices = jnp.dot(indices, strides)
49 # Alternatively, flat_indices = jnp.ravel_multi_indices(indices,
50 # arr.shape[:-index_dims])
52 return jnp.take(flat_arr, flat_indices, axis=-1)
55def transition_coords(
56 incidence_matrix: np.ndarray, dtype=np.int32
57) -> np.ndarray:
58 """Compute coordinates of transitions in an incidence matrix or batch
59 of incidence matrices.
61 For each ``incidence_matrix``, return the indices of the
62 source and destination states.
64 Note: this requires ``incidence_matrix``
65 being that for a state transition model, with a single negative (source)
66 and positive (destination) entry in each column, and all other entries
67 ``0``.
70 Args:
71 incidence_matrix(np.ndarray): an array of shape ``(..., S, R)``
72 containing a single or a batch of incidence matrices,
73 each representing ``R`` transitions within ``S`` states.
75 dtype (type): Data type of returned array.
77 Returns:
78 np.ndarray: An array of shape ``(..., R, 2)``,
79 where the ``[..., r, 0]`` entries is the coordinate of the source
80 for the r-th rate, and ``[..., r, 1]`` entry the destination
81 for the r-th rate.
83 """
85 is_src_dest = np.stack(
86 [incidence_matrix < 0, incidence_matrix > 0], axis=-1
87 )
89 coords = np.sum(
90 cumsum_np(
91 is_src_dest.astype(dtype),
92 reverse=True,
93 exclusive=True,
94 axis=-3,
95 ),
96 axis=-3,
97 )
98 return coords.astype(dtype)
101def transition_coords_tuple(
102 incidence_matrix: np.ndarray, dtype=jnp.int32
103) -> np.ndarray:
104 return tuple(
105 map(tuple, transition_coords(incidence_matrix, dtype).tolist())
106 )
109def states_from_transition_idx(
110 transition_index: int,
111 incidence_matrix: np.ndarray,
112) -> tuple[np.ndarray, np.ndarray]:
113 """Return source and destination state indices from an incidence matrix
114 for a particular transition index.
116 Given the index of a transition in ``incidence_matrix``, return
117 the indices of the source and destination states.
119 Note: this requires the ``incidence_matrix``
120 to have one negative (source) and one positive (destination) entry in
121 each column, and all other entries ``0``.
123 Args:
124 transition_index (int): the index (column) of the event in
125 ``incidence_matrix``.
126 incidence_matrix (np.ndarray): an array of shape ``(..., S, R)``
127 containing a single or a batch of incidence matrices,
128 each representing ``R`` transitions within ``S`` states.
131 Returns:
132 tuple[np.ndarray, np.ndarray]: A pair of integer arrays of shape
133 ``(..., R)``, giving the indices of the source and destination
134 states for each rate.
136 """
137 coords = transition_coords(incidence_matrix)[..., transition_index, :]
139 return coords[..., 0], coords[..., 1]