Coverage for partipy/utils.py: 86%

21 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 10:24 +0200

1import numpy as np 

2from scipy.optimize import linear_sum_assignment 

3from scipy.spatial.distance import cdist, pdist 

4 

5 

6def align_archetypes(ref_arch: np.ndarray, query_arch: np.ndarray) -> np.ndarray: 

7 """Align the query archetypes to the reference archetypes""" 

8 assert np.all(ref_arch.shape == query_arch.shape) 

9 euclidean_d = cdist(ref_arch, query_arch) 

10 ref_idx, query_idx = linear_sum_assignment(euclidean_d) 

11 return query_arch[query_idx, :] 

12 

13 

14def compute_rowwise_l2_distance(mtx_1: np.ndarray, mtx_2: np.ndarray) -> np.ndarray: 

15 """Compute l2 distance between the rows of mtx 1 and the rows of mtx 2""" 

16 assert np.all(mtx_1.shape == mtx_2.shape) 

17 dist = np.sqrt(np.sum(np.square(mtx_1 - mtx_2), axis=1)) 

18 return dist 

19 

20 

21def compute_rowwise_l1_distance(mtx_1: np.ndarray, mtx_2: np.ndarray) -> np.ndarray: 

22 """Compute l1 distance between the rows of mtx 1 and the rows of mtx 2""" 

23 assert np.all(mtx_1.shape == mtx_2.shape) 

24 dist = np.sum(np.abs(mtx_1 - mtx_2), axis=1) 

25 return dist 

26 

27 

28def compute_relative_rowwise_l2_distance(mtx_1: np.ndarray, mtx_2: np.ndarray) -> np.ndarray: 

29 """Compute relative l2 distance between the rows of mtx 1 and the rows of mtx 2""" 

30 rowwise_l2 = compute_rowwise_l2_distance(mtx_1, mtx_2) 

31 archetype_dispersion = np.mean(pdist(mtx_1)) # average pairwise distance 

32 rowwise_l2_normalized = rowwise_l2 / archetype_dispersion 

33 return rowwise_l2_normalized