Coverage for tests/unit/test_optimization.py: 100%
48 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-11 20:12 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-11 20:12 +0100
1import pytest
2import numpy as np
3from scipy.optimize import linear_sum_assignment
5from ParTIpy.arch import AA
6from ParTIpy.generate_test_data import simulate
7from ParTIpy.const import OPTIM_ALGS, WEIGHT_ALGS, INIT_ALGS
9def compute_dist_mtx(mtx_1, mtx_2):
10 AB = np.dot(mtx_1, mtx_2.T)
11 AA = np.sum(np.square(mtx_1), axis=1)
12 BB = np.sum(np.square(mtx_2), axis=1)
13 dist_mtx = (BB - 2 * AB).T + AA
14 dist_mtx[np.isclose(dist_mtx, 0)] = (
15 0 # avoid problems if we get small negative numbers due to numerical inaccuracies
16 )
17 dist_mtx = np.sqrt(dist_mtx)
18 return dist_mtx
20def align_archetypes(ref_arch, query_arch):
21 # not sure if copy here is needed, compute_dist_mtx should not modify the matrices
22 euclidean_d = compute_dist_mtx(ref_arch, query_arch.copy()).T
23 ref_idx, query_idx = linear_sum_assignment(euclidean_d)
24 return query_arch[query_idx, :]
26def compute_rowwise_correlation(mtx_1, mtx_2):
27 assert np.all(mtx_1.shape == mtx_2.shape)
28 mtx_1 = mtx_1 - mtx_1.mean(axis=1, keepdims=True)
29 mtx_1 /= mtx_1.std(axis=1, keepdims=True)
30 mtx_2 = mtx_2 - mtx_2.mean(axis=1, keepdims=True)
31 mtx_2 /= mtx_2.std(axis=1, keepdims=True)
32 corr_vec = np.mean(mtx_1 * mtx_2, axis=1)
33 return corr_vec
35@pytest.mark.parametrize("n_archetypes", list(range(2, 8)))
36@pytest.mark.parametrize("optim_str", OPTIM_ALGS)
37def test_that_archetypes_can_be_identified(
38 n_archetypes: int,
39 optim_str: str,
40) -> None:
41 N_SAMPLES = 1_000
42 N_DIMENSIONS = 10
43 MIN_CORR = 0.9
44 X, A, Z = simulate(n_samples=N_SAMPLES,
45 n_archetypes=n_archetypes,
46 n_dimensions=N_DIMENSIONS,
47 noise_std=0.0,
48 seed=111)
50 A_hat, B_hat, Z_hat, RSS, varexpl = \
51 AA(n_archetypes=n_archetypes, optim=optim_str).fit(X).return_all()
53 Z_hat = align_archetypes(Z, Z_hat)
55 corr_between_archetypes = compute_rowwise_correlation(Z, Z_hat)
56 assert np.all(corr_between_archetypes > MIN_CORR)
58@pytest.mark.parametrize("optim_str", OPTIM_ALGS)
59@pytest.mark.parametrize("weight_str", WEIGHT_ALGS)
60@pytest.mark.parametrize("init_str", INIT_ALGS)
61def test_that_input_to_AA_is_not_modfied(
62 optim_str,
63 weight_str,
64 init_str
65 ) -> None:
66 N_SAMPLES = 200
67 N_DIMENSIONS = 3
68 N_ARCHETYPES = 5
69 X, A, Z = simulate(n_samples=N_SAMPLES,
70 n_archetypes=N_ARCHETYPES,
71 n_dimensions=N_DIMENSIONS,
72 noise_std=0.0,
73 seed=111)
74 X_in = X.copy()
76 A_hat, B_hat, Z_hat, RSS, varexpl = \
77 AA(n_archetypes=N_ARCHETYPES,
78 optim=optim_str,
79 weight=weight_str,
80 init=init_str).fit(X).return_all()
82 assert np.all(np.isclose(X_in, X))