Coverage for src/driada/information/gcmi.py: 51.00%
251 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1import numpy as np
2import numba as nb
3from numba import njit
4import warnings
5from scipy.special import ndtri, psi, digamma
7from .info_utils import py_fast_digamma_arr
9# Import JIT versions if available
10try:
11 from .gcmi_jit_utils import (
12 ctransform_jit, ctransform_2d_jit, copnorm_jit, copnorm_2d_jit,
13 mi_gg_jit, cmi_ggg_jit, gcmi_cc_jit
14 )
15 _JIT_AVAILABLE = True
16except ImportError:
17 _JIT_AVAILABLE = False
19#TODO: credits to original GCMI: https://github.com/robince/gcmi
21def ctransform(x):
22 """Copula transformation (empirical CDF)
23 cx = ctransform(x) returns the empirical CDF value along the first
24 axis of x. Data is ranked and scaled within [0 1] (open interval).
25 """
26 x = np.atleast_2d(x)
28 # Use JIT version for suitable inputs
29 if _JIT_AVAILABLE and x.flags.c_contiguous and x.dtype in (np.float32, np.float64):
30 if x.shape[0] == 1:
31 # 1D case
32 return ctransform_jit(x.ravel()).reshape(1, -1)
33 else:
34 # 2D case
35 return ctransform_2d_jit(x)
37 # Fallback to original implementation
38 xi = np.argsort(x)
39 xr = np.argsort(xi)
40 cx = (xr + 1).astype(float) / (xr.shape[-1] + 1)
41 return cx
44def copnorm(x):
45 """Copula normalization
47 cx = copnorm(x) returns standard normal samples with the same empirical
48 CDF value as the input. Operates along the last axis.
49 """
50 x = np.atleast_2d(x)
52 # Use JIT version for suitable inputs
53 if _JIT_AVAILABLE and x.flags.c_contiguous and x.dtype in (np.float32, np.float64):
54 if x.shape[0] == 1:
55 # 1D case
56 return copnorm_jit(x.ravel()).reshape(1, -1)
57 else:
58 # 2D case
59 return copnorm_2d_jit(x)
61 # Fallback to original implementation
62 # cx = sp.stats.norm.ppf(ctransform(x))
63 cx = ndtri(ctransform(x))
64 return cx
67@njit
68def demean(x):
69 """Demean each row of a 2D array.
71 Parameters
72 ----------
73 x : ndarray
74 2D array where each row is demeaned independently.
76 Returns
77 -------
78 ndarray
79 Array with same shape as input with zero mean rows.
80 """
81 # Get the number of rows
82 num_rows = x.shape[0]
84 # Create an output array with the same shape as input
85 demeaned_x = np.empty_like(x)
87 # Demean each row
88 for i in range(num_rows):
89 row_mean = np.mean(x[i])
90 demeaned_x[i] = x[i] - row_mean
92 return demeaned_x
95@njit
96def regularized_cholesky(C, regularization=1e-12):
97 """Compute Cholesky decomposition with regularization for numerical stability.
99 Adds diagonal regularization to prevent issues with near-singular
100 covariance matrices. Uses adaptive regularization for severely ill-conditioned
101 matrices based on determinant check.
103 Parameters
104 ----------
105 C : ndarray
106 Covariance matrix to decompose.
107 regularization : float, optional
108 Base regularization parameter added to diagonal (default: 1e-12).
110 Returns
111 -------
112 ndarray
113 Lower triangular Cholesky factor.
114 """
115 # Check matrix conditioning using determinant
116 det_C = np.linalg.det(C)
117 trace_C = np.trace(C)
119 # Adaptive regularization based on determinant relative to trace
120 # For near-singular matrices, det << trace^n where n is matrix size
121 n = C.shape[0]
122 expected_det_scale = (trace_C / n) ** n
124 if det_C > 0 and det_C < expected_det_scale * 1e-8: # Severely ill-conditioned
125 # Use stronger regularization proportional to trace
126 adaptive_reg = trace_C * 1e-8 / n # Scale by matrix size
127 reg = max(regularization, adaptive_reg)
128 else:
129 reg = regularization
131 # Apply regularization
132 C_reg = C + np.eye(C.shape[0]) * reg
133 return np.linalg.cholesky(C_reg)
136@njit()
137def ent_g(x, biascorrect=True):
138 """Entropy of a Gaussian variable in bits
139 H = ent_g(x) returns the entropy of a (possibly
140 multidimensional) Gaussian variable x with bias correction.
141 Columns of x correspond to samples, rows to dimensions/variables.
142 (Samples last axis)
143 """
144 x = np.atleast_2d(x)
145 if x.ndim > 2:
146 raise ValueError("x must be at most 2d")
147 Ntrl = x.shape[1]
148 Nvarx = x.shape[0]
150 # demean data
151 x = demean(x)
152 # covariance
153 C = np.dot(x, x.T) / float(Ntrl - 1)
154 chC = regularized_cholesky(C)
156 # entropy in nats
157 # Extract diagonal manually for Numba compatibility
158 diag_sum = 0.0
159 for i in range(chC.shape[0]):
160 diag_sum += np.log(chC[i, i])
161 HX = diag_sum + 0.5 * Nvarx * (np.log(2 * np.pi) + 1.0)
163 ln2 = np.log(2)
164 if biascorrect:
165 psiterms = py_fast_digamma_arr((Ntrl - np.arange(1, Nvarx + 1, dtype=np.float64)) / 2.0) / 2.0
166 dterm = (ln2 - np.log(Ntrl - 1.0)) / 2.0
167 HX = HX - Nvarx * dterm - psiterms.sum()
169 # convert to bits
170 return HX / ln2
173@njit()
174def mi_gg(x, y, biascorrect=True, demeaned=False, max_dim=3):
175 """Mutual information (MI) between two Gaussian variables in bits
177 I = mi_gg(x,y) returns the MI between two (possibly multidimensional)
178 Gassian variables, x and y, with bias correction.
179 If x and/or y are multivariate columns must correspond to samples, rows
180 to dimensions/variables. (Samples last axis)
182 biascorrect : true / false option (default true) which specifies whether
183 bias correction should be applied to the estimated MI.
184 demeaned : false / true option (default false) which specifies whether th
185 input data already has zero mean (true if it has been copula-normalized)
186 max_dim : int (default 3) which specifies the maximum allowed dimensionality
187 to prevent undersampling issues.
188 """
190 x = np.atleast_2d(x)
191 y = np.atleast_2d(y)
192 if x.ndim > max_dim or y.ndim > max_dim:
193 raise ValueError(f"x and y must be at most {max_dim}d to prevent undersampling issues")
194 Ntrl = x.shape[1]
195 Nvarx = x.shape[0]
196 Nvary = y.shape[0]
197 Nvarxy = Nvarx + Nvary
199 if y.shape[1] != Ntrl:
200 raise ValueError("number of trials do not match")
202 # joint variable
203 xy = np.vstack((x, y))
205 if not demeaned:
206 xy = demean(xy)
208 Cxy = np.dot(xy, xy.T) / float(Ntrl - 1)
209 # submatrices of joint covariance
210 Cx = Cxy[:Nvarx, :Nvarx]
211 Cy = Cxy[Nvarx:, Nvarx:]
213 chCxy = regularized_cholesky(Cxy)
214 chCx = regularized_cholesky(Cx)
215 chCy = regularized_cholesky(Cy)
217 # entropies in nats
218 # normalizations cancel for mutual information
219 HX = np.sum(np.log(np.diag(chCx))) # + 0.5*Nvarx*(np.log(2*np.pi)+1.0)
220 HY = np.sum(np.log(np.diag(chCy))) # + 0.5*Nvary*(np.log(2*np.pi)+1.0)
221 HXY = np.sum(np.log(np.diag(chCxy))) # + 0.5*Nvarxy*(np.log(2*np.pi)+1.0)
223 ln2 = np.log(2)
224 if biascorrect:
225 psiterms = py_fast_digamma_arr((Ntrl - np.arange(1, Nvarxy + 1)) / 2.0) / 2.0
226 dterm = (ln2 - np.log(Ntrl - 1.0)) / 2.0
227 HX = HX - Nvarx * dterm - psiterms[:Nvarx].sum()
228 HY = HY - Nvary * dterm - psiterms[:Nvary].sum()
229 HXY = HXY - Nvarxy * dterm - psiterms[:Nvarxy].sum()
231 # MI in bits
232 I = (HX + HY - HXY) / ln2
233 return I
236@njit()
237def mi_model_gd(x, y, Ym, biascorrect=True, demeaned=False):
238 """Mutual information (MI) between a Gaussian and a discrete variable in bits
239 based on ANOVA style model comparison.
240 I = mi_model_gd(x,y,Ym) returns the MI between the (possibly multidimensional)
241 Gaussian variable x and the discrete variable y.
242 For 1D x this is a lower bound to the mutual information.
243 Columns of x correspond to samples, rows to dimensions/variables.
244 (Samples last axis)
245 y should contain integer values in the range [0 Ym-1] (inclusive).
246 biascorrect : true / false option (default true) which specifies whether
247 bias correction should be applied to the estimated MI.
248 demeaned : false / true option (default false) which specifies whether the
249 input data already has zero mean (true if it has been copula-normalized)
250 See also: mi_mixture_gd
251 """
253 x = np.atleast_2d(x)
254 # y = np.squeeze(y)
255 if x.ndim > 2:
256 raise ValueError("x must be at most 2d")
257 if y.ndim > 1:
258 raise ValueError("only univariate discrete variables supported")
259 '''
260 if not np.issubdtype(y.dtype, np.integer):
261 raise ValueError("y should be an integer array")
262 '''
263 if int(Ym) != Ym:
264 raise ValueError("Ym should be an integer")
266 Ntrl = x.shape[1]
267 Nvarx = x.shape[0]
269 if y.size != Ntrl:
270 raise ValueError("number of trials do not match")
271 '''
272 if not demeaned:
273 x = x - x.mean(axis=1)[:,np.newaxis]
274 '''
275 # class-conditional entropies
276 Ntrl_y = np.zeros(Ym)
277 Hcond = np.zeros(Ym)
278 c = 0.5 * (np.log(2.0 * np.pi) + 1)
280 for yi in range(Ym):
281 idx = y == yi
282 xm = x[:, idx]
283 Ntrl_y[yi] = xm.shape[1]
284 xm = demean(xm)
285 Cm = np.dot(xm, xm.T) / float(Ntrl_y[yi] - 1)
286 chCm = regularized_cholesky(Cm)
287 Hcond[yi] = np.sum(np.log(np.diag(chCm))) # + c*Nvarx
289 # class weights
290 w = Ntrl_y / float(Ntrl)
292 # unconditional entropy from unconditional Gaussian fit
293 Cx = np.dot(x, x.T) / float(Ntrl - 1)
294 chC = regularized_cholesky(Cx)
295 Hunc = np.sum(np.log(np.diag(chC))) # + c*Nvarx
297 ln2 = np.log(2)
298 if biascorrect:
299 vars = np.arange(1, Nvarx + 1)
301 psiterms = py_fast_digamma_arr((Ntrl - vars) / 2.0) / 2.0
302 dterm = (ln2 - np.log(float(Ntrl - 1))) / 2.0
303 Hunc = Hunc - Nvarx * dterm - psiterms.sum()
305 dterm = (ln2 - np.log((Ntrl_y - 1))) / 2.0
306 psiterms = np.zeros(Ym)
307 for vi in vars:
308 idx = Ntrl_y - vi
309 psiterms = psiterms + py_fast_digamma_arr(idx / 2.0)
310 Hcond = Hcond - Nvarx * dterm - (psiterms / 2.0)
312 # MI in bits
313 I = (Hunc - np.sum(w * Hcond)) / ln2
314 return I
317def gcmi_cc(x, y):
318 """Gaussian-Copula Mutual Information between two continuous variables.
319 I = gcmi_cc(x,y) returns the MI between two (possibly multidimensional)
320 continuous variables, x and y, estimated via a Gaussian copula.
321 If x and/or y are multivariate columns must correspond to samples, rows
322 to dimensions/variables. (Samples first axis)
323 This provides a lower bound to the true MI value.
324 """
326 x = np.atleast_2d(x)
327 y = np.atleast_2d(y)
328 if x.ndim > 2 or y.ndim > 2:
329 raise ValueError("x and y must be at most 2d")
330 Ntrl = x.shape[1]
331 Nvarx = x.shape[0]
332 Nvary = y.shape[0]
334 if y.shape[1] != Ntrl:
335 raise ValueError("number of trials do not match")
337 # Use JIT version if available and suitable
338 if (_JIT_AVAILABLE and
339 x.flags.c_contiguous and y.flags.c_contiguous and
340 x.dtype in (np.float32, np.float64) and y.dtype in (np.float32, np.float64)):
341 return gcmi_cc_jit(x, y)
343 '''
344 # check for repeated values
345 for xi in range(Nvarx):
346 if (np.unique(x[xi,:]).size / float(Ntrl)) < 0.9:
347 warnings.warn("Input x has more than 10% repeated values")
348 break
349 for yi in range(Nvary):
350 if (np.unique(y[yi,:]).size / float(Ntrl)) < 0.9:
351 warnings.warn("Input y has more than 10% repeated values")
352 break
353 '''
355 # copula normalization
356 cx = copnorm(x)
357 cy = copnorm(y)
358 # parametric Gaussian MI
359 I = mi_gg(cx, cy, True, True)
360 return I
362# TODO: integrate into numba everything below this line
363def cmi_ggg(x, y, z, biascorrect=True, demeaned=False):
364 """Conditional Mutual information (CMI) between two Gaussian variables
365 conditioned on a third
367 I = cmi_ggg(x,y,z) returns the CMI between two (possibly multidimensional)
368 Gassian variables, x and y, conditioned on a third, z, with bias correction.
369 If x / y / z are multivariate columns must correspond to samples, rows
370 to dimensions/variables. (Samples last axis)
372 biascorrect : true / false option (default true) which specifies whether
373 bias correction should be applied to the esimtated MI.
374 demeaned : false / true option (default false) which specifies whether the
375 input data already has zero mean (true if it has been copula-normalized)
377 """
379 x = np.atleast_2d(x)
380 y = np.atleast_2d(y)
381 z = np.atleast_2d(z)
383 # Use JIT version if available and suitable
384 if (_JIT_AVAILABLE and
385 x.flags.c_contiguous and y.flags.c_contiguous and z.flags.c_contiguous and
386 x.dtype in (np.float32, np.float64) and
387 y.dtype in (np.float32, np.float64) and
388 z.dtype in (np.float32, np.float64)):
389 return cmi_ggg_jit(x, y, z, biascorrect, demeaned)
391 if x.ndim > 2 or y.ndim > 2 or z.ndim > 2:
392 raise ValueError("x, y and z must be at most 2d")
393 Ntrl = x.shape[1]
394 Nvarx = x.shape[0]
395 Nvary = y.shape[0]
396 Nvarz = z.shape[0]
397 Nvaryz = Nvary + Nvarz
398 Nvarxy = Nvarx + Nvary
399 Nvarxz = Nvarx + Nvarz
400 Nvarxyz = Nvarx + Nvaryz
402 if y.shape[1] != Ntrl or z.shape[1] != Ntrl:
403 raise ValueError("number of trials do not match")
405 # joint variable
406 xyz = np.vstack((x,y,z))
407 if not demeaned:
408 xyz = xyz - xyz.mean(axis=1)[:,np.newaxis]
409 Cxyz = np.dot(xyz,xyz.T) / float(Ntrl - 1)
410 # submatrices of joint covariance
411 Cz = Cxyz[Nvarxy:,Nvarxy:]
412 Cyz = Cxyz[Nvarx:,Nvarx:]
413 Cxz = np.zeros((Nvarxz,Nvarxz))
414 Cxz[:Nvarx,:Nvarx] = Cxyz[:Nvarx,:Nvarx]
415 Cxz[:Nvarx,Nvarx:] = Cxyz[:Nvarx,Nvarxy:]
416 Cxz[Nvarx:,:Nvarx] = Cxyz[Nvarxy:,:Nvarx]
417 Cxz[Nvarx:,Nvarx:] = Cxyz[Nvarxy:,Nvarxy:]
419 chCz = regularized_cholesky(Cz)
420 chCxz = regularized_cholesky(Cxz)
421 chCyz = regularized_cholesky(Cyz)
422 chCxyz = regularized_cholesky(Cxyz)
424 # entropies in nats
425 # normalizations cancel for cmi
426 HZ = np.sum(np.log(np.diagonal(chCz))) # + 0.5*Nvarz*(np.log(2*np.pi)+1.0)
427 HXZ = np.sum(np.log(np.diagonal(chCxz))) # + 0.5*Nvarxz*(np.log(2*np.pi)+1.0)
428 HYZ = np.sum(np.log(np.diagonal(chCyz))) # + 0.5*Nvaryz*(np.log(2*np.pi)+1.0)
429 HXYZ = np.sum(np.log(np.diagonal(chCxyz))) # + 0.5*Nvarxyz*(np.log(2*np.pi)+1.0)
431 ln2 = np.log(2)
432 if biascorrect:
433 psiterms = psi((Ntrl - np.arange(1,Nvarxyz+1)).astype(float)/2.0) / 2.0
434 dterm = (ln2 - np.log(Ntrl-1.0)) / 2.0
435 HZ = HZ - Nvarz*dterm - psiterms[:Nvarz].sum()
436 HXZ = HXZ - Nvarxz*dterm - psiterms[:Nvarxz].sum()
437 HYZ = HYZ - Nvaryz*dterm - psiterms[:Nvaryz].sum()
438 HXYZ = HXYZ - Nvarxyz*dterm - psiterms[:Nvarxyz].sum()
440 # MI in bits
441 I = (HXZ + HYZ - HXYZ - HZ) / ln2
442 return I
445def gccmi_ccd(x,y,z,Zm):
446 """Gaussian-Copula CMI between 2 continuous variables conditioned on a discrete variable.
448 I = gccmi_ccd(x,y,z,Zm) returns the CMI between two (possibly multidimensional)
449 continuous variables, x and y, conditioned on a third discrete variable z, estimated
450 via a Gaussian copula.
451 If x and/or y are multivariate columns must correspond to samples, rows
452 to dimensions/variables. (Samples first axis)
453 z should contain integer values in the range [0 Zm-1] (inclusive).
455 """
457 x = np.atleast_2d(x)
458 y = np.atleast_2d(y)
459 if x.ndim > 2 or y.ndim > 2:
460 raise ValueError("x and y must be at most 2d")
461 if z.ndim > 1:
462 raise ValueError("only univariate discrete variables supported")
463 if not np.issubdtype(z.dtype, np.integer):
464 raise ValueError("z should be an integer array")
465 if not isinstance(Zm, int):
466 raise ValueError("Zm should be an integer")
468 Ntrl = x.shape[1]
469 Nvarx = x.shape[0]
470 Nvary = y.shape[0]
472 if y.shape[1] != Ntrl or z.size != Ntrl:
473 raise ValueError("number of trials do not match")
475 # check for repeated values
476 for xi in range(Nvarx):
477 if (np.unique(x[xi,:]).size / float(Ntrl)) < 0.9:
478 warnings.warn("Input x has more than 10% repeated values")
479 break
480 for yi in range(Nvary):
481 if (np.unique(y[yi,:]).size / float(Ntrl)) < 0.9:
482 warnings.warn("Input y has more than 10% repeated values")
483 break
485 # check values of discrete variable
486 if z.min()!=0 or z.max()!=(Zm-1):
487 raise ValueError("values of discrete variable z are out of bounds")
489 # calculate gcmi for each z value
490 Icond = np.zeros(Zm)
491 Pz = np.zeros(Zm)
492 cx = []
493 cy = []
494 for zi in range(Zm):
495 idx = z==zi
496 thsx = copnorm(x[:,idx])
497 thsy = copnorm(y[:,idx])
498 Pz[zi] = idx.sum()
499 cx.append(thsx)
500 cy.append(thsy)
501 Icond[zi] = mi_gg(thsx,thsy,True,True)
503 Pz = Pz / float(Ntrl)
505 # conditional mutual information
506 CMI = np.sum(Pz*Icond)
507 #I = mi_gg(np.hstack(cx),np.hstack(cy),True,False)
508 return CMI