Coverage for src/driada/dimensionality/utils.py: 95.45%

44 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1from scipy.spatial.distance import pdist, cdist 

2from scipy.sparse.csgraph import shortest_path 

3from scipy.stats import pearsonr, norm 

4from scipy.linalg import eigh 

5import numpy as np 

6import warnings 

7 

8 

9def res_var_metric(all_dists, emb_dists): 

10 m = 1 - (pearsonr(all_dists, emb_dists)[0])**2 

11 return m 

12 

13 

14def correct_cov_spectrum(N, T, cmat, correction_iters=10, ensemble_size=1, min_eigenvalue=1e-10): 

15 """ 

16 Correct the spectrum of a covariance/correlation matrix. 

17  

18 Parameters 

19 ---------- 

20 N : int 

21 Number of variables (neurons). 

22 T : int 

23 Number of time points. 

24 cmat : ndarray 

25 Covariance or correlation matrix. 

26 correction_iters : int, optional 

27 Number of correction iterations. Default is 10. 

28 ensemble_size : int, optional 

29 Size of the ensemble for phase 1. Default is 1. 

30 min_eigenvalue : float, optional 

31 Minimum eigenvalue threshold to avoid numerical issues. Default is 1e-10. 

32  

33 Returns 

34 ------- 

35 corrected_eigs : list 

36 List of eigenvalue arrays for each iteration. 

37 """ 

38 eigs = eigh(cmat, eigvals_only=True) 

39 

40 # Check for negative eigenvalues and clip them 

41 if np.any(eigs < 0): 

42 neg_fraction = np.sum(eigs < 0) / len(eigs) 

43 min_eig = np.min(eigs) 

44 if min_eig < -1e-6: # Significant negative eigenvalue 

45 warnings.warn( 

46 f"Found significant negative eigenvalues (min={min_eig:.2e}). " 

47 f"{neg_fraction:.1%} of eigenvalues are negative. " 

48 "This may indicate numerical precision issues with the correlation matrix." 

49 ) 

50 eigs = np.maximum(eigs, min_eigenvalue) 

51 

52 init_eigs = eigs.copy() 

53 iter_eigs = eigs.copy() 

54 corrected_eigs = [init_eigs] 

55 

56 for i in range(correction_iters): 

57 all_ratios = np.zeros((ensemble_size, N)) 

58 

59 # phase 1 

60 for j in range(ensemble_size): 

61 M = norm.rvs(size=(N,T)) 

62 # Ensure eigenvalues are non-negative before taking sqrt 

63 iter_eigs = np.maximum(iter_eigs, min_eigenvalue) 

64 L = np.diag(np.sqrt(iter_eigs)) 

65 M2 = L@M@M.T@L/T 

66 ps_eigs = eigh(M2)[0] 

67 # Clip ps_eigs to avoid division issues 

68 ps_eigs = np.maximum(ps_eigs, min_eigenvalue) 

69 all_ratios[j,:] = np.divide(ps_eigs, iter_eigs) 

70 

71 s1 = np.sum(all_ratios, axis=0) 

72 s2 = np.sum(np.square(all_ratios), axis=0) 

73 S = np.diag(np.divide(s1, s2)) 

74 

75 iter_eigs = eigh(np.diag(init_eigs)@S, eigvals_only=True) 

76 

77 # phase 2 

78 M = norm.rvs(size=(N,T)) 

79 # Ensure eigenvalues are non-negative before taking sqrt 

80 iter_eigs = np.maximum(iter_eigs, min_eigenvalue) 

81 L = np.diag(np.sqrt(iter_eigs)) 

82 W = L@M@M.T@L/T 

83 _, V = eigh(W) 

84 upd_eigs = np.diagonal(V@np.diag(init_eigs)@V.T) 

85 # Ensure updated eigenvalues are non-negative 

86 upd_eigs = np.maximum(upd_eigs, min_eigenvalue) 

87 

88 corrected_eigs.append(upd_eigs) 

89 iter_eigs = upd_eigs 

90 

91 return corrected_eigs