Coverage for src/driada/information/ksg.py: 33.33%

87 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1# TODO: Credit to Greg Ver Steeg (http://www.isi.edu/~gregv/npeet.html) 

2 

3import numpy as np 

4import numpy.linalg as la 

5from numpy import log 

6from sklearn.neighbors import BallTree, KDTree 

7 

8from .info_utils import py_fast_digamma 

9 

10DEFAULT_NN = 5 

11# UTILITY FUNCTIONS 

12 

13#TODO: add automatic alpha selection for LNC correction from https://github.com/BiuBiuBiLL/NPEET_LNC 

14 

15def add_noise(x, ampl=1e-10): 

16 # small noise to break degeneracy, see doc. 

17 return x + ampl * np.random.random_sample(x.shape) 

18 

19 

20def query_neighbors(tree, x, k): 

21 # return tree.query(x, k=k+1, breadth_first = False)[0][:, k] 

22 return tree.query(x, k=k + 1)[0][:, k] 

23 

24 

25def _count_neighbors_single(tree, x, radii, ind): 

26 dists, indices = tree.query(x[ind:ind + 1], k=DEFAULT_NN, distance_upper_bound=radii[ind]) 

27 return len(np.unique(indices[0])) - 2 

28 

29 

30def count_neighbors(tree, x, radii): 

31 return tree.query_radius(x, radii, count_only=True) 

32 # dists, indices = tree.query(x, k=DEFAULT_NN, distance_upper_bound=r) 

33 # out = tree.query(x, k=DEFAULT_NN, distance_upper_bound=r) 

34 # return np.array([_count_neighbors_single(tree, x, radii, ind) for ind in range(len(x))]) 

35 # return np.array([len(nn)-1 for nn in tree.query_ball_point(x, radii)]) 

36 

37 

38def build_tree(points, lf=5): 

39 if points.shape[1] >= 20: 

40 return BallTree(points, metric='chebyshev') 

41 

42 return KDTree(points, metric='chebyshev', leaf_size=lf) 

43 # return KDTree(points, leafsize = lf) 

44 # return KDTree(points, copy_data=True, leafsize = 5) 

45 

46 

47def avgdigamma(points, dvec, lf=30, tree=None): 

48 # This part finds number of neighbors in some radius in the marginal space 

49 # returns expectation value of <psi(nx)> 

50 if tree is None: 

51 tree = build_tree(points, lf=lf) 

52 

53 dvec = dvec - 1e-15 

54 num_points = count_neighbors(tree, points, dvec) 

55 num_points = num_points.astype(float) 

56 

57 zero_inds = np.where(num_points == 0)[0] 

58 if 1.0 * len(zero_inds) / len(num_points) > 0.01: 

59 raise Exception('No neighbours in more than 1% points, check input!') 

60 else: 

61 if len(zero_inds) != 0: 

62 num_points[zero_inds] = 0.5 

63 

64 # inf_inds = np.where(digamma(num_points) == -np.inf) 

65 # print(num_points[inf_inds]) 

66 

67 digammas = list(map(py_fast_digamma, num_points)) 

68 return np.mean(digammas) 

69 

70 

71# CONTINUOUS ESTIMATORS 

72 

73def nonparam_entropy_c(x, k=DEFAULT_NN, base=np.e): 

74 """ The classic K-L k-nearest neighbor continuous entropy estimator. 

75 """ 

76 #assert k <= len(x) - 1, "Set k smaller than num. samples - 1" 

77 # xs_columns = np.expand_dims(xs, axis=0).T 

78 x = np.asarray(x) 

79 if len(x.shape) == 1: 

80 x = x.reshape(-1, 1) 

81 n_elements, n_features = x.shape 

82 x = add_noise(x) 

83 tree = build_tree(x) 

84 nn = query_neighbors(tree, x, k) 

85 const = py_fast_digamma(n_elements) - py_fast_digamma(k) + n_features * log(2) 

86 return (const + n_features * np.log(nn).mean()) / log(base) 

87 

88 

89def nonparam_cond_entropy_cc(x, y, k=DEFAULT_NN, base=np.e): 

90 """ The classic K-L k-nearest neighbor continuous entropy estimator for the 

91 entropy of X conditioned on Y. 

92 """ 

93 xy = np.c_[x, y] 

94 entropy_union_xy = nonparam_entropy_c(xy, k=k, base=base) 

95 entropy_y = nonparam_entropy_c(y, k=k, base=base) 

96 return entropy_union_xy - entropy_y 

97 

98 

99def nonparam_mi_cc(x, y, z=None, k=DEFAULT_NN, base=np.e, alpha=0, 

100 lf=5, precomputed_tree_x=None, precomputed_tree_y=None): 

101 """ 

102 Mutual information of x and y (conditioned on z if z is not None) 

103 """ 

104 

105 assert len(x) == len(y), "Arrays should have same length" 

106 assert k <= len(x) - 1, "Set k smaller than num. samples - 1" 

107 

108 x, y = np.asarray(x), np.asarray(y) 

109 x, y = x.reshape(x.shape[0], -1), y.reshape(y.shape[0], -1) 

110 x = add_noise(x) 

111 y = add_noise(y) 

112 

113 points = [x, y] 

114 if z is not None: 

115 z = np.asarray(z) 

116 z = z.reshape(z.shape[0], -1) 

117 points.append(z) 

118 

119 points = np.hstack(points) 

120 

121 # Find nearest neighbors in joint space, p=inf means max-norm 

122 tree = build_tree(points, lf=lf) 

123 dvec = query_neighbors(tree, points, k) 

124 

125 if z is None: 

126 a = avgdigamma(x, dvec, tree=precomputed_tree_x, lf=lf) 

127 b = avgdigamma(y, dvec, tree=precomputed_tree_y, lf=lf) 

128 c = py_fast_digamma(k) 

129 d = py_fast_digamma(len(x)) 

130 

131 # print(a, b, c, d) 

132 

133 if alpha > 0: 

134 d += lnc_correction(tree, points, k, alpha) 

135 else: 

136 xz = np.c_[x, z] 

137 yz = np.c_[y, z] 

138 a, b, c, d = avgdigamma(xz, dvec), avgdigamma( 

139 yz, dvec), avgdigamma(z, dvec), py_fast_digamma(k) 

140 

141 return (-a - b + c + d) / log(base) 

142 

143 

144def lnc_correction(tree, points, k, alpha): 

145 e = 0 

146 n_sample = points.shape[0] 

147 for point in points: 

148 # Find k-nearest neighbors in joint space, p=inf means max norm 

149 knn = tree.query(point[None, :], k=k + 1, return_distance=False)[0] 

150 knn_points = points[knn] 

151 # Substract mean of k-nearest neighbor points 

152 knn_points = knn_points - knn_points[0] 

153 # Calculate covariance matrix of k-nearest neighbor points, obtain eigen vectors 

154 covr = knn_points.T @ knn_points / k 

155 _, v = la.eig(covr) 

156 # Calculate PCA-bounding box using eigen vectors 

157 V_rect = np.log(np.abs(knn_points @ v).max(axis=0)).sum() 

158 # Calculate the volume of original box 

159 log_knn_dist = np.log(np.abs(knn_points).max(axis=0)).sum() 

160 

161 # Perform local non-uniformity checking and update correction term 

162 if V_rect < log_knn_dist + np.log(alpha): 

163 e += (log_knn_dist - V_rect) / n_sample 

164 return e