Coverage for src/driada/information/ksg.py: 33.33%
87 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1# TODO: Credit to Greg Ver Steeg (http://www.isi.edu/~gregv/npeet.html)
3import numpy as np
4import numpy.linalg as la
5from numpy import log
6from sklearn.neighbors import BallTree, KDTree
8from .info_utils import py_fast_digamma
10DEFAULT_NN = 5
11# UTILITY FUNCTIONS
13#TODO: add automatic alpha selection for LNC correction from https://github.com/BiuBiuBiLL/NPEET_LNC
15def add_noise(x, ampl=1e-10):
16 # small noise to break degeneracy, see doc.
17 return x + ampl * np.random.random_sample(x.shape)
20def query_neighbors(tree, x, k):
21 # return tree.query(x, k=k+1, breadth_first = False)[0][:, k]
22 return tree.query(x, k=k + 1)[0][:, k]
25def _count_neighbors_single(tree, x, radii, ind):
26 dists, indices = tree.query(x[ind:ind + 1], k=DEFAULT_NN, distance_upper_bound=radii[ind])
27 return len(np.unique(indices[0])) - 2
30def count_neighbors(tree, x, radii):
31 return tree.query_radius(x, radii, count_only=True)
32 # dists, indices = tree.query(x, k=DEFAULT_NN, distance_upper_bound=r)
33 # out = tree.query(x, k=DEFAULT_NN, distance_upper_bound=r)
34 # return np.array([_count_neighbors_single(tree, x, radii, ind) for ind in range(len(x))])
35 # return np.array([len(nn)-1 for nn in tree.query_ball_point(x, radii)])
38def build_tree(points, lf=5):
39 if points.shape[1] >= 20:
40 return BallTree(points, metric='chebyshev')
42 return KDTree(points, metric='chebyshev', leaf_size=lf)
43 # return KDTree(points, leafsize = lf)
44 # return KDTree(points, copy_data=True, leafsize = 5)
47def avgdigamma(points, dvec, lf=30, tree=None):
48 # This part finds number of neighbors in some radius in the marginal space
49 # returns expectation value of <psi(nx)>
50 if tree is None:
51 tree = build_tree(points, lf=lf)
53 dvec = dvec - 1e-15
54 num_points = count_neighbors(tree, points, dvec)
55 num_points = num_points.astype(float)
57 zero_inds = np.where(num_points == 0)[0]
58 if 1.0 * len(zero_inds) / len(num_points) > 0.01:
59 raise Exception('No neighbours in more than 1% points, check input!')
60 else:
61 if len(zero_inds) != 0:
62 num_points[zero_inds] = 0.5
64 # inf_inds = np.where(digamma(num_points) == -np.inf)
65 # print(num_points[inf_inds])
67 digammas = list(map(py_fast_digamma, num_points))
68 return np.mean(digammas)
71# CONTINUOUS ESTIMATORS
73def nonparam_entropy_c(x, k=DEFAULT_NN, base=np.e):
74 """ The classic K-L k-nearest neighbor continuous entropy estimator.
75 """
76 #assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
77 # xs_columns = np.expand_dims(xs, axis=0).T
78 x = np.asarray(x)
79 if len(x.shape) == 1:
80 x = x.reshape(-1, 1)
81 n_elements, n_features = x.shape
82 x = add_noise(x)
83 tree = build_tree(x)
84 nn = query_neighbors(tree, x, k)
85 const = py_fast_digamma(n_elements) - py_fast_digamma(k) + n_features * log(2)
86 return (const + n_features * np.log(nn).mean()) / log(base)
89def nonparam_cond_entropy_cc(x, y, k=DEFAULT_NN, base=np.e):
90 """ The classic K-L k-nearest neighbor continuous entropy estimator for the
91 entropy of X conditioned on Y.
92 """
93 xy = np.c_[x, y]
94 entropy_union_xy = nonparam_entropy_c(xy, k=k, base=base)
95 entropy_y = nonparam_entropy_c(y, k=k, base=base)
96 return entropy_union_xy - entropy_y
99def nonparam_mi_cc(x, y, z=None, k=DEFAULT_NN, base=np.e, alpha=0,
100 lf=5, precomputed_tree_x=None, precomputed_tree_y=None):
101 """
102 Mutual information of x and y (conditioned on z if z is not None)
103 """
105 assert len(x) == len(y), "Arrays should have same length"
106 assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
108 x, y = np.asarray(x), np.asarray(y)
109 x, y = x.reshape(x.shape[0], -1), y.reshape(y.shape[0], -1)
110 x = add_noise(x)
111 y = add_noise(y)
113 points = [x, y]
114 if z is not None:
115 z = np.asarray(z)
116 z = z.reshape(z.shape[0], -1)
117 points.append(z)
119 points = np.hstack(points)
121 # Find nearest neighbors in joint space, p=inf means max-norm
122 tree = build_tree(points, lf=lf)
123 dvec = query_neighbors(tree, points, k)
125 if z is None:
126 a = avgdigamma(x, dvec, tree=precomputed_tree_x, lf=lf)
127 b = avgdigamma(y, dvec, tree=precomputed_tree_y, lf=lf)
128 c = py_fast_digamma(k)
129 d = py_fast_digamma(len(x))
131 # print(a, b, c, d)
133 if alpha > 0:
134 d += lnc_correction(tree, points, k, alpha)
135 else:
136 xz = np.c_[x, z]
137 yz = np.c_[y, z]
138 a, b, c, d = avgdigamma(xz, dvec), avgdigamma(
139 yz, dvec), avgdigamma(z, dvec), py_fast_digamma(k)
141 return (-a - b + c + d) / log(base)
144def lnc_correction(tree, points, k, alpha):
145 e = 0
146 n_sample = points.shape[0]
147 for point in points:
148 # Find k-nearest neighbors in joint space, p=inf means max norm
149 knn = tree.query(point[None, :], k=k + 1, return_distance=False)[0]
150 knn_points = points[knn]
151 # Substract mean of k-nearest neighbor points
152 knn_points = knn_points - knn_points[0]
153 # Calculate covariance matrix of k-nearest neighbor points, obtain eigen vectors
154 covr = knn_points.T @ knn_points / k
155 _, v = la.eig(covr)
156 # Calculate PCA-bounding box using eigen vectors
157 V_rect = np.log(np.abs(knn_points @ v).max(axis=0)).sum()
158 # Calculate the volume of original box
159 log_knn_dist = np.log(np.abs(knn_points).max(axis=0)).sum()
161 # Perform local non-uniformity checking and update correction term
162 if V_rect < log_knn_dist + np.log(alpha):
163 e += (log_knn_dist - V_rect) / n_sample
164 return e