Coverage for src/driada/information/info_utils.py: 48.78%

41 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import numpy as np 

2from numba import njit 

3 

4@njit() 

5def py_fast_digamma_arr(data): 

6 res = np.zeros(len(data)) 

7 for i, x in enumerate(data): 

8 "Faster digamma function assumes x > 0." 

9 r = 0 

10 while x <= 5: 

11 r -= 1 / x 

12 x += 1 

13 f = 1 / (x * x) 

14 t = f * (-1 / 12.0 + f * (1 / 120.0 + f * (-1 / 252.0 + f * (1 / 240.0 + f * (-1 / 132.0 

15 + f * (691 / 32760.0 + f * (-1 / 12.0 + f * 3617 / 8160.0))))))) 

16 

17 res[i] = r + np.log(x) - 0.5 / x + t 

18 

19 return res 

20 

21 

22@njit() 

23def py_fast_digamma(x): 

24 r = 0 

25 x = x*1.0 

26 while x <= 5: 

27 r -= 1 / x 

28 x += 1 

29 f = 1 / (x * x) 

30 t = f * (-1 / 12.0 + f * (1 / 120.0 + f * (-1 / 252.0 + f * (1 / 240.0 + f * (-1 / 132.0 

31 + f * (691 / 32760.0 + f * (-1 / 12.0 + f * 3617 / 8160.0))))))) 

32 

33 res = r + np.log(x) - 0.5 / x + t 

34 return res 

35 

36 

37def binary_mi_score(contingency): 

38 nzx, nzy = np.nonzero(contingency) 

39 nz_val = contingency[nzx, nzy] 

40 

41 contingency_sum = contingency.sum() 

42 pi = np.ravel(contingency.sum(axis=1)) 

43 pj = np.ravel(contingency.sum(axis=0)) 

44 

45 # Since MI <= min(H(X), H(Y)), any labelling with zero entropy, i.e. containing a 

46 # single cluster, implies MI = 0 

47 if pi.size == 1 or pj.size == 1: 

48 return 0.0 

49 

50 log_contingency_nm = np.log(nz_val) 

51 contingency_nm = nz_val / contingency_sum 

52 # Don't need to calculate the full outer product, just for non-zeroes 

53 outer = pi.take(nzx).astype(np.int64, copy=False) * pj.take(nzy).astype( 

54 np.int64, copy=False 

55 ) 

56 log_outer = -np.log(outer) + np.log(pi.sum()) + np.log(pj.sum()) 

57 mi = ( 

58 contingency_nm * (log_contingency_nm - np.log(contingency_sum)) 

59 + contingency_nm * log_outer 

60 ) 

61 mi = np.where(np.abs(mi) < np.finfo(mi.dtype).eps, 0.0, mi) 

62 return np.clip(mi.sum(), 0.0, None)