Coverage for gemlib/spatial/sp_dist.py: 93%
46 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""Compute a sparse distance matrix given coordinates"""
3from __future__ import annotations
5from collections.abc import Callable
6from functools import partial
8import jax
9import jax.experimental.sparse as jsp
10import jax.numpy as jnp
11import numpy as np
12from tqdm import tqdm
14__all__ = ["pdist", "sparse_pdist"]
16Array = jax.Array
17ArrayLike = jax.typing.ArrayLike
18BooleanArray = np.typing.NDArray[np.bool]
21def pdist(a: ArrayLike, b: ArrayLike) -> Array:
22 """Compute the Euclidean distance between a and b
24 Args:
25 a: a :code:`[N, D]` tensor of coordinates
26 b: a :code:`[M, D]` tensor of coordinates
28 Returns:
29 A :code:`[N, M]` matrix of Euclidean distances between
30 coordinates.
31 """
32 delta = a[..., np.newaxis, :] - b[np.newaxis, ...]
33 sqdist = jnp.sum(delta * delta, axis=-1)
35 return jnp.sqrt(sqdist)
38def include_all(x: np.typing.ArrayLike):
39 x_ = jnp.asarray(x)
40 return jnp.full(x_.shape, True)
43@partial(jax.jit, static_argnums=2)
44def _pdist_indices_mask(a, b, include_fn):
45 values = pdist(a, b).flatten()
46 mask = include_fn(values)
48 flat_index = jnp.cumsum(mask) - 1
49 row_idx = flat_index // a.shape[0]
50 col_idx = flat_index % b.shape[0]
51 indices = jnp.stack([row_idx, col_idx], axis=-1)
52 return values, indices, mask
55def compress_distance(a, b, include_fn):
56 """Return a sparse tensor containing all distances
57 between :code:`a` and :code:`b` less than :code:`max_dist`.
58 """
59 values, indices, mask = _pdist_indices_mask(a, b, include_fn)
61 return values[mask], indices[mask]
64def sparse_pdist(
65 coords: ArrayLike,
66 include_fn: Callable[[ArrayLike], BooleanArray] = include_all,
67 chunk_size: int | None = None,
68) -> Array:
69 """Compute a sparse distance matrix
71 Compute a sparse Euclidean distance matrix between all pairs of
72 :code:`coords` such that the distance is less than :code:`max_dist`.
74 Args:
75 coords: a ``[N, D]`` array of coordinates
76 include_fn: a callable that takes a float representing the distance
77 between two points, and returns :code:`True` if the
78 distance should be included as a "non-zero" element
79 of the returned sparse matrix.
80 batch_size: If memory is limited, compute the distances in batches
81 of ``[batch_size, N]`` stripes.
83 Returns:
84 A sparse tensor of Euclidean distances satisfying `include_fn`.
86 Example:
88 >>> import numpy as np
89 >>> from gemlib.spatial import sparse_pdist
90 >>> coords = np.random.uniform(size=(1000, 2))
91 >>> d_sparse = sparse_pdist(coords, max_dist=0.01, batch_size=200)
92 >>> d_sparse
93 SparseTensor(indices=tf.Tensor(
94 [[ 0 0]
95 [ 1 1]
96 [ 2 2]
97 ...
98 [997 997]
99 [998 998]
100 [999 999]], shape=(1316, 2), dtype=int64), values=tf.Tensor(
101 [0.00000000e+00 2.22044605e-16 0.00000000e+00 ... 0.00000000e+00
102 0.00000000e+00 0.00000000e+00], shape=(1316,), dtype=float64),
103 dense_shape=tf.Tensor([1000 1000], shape=(2,), dtype=int64))
105 """
106 coords = np.asarray(coords)
107 num_coords = coords.shape[-2]
109 if chunk_size is None:
110 chunk_size = num_coords
112 cpu = jax.devices("cpu")[0]
113 values_accum = []
114 indices_accum = []
116 for i in tqdm(
117 range(0, num_coords, chunk_size),
118 unit="rows",
119 unit_scale=chunk_size,
120 miniters=1,
121 ):
122 j = np.minimum(i + chunk_size, num_coords)
123 values, indices = compress_distance(coords[i:j], coords, include_fn)
124 values_accum.append(jax.device_put(values, cpu))
125 indices_accum.append(jax.device_put(indices, cpu))
127 res = jsp.BCOO(
128 (jnp.concatenate(values_accum, 0), jnp.concatenate(indices_accum, 0)),
129 shape=(num_coords, num_coords),
130 indices_sorted=True,
131 unique_indices=True,
132 )
134 return res