Coverage for tests/unit/test_optimization.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-11 20:12 +0100

1import pytest 

2import numpy as np 

3from scipy.optimize import linear_sum_assignment 

4 

5from ParTIpy.arch import AA 

6from ParTIpy.generate_test_data import simulate 

7from ParTIpy.const import OPTIM_ALGS, WEIGHT_ALGS, INIT_ALGS 

8 

9def compute_dist_mtx(mtx_1, mtx_2): 

10 AB = np.dot(mtx_1, mtx_2.T) 

11 AA = np.sum(np.square(mtx_1), axis=1) 

12 BB = np.sum(np.square(mtx_2), axis=1) 

13 dist_mtx = (BB - 2 * AB).T + AA 

14 dist_mtx[np.isclose(dist_mtx, 0)] = ( 

15 0 # avoid problems if we get small negative numbers due to numerical inaccuracies 

16 ) 

17 dist_mtx = np.sqrt(dist_mtx) 

18 return dist_mtx 

19 

20def align_archetypes(ref_arch, query_arch): 

21 # not sure if copy here is needed, compute_dist_mtx should not modify the matrices 

22 euclidean_d = compute_dist_mtx(ref_arch, query_arch.copy()).T 

23 ref_idx, query_idx = linear_sum_assignment(euclidean_d) 

24 return query_arch[query_idx, :] 

25 

26def compute_rowwise_correlation(mtx_1, mtx_2): 

27 assert np.all(mtx_1.shape == mtx_2.shape) 

28 mtx_1 = mtx_1 - mtx_1.mean(axis=1, keepdims=True) 

29 mtx_1 /= mtx_1.std(axis=1, keepdims=True) 

30 mtx_2 = mtx_2 - mtx_2.mean(axis=1, keepdims=True) 

31 mtx_2 /= mtx_2.std(axis=1, keepdims=True) 

32 corr_vec = np.mean(mtx_1 * mtx_2, axis=1) 

33 return corr_vec 

34 

35@pytest.mark.parametrize("n_archetypes", list(range(2, 8))) 

36@pytest.mark.parametrize("optim_str", OPTIM_ALGS) 

37def test_that_archetypes_can_be_identified( 

38 n_archetypes: int, 

39 optim_str: str, 

40) -> None: 

41 N_SAMPLES = 1_000 

42 N_DIMENSIONS = 10 

43 MIN_CORR = 0.9 

44 X, A, Z = simulate(n_samples=N_SAMPLES, 

45 n_archetypes=n_archetypes, 

46 n_dimensions=N_DIMENSIONS, 

47 noise_std=0.0, 

48 seed=111) 

49 

50 A_hat, B_hat, Z_hat, RSS, varexpl = \ 

51 AA(n_archetypes=n_archetypes, optim=optim_str).fit(X).return_all() 

52 

53 Z_hat = align_archetypes(Z, Z_hat) 

54 

55 corr_between_archetypes = compute_rowwise_correlation(Z, Z_hat) 

56 assert np.all(corr_between_archetypes > MIN_CORR) 

57 

58@pytest.mark.parametrize("optim_str", OPTIM_ALGS) 

59@pytest.mark.parametrize("weight_str", WEIGHT_ALGS) 

60@pytest.mark.parametrize("init_str", INIT_ALGS) 

61def test_that_input_to_AA_is_not_modfied( 

62 optim_str, 

63 weight_str, 

64 init_str 

65 ) -> None: 

66 N_SAMPLES = 200 

67 N_DIMENSIONS = 3 

68 N_ARCHETYPES = 5 

69 X, A, Z = simulate(n_samples=N_SAMPLES, 

70 n_archetypes=N_ARCHETYPES, 

71 n_dimensions=N_DIMENSIONS, 

72 noise_std=0.0, 

73 seed=111) 

74 X_in = X.copy() 

75 

76 A_hat, B_hat, Z_hat, RSS, varexpl = \ 

77 AA(n_archetypes=N_ARCHETYPES, 

78 optim=optim_str, 

79 weight=weight_str, 

80 init=init_str).fit(X).return_all() 

81 

82 assert np.all(np.isclose(X_in, X))