Coverage for src/driada/information/info_utils.py: 48.78%
41 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import numpy as np
2from numba import njit
4@njit()
5def py_fast_digamma_arr(data):
6 res = np.zeros(len(data))
7 for i, x in enumerate(data):
8 "Faster digamma function assumes x > 0."
9 r = 0
10 while x <= 5:
11 r -= 1 / x
12 x += 1
13 f = 1 / (x * x)
14 t = f * (-1 / 12.0 + f * (1 / 120.0 + f * (-1 / 252.0 + f * (1 / 240.0 + f * (-1 / 132.0
15 + f * (691 / 32760.0 + f * (-1 / 12.0 + f * 3617 / 8160.0)))))))
17 res[i] = r + np.log(x) - 0.5 / x + t
19 return res
22@njit()
23def py_fast_digamma(x):
24 r = 0
25 x = x*1.0
26 while x <= 5:
27 r -= 1 / x
28 x += 1
29 f = 1 / (x * x)
30 t = f * (-1 / 12.0 + f * (1 / 120.0 + f * (-1 / 252.0 + f * (1 / 240.0 + f * (-1 / 132.0
31 + f * (691 / 32760.0 + f * (-1 / 12.0 + f * 3617 / 8160.0)))))))
33 res = r + np.log(x) - 0.5 / x + t
34 return res
37def binary_mi_score(contingency):
38 nzx, nzy = np.nonzero(contingency)
39 nz_val = contingency[nzx, nzy]
41 contingency_sum = contingency.sum()
42 pi = np.ravel(contingency.sum(axis=1))
43 pj = np.ravel(contingency.sum(axis=0))
45 # Since MI <= min(H(X), H(Y)), any labelling with zero entropy, i.e. containing a
46 # single cluster, implies MI = 0
47 if pi.size == 1 or pj.size == 1:
48 return 0.0
50 log_contingency_nm = np.log(nz_val)
51 contingency_nm = nz_val / contingency_sum
52 # Don't need to calculate the full outer product, just for non-zeroes
53 outer = pi.take(nzx).astype(np.int64, copy=False) * pj.take(nzy).astype(
54 np.int64, copy=False
55 )
56 log_outer = -np.log(outer) + np.log(pi.sum()) + np.log(pj.sum())
57 mi = (
58 contingency_nm * (log_contingency_nm - np.log(contingency_sum))
59 + contingency_nm * log_outer
60 )
61 mi = np.where(np.abs(mi) < np.finfo(mi.dtype).eps, 0.0, mi)
62 return np.clip(mi.sum(), 0.0, None)