Coverage for src/driada/dimensionality/utils.py: 95.45%
44 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1from scipy.spatial.distance import pdist, cdist
2from scipy.sparse.csgraph import shortest_path
3from scipy.stats import pearsonr, norm
4from scipy.linalg import eigh
5import numpy as np
6import warnings
9def res_var_metric(all_dists, emb_dists):
10 m = 1 - (pearsonr(all_dists, emb_dists)[0])**2
11 return m
14def correct_cov_spectrum(N, T, cmat, correction_iters=10, ensemble_size=1, min_eigenvalue=1e-10):
15 """
16 Correct the spectrum of a covariance/correlation matrix.
18 Parameters
19 ----------
20 N : int
21 Number of variables (neurons).
22 T : int
23 Number of time points.
24 cmat : ndarray
25 Covariance or correlation matrix.
26 correction_iters : int, optional
27 Number of correction iterations. Default is 10.
28 ensemble_size : int, optional
29 Size of the ensemble for phase 1. Default is 1.
30 min_eigenvalue : float, optional
31 Minimum eigenvalue threshold to avoid numerical issues. Default is 1e-10.
33 Returns
34 -------
35 corrected_eigs : list
36 List of eigenvalue arrays for each iteration.
37 """
38 eigs = eigh(cmat, eigvals_only=True)
40 # Check for negative eigenvalues and clip them
41 if np.any(eigs < 0):
42 neg_fraction = np.sum(eigs < 0) / len(eigs)
43 min_eig = np.min(eigs)
44 if min_eig < -1e-6: # Significant negative eigenvalue
45 warnings.warn(
46 f"Found significant negative eigenvalues (min={min_eig:.2e}). "
47 f"{neg_fraction:.1%} of eigenvalues are negative. "
48 "This may indicate numerical precision issues with the correlation matrix."
49 )
50 eigs = np.maximum(eigs, min_eigenvalue)
52 init_eigs = eigs.copy()
53 iter_eigs = eigs.copy()
54 corrected_eigs = [init_eigs]
56 for i in range(correction_iters):
57 all_ratios = np.zeros((ensemble_size, N))
59 # phase 1
60 for j in range(ensemble_size):
61 M = norm.rvs(size=(N,T))
62 # Ensure eigenvalues are non-negative before taking sqrt
63 iter_eigs = np.maximum(iter_eigs, min_eigenvalue)
64 L = np.diag(np.sqrt(iter_eigs))
65 M2 = L@M@M.T@L/T
66 ps_eigs = eigh(M2)[0]
67 # Clip ps_eigs to avoid division issues
68 ps_eigs = np.maximum(ps_eigs, min_eigenvalue)
69 all_ratios[j,:] = np.divide(ps_eigs, iter_eigs)
71 s1 = np.sum(all_ratios, axis=0)
72 s2 = np.sum(np.square(all_ratios), axis=0)
73 S = np.diag(np.divide(s1, s2))
75 iter_eigs = eigh(np.diag(init_eigs)@S, eigvals_only=True)
77 # phase 2
78 M = norm.rvs(size=(N,T))
79 # Ensure eigenvalues are non-negative before taking sqrt
80 iter_eigs = np.maximum(iter_eigs, min_eigenvalue)
81 L = np.diag(np.sqrt(iter_eigs))
82 W = L@M@M.T@L/T
83 _, V = eigh(W)
84 upd_eigs = np.diagonal(V@np.diag(init_eigs)@V.T)
85 # Ensure updated eigenvalues are non-negative
86 upd_eigs = np.maximum(upd_eigs, min_eigenvalue)
88 corrected_eigs.append(upd_eigs)
89 iter_eigs = upd_eigs
91 return corrected_eigs