Coverage for src/driada/information/entropy_jit.py: 0.00%

43 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2JIT-compiled entropy calculation functions for performance optimization. 

3""" 

4 

5import numpy as np 

6from numba import njit 

7 

8 

9@njit 

10def entropy_d_jit(x): 

11 """JIT-compiled discrete entropy calculation. 

12  

13 Parameters 

14 ---------- 

15 x : array-like 

16 Discrete variable values. 

17  

18 Returns 

19 ------- 

20 float 

21 Entropy in bits. 

22 """ 

23 # Count occurrences efficiently 

24 unique_vals = np.unique(x) 

25 counts = np.zeros(unique_vals.size) 

26 

27 for i in range(x.size): 

28 for j in range(unique_vals.size): 

29 if x[i] == unique_vals[j]: 

30 counts[j] += 1 

31 break 

32 

33 # Calculate entropy 

34 p = counts / x.size 

35 h = 0.0 

36 for i in range(p.size): 

37 if p[i] > 0: 

38 h -= p[i] * np.log2(p[i]) 

39 

40 return h 

41 

42 

43@njit 

44def joint_entropy_dd_jit(x, y): 

45 """JIT-compiled joint entropy for two discrete variables. 

46  

47 Parameters 

48 ---------- 

49 x : array-like 

50 First discrete variable. 

51 y : array-like 

52 Second discrete variable. 

53  

54 Returns 

55 ------- 

56 float 

57 Joint entropy H(X,Y) in bits. 

58 """ 

59 # Create joint distribution 

60 unique_x = np.unique(x) 

61 unique_y = np.unique(y) 

62 joint_counts = np.zeros((unique_x.size, unique_y.size)) 

63 

64 # Count joint occurrences 

65 for i in range(x.size): 

66 x_idx = -1 

67 y_idx = -1 

68 

69 # Find indices 

70 for j in range(unique_x.size): 

71 if x[i] == unique_x[j]: 

72 x_idx = j 

73 break 

74 

75 for j in range(unique_y.size): 

76 if y[i] == unique_y[j]: 

77 y_idx = j 

78 break 

79 

80 if x_idx >= 0 and y_idx >= 0: 

81 joint_counts[x_idx, y_idx] += 1 

82 

83 # Calculate joint entropy 

84 total = x.size 

85 h = 0.0 

86 

87 for i in range(unique_x.size): 

88 for j in range(unique_y.size): 

89 if joint_counts[i, j] > 0: 

90 p = joint_counts[i, j] / total 

91 h -= p * np.log2(p) 

92 

93 return h