Coverage for src/driada/information/gcmi_jit_utils.py: 10.14%
207 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2JIT-compiled copula transformation functions for GCMI.
3"""
5import numpy as np
6from numba import njit
7from scipy.special import ndtri
10@njit
11def ctransform_jit(x):
12 """JIT-compiled copula transformation (empirical CDF).
14 Efficient O(n log n) implementation using sorting-based ranking.
16 Parameters
17 ----------
18 x : ndarray
19 1D array of values to transform.
21 Returns
22 -------
23 ndarray
24 Copula-transformed values in (0, 1).
25 """
26 n = x.size
28 # Get sorting indices - O(n log n)
29 sorted_indices = np.argsort(x)
31 # Create ranks array
32 ranks = np.empty(n, dtype=np.int64)
34 # Assign ranks based on sorted order
35 for i in range(n):
36 ranks[sorted_indices[i]] = i + 1
38 # Handle ties by using the original tie-breaking logic
39 # For identical values, use index-based tie breaking
40 sorted_values = x[sorted_indices]
41 current_rank = 1
43 for i in range(n):
44 if i > 0 and sorted_values[i] == sorted_values[i-1]:
45 # Same value as previous - keep incrementing rank for each occurrence
46 current_rank += 1
47 else:
48 # New value - start fresh rank
49 current_rank = i + 1
51 ranks[sorted_indices[i]] = current_rank
53 # Convert to copula values
54 return ranks.astype(np.float64) / (n + 1)
57@njit
58def ctransform_2d_jit(x):
59 """JIT-compiled copula transformation for 2D arrays.
61 Transforms each row independently.
63 Parameters
64 ----------
65 x : ndarray
66 2D array where each row is transformed independently.
68 Returns
69 -------
70 ndarray
71 Copula-transformed array.
72 """
73 n_vars, n_samples = x.shape
74 result = np.empty_like(x)
76 for i in range(n_vars):
77 result[i, :] = ctransform_jit(x[i, :])
79 return result
82@njit
83def ndtri_approx(p):
84 """Approximate inverse normal CDF for JIT compilation.
86 Uses a simpler but efficient approximation for JIT compilation.
88 Parameters
89 ----------
90 p : float or ndarray
91 Probability values in (0, 1).
93 Returns
94 -------
95 float or ndarray
96 Approximate quantile values.
97 """
98 # Handle array input
99 if hasattr(p, 'shape'):
100 result = np.empty_like(p)
101 for i in range(p.size):
102 flat_p = p.flat[i]
103 if flat_p <= 0.0:
104 result.flat[i] = -np.inf
105 elif flat_p >= 1.0:
106 result.flat[i] = np.inf
107 else:
108 # Use Box-Muller-like transformation
109 # This is a simplified approximation suitable for JIT
110 if flat_p < 0.5:
111 # Use symmetry
112 q = flat_p
113 t = np.sqrt(-2.0 * np.log(q))
114 result.flat[i] = -(t - (2.515517 + 0.802853*t + 0.010328*t*t) / (1.0 + 1.432788*t + 0.189269*t*t + 0.001308*t*t*t))
115 else:
116 q = 1.0 - flat_p
117 t = np.sqrt(-2.0 * np.log(q))
118 result.flat[i] = t - (2.515517 + 0.802853*t + 0.010328*t*t) / (1.0 + 1.432788*t + 0.189269*t*t + 0.001308*t*t*t)
119 return result
120 else:
121 # Scalar input
122 if p <= 0.0:
123 return -np.inf
124 elif p >= 1.0:
125 return np.inf
126 else:
127 if p < 0.5:
128 q = p
129 t = np.sqrt(-2.0 * np.log(q))
130 return -(t - (2.515517 + 0.802853*t + 0.010328*t*t) / (1.0 + 1.432788*t + 0.189269*t*t + 0.001308*t*t*t))
131 else:
132 q = 1.0 - p
133 t = np.sqrt(-2.0 * np.log(q))
134 return t - (2.515517 + 0.802853*t + 0.010328*t*t) / (1.0 + 1.432788*t + 0.189269*t*t + 0.001308*t*t*t)
137@njit
138def copnorm_jit(x):
139 """JIT-compiled copula normalization.
141 Fast implementation using approximations suitable for JIT.
143 Parameters
144 ----------
145 x : ndarray
146 1D array to normalize.
148 Returns
149 -------
150 ndarray
151 Standard normal samples with same empirical CDF as input.
152 """
153 cx = ctransform_jit(x)
154 return ndtri_approx(cx)
157@njit
158def copnorm_2d_jit(x):
159 """JIT-compiled copula normalization for 2D arrays.
161 Parameters
162 ----------
163 x : ndarray
164 2D array where each row is normalized independently.
166 Returns
167 -------
168 ndarray
169 Copula-normalized array.
170 """
171 n_vars, n_samples = x.shape
172 result = np.empty_like(x)
174 for i in range(n_vars):
175 result[i, :] = copnorm_jit(x[i, :])
177 return result
180@njit
181def mi_gg_jit(x, y, biascorrect=True, demeaned=False):
182 """JIT-compiled Gaussian mutual information between two variables.
184 Parameters
185 ----------
186 x : ndarray
187 First variable data (n_vars_x, n_samples).
188 y : ndarray
189 Second variable data (n_vars_y, n_samples).
190 biascorrect : bool
191 Apply bias correction.
192 demeaned : bool
193 Whether data is already demeaned.
195 Returns
196 -------
197 float
198 Mutual information in bits.
199 """
200 if x.shape[1] != y.shape[1]:
201 raise ValueError("Number of samples must match")
203 Ntrl = x.shape[1]
204 Nvarx = x.shape[0]
205 Nvary = y.shape[0]
206 Nvarxy = Nvarx + Nvary
208 if not demeaned:
209 # Demean data - manual implementation for numba
210 for i in range(x.shape[0]):
211 x[i] = x[i] - np.mean(x[i])
212 for i in range(y.shape[0]):
213 y[i] = y[i] - np.mean(y[i])
215 # Compute covariance matrices
216 Cxx = np.dot(x, x.T) / (Ntrl - 1)
217 Cyy = np.dot(y, y.T) / (Ntrl - 1)
218 Cxy = np.dot(x, y.T) / (Ntrl - 1)
219 Cyx = np.dot(y, x.T) / (Ntrl - 1)
221 # Joint covariance
222 C = np.empty((Nvarxy, Nvarxy))
223 C[:Nvarx, :Nvarx] = Cxx
224 C[:Nvarx, Nvarx:] = Cxy
225 C[Nvarx:, :Nvarx] = Cyx
226 C[Nvarx:, Nvarx:] = Cyy
228 # Compute log determinants using Cholesky decomposition
229 # Add small regularization to prevent numerical issues with identical data
230 C += np.eye(C.shape[0]) * 1e-12
231 Cxx += np.eye(Cxx.shape[0]) * 1e-12
232 Cyy += np.eye(Cyy.shape[0]) * 1e-12
234 chC = np.linalg.cholesky(C)
235 chCxx = np.linalg.cholesky(Cxx)
236 chCyy = np.linalg.cholesky(Cyy)
238 # Sum of log diagonals
239 HX = np.sum(np.log(np.diag(chCxx)))
240 HY = np.sum(np.log(np.diag(chCyy)))
241 HXY = np.sum(np.log(np.diag(chC)))
243 ln2 = np.log(2.0)
245 if biascorrect:
246 # Bias correction terms
247 psiterms = np.zeros(Nvarxy)
248 dterm = (ln2 - np.log(Ntrl - 1.0)) / 2.0
250 for i in range(Nvarxy):
251 psiterms[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0
253 HX = HX - Nvarx * dterm - np.sum(psiterms[:Nvarx])
254 HY = HY - Nvary * dterm - np.sum(psiterms[:Nvary])
255 HXY = HXY - Nvarxy * dterm - np.sum(psiterms)
257 # MI in bits
258 I = (HX + HY - HXY) / ln2
259 return I
262@njit
263def cmi_ggg_jit(x, y, z, biascorrect=True, demeaned=False):
264 """JIT-compiled conditional mutual information for Gaussian variables.
266 Computes I(X;Y|Z) for continuous variables.
268 Parameters
269 ----------
270 x : ndarray
271 First variable (n_vars_x, n_samples).
272 y : ndarray
273 Second variable (n_vars_y, n_samples).
274 z : ndarray
275 Conditioning variable (n_vars_z, n_samples).
276 biascorrect : bool
277 Apply bias correction.
278 demeaned : bool
279 Whether data is already demeaned.
281 Returns
282 -------
283 float
284 Conditional mutual information in bits.
285 """
286 if x.shape[1] != y.shape[1] or x.shape[1] != z.shape[1]:
287 raise ValueError("Number of samples must match")
289 Ntrl = x.shape[1]
290 Nvarx = x.shape[0]
291 Nvary = y.shape[0]
292 Nvarz = z.shape[0]
293 Nvaryz = Nvary + Nvarz
294 Nvarxz = Nvarx + Nvarz
295 Nvarxyz = Nvarx + Nvary + Nvarz
297 if not demeaned:
298 # Demean data - manual implementation for numba
299 for i in range(x.shape[0]):
300 x[i] = x[i] - np.mean(x[i])
301 for i in range(y.shape[0]):
302 y[i] = y[i] - np.mean(y[i])
303 for i in range(z.shape[0]):
304 z[i] = z[i] - np.mean(z[i])
306 # Compute all required covariance matrices
307 Cxx = np.dot(x, x.T) / (Ntrl - 1)
308 Cyy = np.dot(y, y.T) / (Ntrl - 1)
309 Czz = np.dot(z, z.T) / (Ntrl - 1)
310 Cxy = np.dot(x, y.T) / (Ntrl - 1)
311 Cxz = np.dot(x, z.T) / (Ntrl - 1)
312 Cyz = np.dot(y, z.T) / (Ntrl - 1)
314 # Build joint covariance matrices
315 # C(y,z)
316 Cyz_joint = np.empty((Nvaryz, Nvaryz))
317 Cyz_joint[:Nvary, :Nvary] = Cyy
318 Cyz_joint[:Nvary, Nvary:] = Cyz
319 Cyz_joint[Nvary:, :Nvary] = Cyz.T
320 Cyz_joint[Nvary:, Nvary:] = Czz
322 # C(x,z)
323 Cxz_joint = np.empty((Nvarxz, Nvarxz))
324 Cxz_joint[:Nvarx, :Nvarx] = Cxx
325 Cxz_joint[:Nvarx, Nvarx:] = Cxz
326 Cxz_joint[Nvarx:, :Nvarx] = Cxz.T
327 Cxz_joint[Nvarx:, Nvarx:] = Czz
329 # C(x,y,z)
330 Cxyz = np.empty((Nvarxyz, Nvarxyz))
331 Cxyz[:Nvarx, :Nvarx] = Cxx
332 Cxyz[:Nvarx, Nvarx:Nvarx+Nvary] = Cxy
333 Cxyz[:Nvarx, Nvarx+Nvary:] = Cxz
334 Cxyz[Nvarx:Nvarx+Nvary, :Nvarx] = Cxy.T
335 Cxyz[Nvarx:Nvarx+Nvary, Nvarx:Nvarx+Nvary] = Cyy
336 Cxyz[Nvarx:Nvarx+Nvary, Nvarx+Nvary:] = Cyz
337 Cxyz[Nvarx+Nvary:, :Nvarx] = Cxz.T
338 Cxyz[Nvarx+Nvary:, Nvarx:Nvarx+Nvary] = Cyz.T
339 Cxyz[Nvarx+Nvary:, Nvarx+Nvary:] = Czz
341 # Compute log determinants
342 # Add small regularization to prevent numerical issues with identical data
343 Czz += np.eye(Czz.shape[0]) * 1e-12
344 Cyz_joint += np.eye(Cyz_joint.shape[0]) * 1e-12
345 Cxz_joint += np.eye(Cxz_joint.shape[0]) * 1e-12
346 Cxyz += np.eye(Cxyz.shape[0]) * 1e-12
348 chCz = np.linalg.cholesky(Czz)
349 chCyz = np.linalg.cholesky(Cyz_joint)
350 chCxz = np.linalg.cholesky(Cxz_joint)
351 chCxyz = np.linalg.cholesky(Cxyz)
353 HZ = np.sum(np.log(np.diag(chCz)))
354 HYZ = np.sum(np.log(np.diag(chCyz)))
355 HXZ = np.sum(np.log(np.diag(chCxz)))
356 HXYZ = np.sum(np.log(np.diag(chCxyz)))
358 ln2 = np.log(2.0)
360 if biascorrect:
361 # Bias correction
362 dterm = (ln2 - np.log(Ntrl - 1.0)) / 2.0
364 # Compute psi terms
365 psiterms_z = np.zeros(Nvarz)
366 psiterms_yz = np.zeros(Nvaryz)
367 psiterms_xz = np.zeros(Nvarxz)
368 psiterms_xyz = np.zeros(Nvarxyz)
370 for i in range(Nvarz):
371 psiterms_z[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0
372 for i in range(Nvaryz):
373 psiterms_yz[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0
374 for i in range(Nvarxz):
375 psiterms_xz[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0
376 for i in range(Nvarxyz):
377 psiterms_xyz[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0
379 HZ = HZ - Nvarz * dterm - np.sum(psiterms_z)
380 HYZ = HYZ - Nvaryz * dterm - np.sum(psiterms_yz)
381 HXZ = HXZ - Nvarxz * dterm - np.sum(psiterms_xz)
382 HXYZ = HXYZ - Nvarxyz * dterm - np.sum(psiterms_xyz)
384 # CMI in bits: I(X;Y|Z) = H(X,Z) + H(Y,Z) - H(X,Y,Z) - H(Z)
385 I = (HXZ + HYZ - HXYZ - HZ) / ln2
386 return I
389@njit
390def digamma_approx(x):
391 """Approximate digamma function for JIT compilation.
393 Uses asymptotic expansion for x > 6 and recurrence for smaller values.
395 Parameters
396 ----------
397 x : float
398 Input value.
400 Returns
401 -------
402 float
403 Approximate digamma value.
404 """
405 if x <= 0:
406 return -np.inf
408 # Use recurrence relation to get x > 6
409 result = 0.0
410 while x < 6:
411 result -= 1.0 / x
412 x += 1.0
414 # Asymptotic expansion
415 x_inv = 1.0 / x
416 x_inv2 = x_inv * x_inv
418 # psi(x) ≈ ln(x) - 1/(2x) - 1/(12x²) + 1/(120x⁴) - 1/(252x⁶)
419 result += np.log(x) - 0.5 * x_inv - x_inv2 / 12.0 + x_inv2 * x_inv2 / 120.0
421 return result
424@njit
425def gcmi_cc_jit(x, y):
426 """JIT-compiled Gaussian-Copula MI between continuous variables.
428 Full pipeline: copula transform -> normalize -> compute MI.
430 Parameters
431 ----------
432 x : ndarray
433 First variable (n_vars_x, n_samples).
434 y : ndarray
435 Second variable (n_vars_y, n_samples).
437 Returns
438 -------
439 float
440 GCMI in bits.
441 """
442 # Copula transform
443 if x.ndim == 1:
444 cx = np.empty((1, x.shape[0]))
445 cx[0, :] = copnorm_jit(x)
446 else:
447 cx = copnorm_2d_jit(x)
449 if y.ndim == 1:
450 cy = np.empty((1, y.shape[0]))
451 cy[0, :] = copnorm_jit(y)
452 else:
453 cy = copnorm_2d_jit(y)
455 # Compute MI with bias correction
456 return mi_gg_jit(cx, cy, biascorrect=True, demeaned=True)