Coverage for gemlib/spatial/sp_dist.py: 93%

46 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""Compute a sparse distance matrix given coordinates""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable 

6from functools import partial 

7 

8import jax 

9import jax.experimental.sparse as jsp 

10import jax.numpy as jnp 

11import numpy as np 

12from tqdm import tqdm 

13 

14__all__ = ["pdist", "sparse_pdist"] 

15 

16Array = jax.Array 

17ArrayLike = jax.typing.ArrayLike 

18BooleanArray = np.typing.NDArray[np.bool] 

19 

20 

21def pdist(a: ArrayLike, b: ArrayLike) -> Array: 

22 """Compute the Euclidean distance between a and b 

23 

24 Args: 

25 a: a :code:`[N, D]` tensor of coordinates 

26 b: a :code:`[M, D]` tensor of coordinates 

27 

28 Returns: 

29 A :code:`[N, M]` matrix of Euclidean distances between 

30 coordinates. 

31 """ 

32 delta = a[..., np.newaxis, :] - b[np.newaxis, ...] 

33 sqdist = jnp.sum(delta * delta, axis=-1) 

34 

35 return jnp.sqrt(sqdist) 

36 

37 

38def include_all(x: np.typing.ArrayLike): 

39 x_ = jnp.asarray(x) 

40 return jnp.full(x_.shape, True) 

41 

42 

43@partial(jax.jit, static_argnums=2) 

44def _pdist_indices_mask(a, b, include_fn): 

45 values = pdist(a, b).flatten() 

46 mask = include_fn(values) 

47 

48 flat_index = jnp.cumsum(mask) - 1 

49 row_idx = flat_index // a.shape[0] 

50 col_idx = flat_index % b.shape[0] 

51 indices = jnp.stack([row_idx, col_idx], axis=-1) 

52 return values, indices, mask 

53 

54 

55def compress_distance(a, b, include_fn): 

56 """Return a sparse tensor containing all distances 

57 between :code:`a` and :code:`b` less than :code:`max_dist`. 

58 """ 

59 values, indices, mask = _pdist_indices_mask(a, b, include_fn) 

60 

61 return values[mask], indices[mask] 

62 

63 

64def sparse_pdist( 

65 coords: ArrayLike, 

66 include_fn: Callable[[ArrayLike], BooleanArray] = include_all, 

67 chunk_size: int | None = None, 

68) -> Array: 

69 """Compute a sparse distance matrix 

70 

71 Compute a sparse Euclidean distance matrix between all pairs of 

72 :code:`coords` such that the distance is less than :code:`max_dist`. 

73 

74 Args: 

75 coords: a ``[N, D]`` array of coordinates 

76 include_fn: a callable that takes a float representing the distance 

77 between two points, and returns :code:`True` if the 

78 distance should be included as a "non-zero" element 

79 of the returned sparse matrix. 

80 batch_size: If memory is limited, compute the distances in batches 

81 of ``[batch_size, N]`` stripes. 

82 

83 Returns: 

84 A sparse tensor of Euclidean distances satisfying `include_fn`. 

85 

86 Example: 

87 

88 >>> import numpy as np 

89 >>> from gemlib.spatial import sparse_pdist 

90 >>> coords = np.random.uniform(size=(1000, 2)) 

91 >>> d_sparse = sparse_pdist(coords, max_dist=0.01, batch_size=200) 

92 >>> d_sparse 

93 SparseTensor(indices=tf.Tensor( 

94 [[ 0 0] 

95 [ 1 1] 

96 [ 2 2] 

97 ... 

98 [997 997] 

99 [998 998] 

100 [999 999]], shape=(1316, 2), dtype=int64), values=tf.Tensor( 

101 [0.00000000e+00 2.22044605e-16 0.00000000e+00 ... 0.00000000e+00 

102 0.00000000e+00 0.00000000e+00], shape=(1316,), dtype=float64), 

103 dense_shape=tf.Tensor([1000 1000], shape=(2,), dtype=int64)) 

104 

105 """ 

106 coords = np.asarray(coords) 

107 num_coords = coords.shape[-2] 

108 

109 if chunk_size is None: 

110 chunk_size = num_coords 

111 

112 cpu = jax.devices("cpu")[0] 

113 values_accum = [] 

114 indices_accum = [] 

115 

116 for i in tqdm( 

117 range(0, num_coords, chunk_size), 

118 unit="rows", 

119 unit_scale=chunk_size, 

120 miniters=1, 

121 ): 

122 j = np.minimum(i + chunk_size, num_coords) 

123 values, indices = compress_distance(coords[i:j], coords, include_fn) 

124 values_accum.append(jax.device_put(values, cpu)) 

125 indices_accum.append(jax.device_put(indices, cpu)) 

126 

127 res = jsp.BCOO( 

128 (jnp.concatenate(values_accum, 0), jnp.concatenate(indices_accum, 0)), 

129 shape=(num_coords, num_coords), 

130 indices_sorted=True, 

131 unique_indices=True, 

132 ) 

133 

134 return res