Coverage for partipy/utils.py: 86%
21 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:24 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-09 10:24 +0200
1import numpy as np
2from scipy.optimize import linear_sum_assignment
3from scipy.spatial.distance import cdist, pdist
6def align_archetypes(ref_arch: np.ndarray, query_arch: np.ndarray) -> np.ndarray:
7 """Align the query archetypes to the reference archetypes"""
8 assert np.all(ref_arch.shape == query_arch.shape)
9 euclidean_d = cdist(ref_arch, query_arch)
10 ref_idx, query_idx = linear_sum_assignment(euclidean_d)
11 return query_arch[query_idx, :]
14def compute_rowwise_l2_distance(mtx_1: np.ndarray, mtx_2: np.ndarray) -> np.ndarray:
15 """Compute l2 distance between the rows of mtx 1 and the rows of mtx 2"""
16 assert np.all(mtx_1.shape == mtx_2.shape)
17 dist = np.sqrt(np.sum(np.square(mtx_1 - mtx_2), axis=1))
18 return dist
21def compute_rowwise_l1_distance(mtx_1: np.ndarray, mtx_2: np.ndarray) -> np.ndarray:
22 """Compute l1 distance between the rows of mtx 1 and the rows of mtx 2"""
23 assert np.all(mtx_1.shape == mtx_2.shape)
24 dist = np.sum(np.abs(mtx_1 - mtx_2), axis=1)
25 return dist
28def compute_relative_rowwise_l2_distance(mtx_1: np.ndarray, mtx_2: np.ndarray) -> np.ndarray:
29 """Compute relative l2 distance between the rows of mtx 1 and the rows of mtx 2"""
30 rowwise_l2 = compute_rowwise_l2_distance(mtx_1, mtx_2)
31 archetype_dispersion = np.mean(pdist(mtx_1)) # average pairwise distance
32 rowwise_l2_normalized = rowwise_l2 / archetype_dispersion
33 return rowwise_l2_normalized