Coverage for src/driada/information/gcmi.py: 51.00%

251 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import numpy as np 

2import numba as nb 

3from numba import njit 

4import warnings 

5from scipy.special import ndtri, psi, digamma 

6 

7from .info_utils import py_fast_digamma_arr 

8 

9# Import JIT versions if available 

10try: 

11 from .gcmi_jit_utils import ( 

12 ctransform_jit, ctransform_2d_jit, copnorm_jit, copnorm_2d_jit, 

13 mi_gg_jit, cmi_ggg_jit, gcmi_cc_jit 

14 ) 

15 _JIT_AVAILABLE = True 

16except ImportError: 

17 _JIT_AVAILABLE = False 

18 

19#TODO: credits to original GCMI: https://github.com/robince/gcmi 

20 

21def ctransform(x): 

22 """Copula transformation (empirical CDF) 

23 cx = ctransform(x) returns the empirical CDF value along the first 

24 axis of x. Data is ranked and scaled within [0 1] (open interval). 

25 """ 

26 x = np.atleast_2d(x) 

27 

28 # Use JIT version for suitable inputs 

29 if _JIT_AVAILABLE and x.flags.c_contiguous and x.dtype in (np.float32, np.float64): 

30 if x.shape[0] == 1: 

31 # 1D case 

32 return ctransform_jit(x.ravel()).reshape(1, -1) 

33 else: 

34 # 2D case 

35 return ctransform_2d_jit(x) 

36 

37 # Fallback to original implementation 

38 xi = np.argsort(x) 

39 xr = np.argsort(xi) 

40 cx = (xr + 1).astype(float) / (xr.shape[-1] + 1) 

41 return cx 

42 

43 

44def copnorm(x): 

45 """Copula normalization 

46 

47 cx = copnorm(x) returns standard normal samples with the same empirical 

48 CDF value as the input. Operates along the last axis. 

49 """ 

50 x = np.atleast_2d(x) 

51 

52 # Use JIT version for suitable inputs 

53 if _JIT_AVAILABLE and x.flags.c_contiguous and x.dtype in (np.float32, np.float64): 

54 if x.shape[0] == 1: 

55 # 1D case 

56 return copnorm_jit(x.ravel()).reshape(1, -1) 

57 else: 

58 # 2D case 

59 return copnorm_2d_jit(x) 

60 

61 # Fallback to original implementation 

62 # cx = sp.stats.norm.ppf(ctransform(x)) 

63 cx = ndtri(ctransform(x)) 

64 return cx 

65 

66 

67@njit 

68def demean(x): 

69 """Demean each row of a 2D array. 

70  

71 Parameters 

72 ---------- 

73 x : ndarray 

74 2D array where each row is demeaned independently. 

75  

76 Returns 

77 ------- 

78 ndarray 

79 Array with same shape as input with zero mean rows. 

80 """ 

81 # Get the number of rows 

82 num_rows = x.shape[0] 

83 

84 # Create an output array with the same shape as input 

85 demeaned_x = np.empty_like(x) 

86 

87 # Demean each row 

88 for i in range(num_rows): 

89 row_mean = np.mean(x[i]) 

90 demeaned_x[i] = x[i] - row_mean 

91 

92 return demeaned_x 

93 

94 

95@njit 

96def regularized_cholesky(C, regularization=1e-12): 

97 """Compute Cholesky decomposition with regularization for numerical stability. 

98  

99 Adds diagonal regularization to prevent issues with near-singular 

100 covariance matrices. Uses adaptive regularization for severely ill-conditioned 

101 matrices based on determinant check. 

102  

103 Parameters 

104 ---------- 

105 C : ndarray 

106 Covariance matrix to decompose. 

107 regularization : float, optional 

108 Base regularization parameter added to diagonal (default: 1e-12). 

109  

110 Returns 

111 ------- 

112 ndarray 

113 Lower triangular Cholesky factor. 

114 """ 

115 # Check matrix conditioning using determinant 

116 det_C = np.linalg.det(C) 

117 trace_C = np.trace(C) 

118 

119 # Adaptive regularization based on determinant relative to trace 

120 # For near-singular matrices, det << trace^n where n is matrix size 

121 n = C.shape[0] 

122 expected_det_scale = (trace_C / n) ** n 

123 

124 if det_C > 0 and det_C < expected_det_scale * 1e-8: # Severely ill-conditioned 

125 # Use stronger regularization proportional to trace 

126 adaptive_reg = trace_C * 1e-8 / n # Scale by matrix size 

127 reg = max(regularization, adaptive_reg) 

128 else: 

129 reg = regularization 

130 

131 # Apply regularization 

132 C_reg = C + np.eye(C.shape[0]) * reg 

133 return np.linalg.cholesky(C_reg) 

134 

135 

136@njit() 

137def ent_g(x, biascorrect=True): 

138 """Entropy of a Gaussian variable in bits 

139 H = ent_g(x) returns the entropy of a (possibly 

140 multidimensional) Gaussian variable x with bias correction. 

141 Columns of x correspond to samples, rows to dimensions/variables. 

142 (Samples last axis) 

143 """ 

144 x = np.atleast_2d(x) 

145 if x.ndim > 2: 

146 raise ValueError("x must be at most 2d") 

147 Ntrl = x.shape[1] 

148 Nvarx = x.shape[0] 

149 

150 # demean data 

151 x = demean(x) 

152 # covariance 

153 C = np.dot(x, x.T) / float(Ntrl - 1) 

154 chC = regularized_cholesky(C) 

155 

156 # entropy in nats 

157 # Extract diagonal manually for Numba compatibility 

158 diag_sum = 0.0 

159 for i in range(chC.shape[0]): 

160 diag_sum += np.log(chC[i, i]) 

161 HX = diag_sum + 0.5 * Nvarx * (np.log(2 * np.pi) + 1.0) 

162 

163 ln2 = np.log(2) 

164 if biascorrect: 

165 psiterms = py_fast_digamma_arr((Ntrl - np.arange(1, Nvarx + 1, dtype=np.float64)) / 2.0) / 2.0 

166 dterm = (ln2 - np.log(Ntrl - 1.0)) / 2.0 

167 HX = HX - Nvarx * dterm - psiterms.sum() 

168 

169 # convert to bits 

170 return HX / ln2 

171 

172 

173@njit() 

174def mi_gg(x, y, biascorrect=True, demeaned=False, max_dim=3): 

175 """Mutual information (MI) between two Gaussian variables in bits 

176 

177 I = mi_gg(x,y) returns the MI between two (possibly multidimensional) 

178 Gassian variables, x and y, with bias correction. 

179 If x and/or y are multivariate columns must correspond to samples, rows 

180 to dimensions/variables. (Samples last axis) 

181 

182 biascorrect : true / false option (default true) which specifies whether 

183 bias correction should be applied to the estimated MI. 

184 demeaned : false / true option (default false) which specifies whether th 

185 input data already has zero mean (true if it has been copula-normalized) 

186 max_dim : int (default 3) which specifies the maximum allowed dimensionality 

187 to prevent undersampling issues. 

188 """ 

189 

190 x = np.atleast_2d(x) 

191 y = np.atleast_2d(y) 

192 if x.ndim > max_dim or y.ndim > max_dim: 

193 raise ValueError(f"x and y must be at most {max_dim}d to prevent undersampling issues") 

194 Ntrl = x.shape[1] 

195 Nvarx = x.shape[0] 

196 Nvary = y.shape[0] 

197 Nvarxy = Nvarx + Nvary 

198 

199 if y.shape[1] != Ntrl: 

200 raise ValueError("number of trials do not match") 

201 

202 # joint variable 

203 xy = np.vstack((x, y)) 

204 

205 if not demeaned: 

206 xy = demean(xy) 

207 

208 Cxy = np.dot(xy, xy.T) / float(Ntrl - 1) 

209 # submatrices of joint covariance 

210 Cx = Cxy[:Nvarx, :Nvarx] 

211 Cy = Cxy[Nvarx:, Nvarx:] 

212 

213 chCxy = regularized_cholesky(Cxy) 

214 chCx = regularized_cholesky(Cx) 

215 chCy = regularized_cholesky(Cy) 

216 

217 # entropies in nats 

218 # normalizations cancel for mutual information 

219 HX = np.sum(np.log(np.diag(chCx))) # + 0.5*Nvarx*(np.log(2*np.pi)+1.0) 

220 HY = np.sum(np.log(np.diag(chCy))) # + 0.5*Nvary*(np.log(2*np.pi)+1.0) 

221 HXY = np.sum(np.log(np.diag(chCxy))) # + 0.5*Nvarxy*(np.log(2*np.pi)+1.0) 

222 

223 ln2 = np.log(2) 

224 if biascorrect: 

225 psiterms = py_fast_digamma_arr((Ntrl - np.arange(1, Nvarxy + 1)) / 2.0) / 2.0 

226 dterm = (ln2 - np.log(Ntrl - 1.0)) / 2.0 

227 HX = HX - Nvarx * dterm - psiterms[:Nvarx].sum() 

228 HY = HY - Nvary * dterm - psiterms[:Nvary].sum() 

229 HXY = HXY - Nvarxy * dterm - psiterms[:Nvarxy].sum() 

230 

231 # MI in bits 

232 I = (HX + HY - HXY) / ln2 

233 return I 

234 

235 

236@njit() 

237def mi_model_gd(x, y, Ym, biascorrect=True, demeaned=False): 

238 """Mutual information (MI) between a Gaussian and a discrete variable in bits 

239 based on ANOVA style model comparison. 

240 I = mi_model_gd(x,y,Ym) returns the MI between the (possibly multidimensional) 

241 Gaussian variable x and the discrete variable y. 

242 For 1D x this is a lower bound to the mutual information. 

243 Columns of x correspond to samples, rows to dimensions/variables. 

244 (Samples last axis) 

245 y should contain integer values in the range [0 Ym-1] (inclusive). 

246 biascorrect : true / false option (default true) which specifies whether 

247 bias correction should be applied to the estimated MI. 

248 demeaned : false / true option (default false) which specifies whether the 

249 input data already has zero mean (true if it has been copula-normalized) 

250 See also: mi_mixture_gd 

251 """ 

252 

253 x = np.atleast_2d(x) 

254 # y = np.squeeze(y) 

255 if x.ndim > 2: 

256 raise ValueError("x must be at most 2d") 

257 if y.ndim > 1: 

258 raise ValueError("only univariate discrete variables supported") 

259 ''' 

260 if not np.issubdtype(y.dtype, np.integer): 

261 raise ValueError("y should be an integer array") 

262 ''' 

263 if int(Ym) != Ym: 

264 raise ValueError("Ym should be an integer") 

265 

266 Ntrl = x.shape[1] 

267 Nvarx = x.shape[0] 

268 

269 if y.size != Ntrl: 

270 raise ValueError("number of trials do not match") 

271 ''' 

272 if not demeaned: 

273 x = x - x.mean(axis=1)[:,np.newaxis] 

274 ''' 

275 # class-conditional entropies 

276 Ntrl_y = np.zeros(Ym) 

277 Hcond = np.zeros(Ym) 

278 c = 0.5 * (np.log(2.0 * np.pi) + 1) 

279 

280 for yi in range(Ym): 

281 idx = y == yi 

282 xm = x[:, idx] 

283 Ntrl_y[yi] = xm.shape[1] 

284 xm = demean(xm) 

285 Cm = np.dot(xm, xm.T) / float(Ntrl_y[yi] - 1) 

286 chCm = regularized_cholesky(Cm) 

287 Hcond[yi] = np.sum(np.log(np.diag(chCm))) # + c*Nvarx 

288 

289 # class weights 

290 w = Ntrl_y / float(Ntrl) 

291 

292 # unconditional entropy from unconditional Gaussian fit 

293 Cx = np.dot(x, x.T) / float(Ntrl - 1) 

294 chC = regularized_cholesky(Cx) 

295 Hunc = np.sum(np.log(np.diag(chC))) # + c*Nvarx 

296 

297 ln2 = np.log(2) 

298 if biascorrect: 

299 vars = np.arange(1, Nvarx + 1) 

300 

301 psiterms = py_fast_digamma_arr((Ntrl - vars) / 2.0) / 2.0 

302 dterm = (ln2 - np.log(float(Ntrl - 1))) / 2.0 

303 Hunc = Hunc - Nvarx * dterm - psiterms.sum() 

304 

305 dterm = (ln2 - np.log((Ntrl_y - 1))) / 2.0 

306 psiterms = np.zeros(Ym) 

307 for vi in vars: 

308 idx = Ntrl_y - vi 

309 psiterms = psiterms + py_fast_digamma_arr(idx / 2.0) 

310 Hcond = Hcond - Nvarx * dterm - (psiterms / 2.0) 

311 

312 # MI in bits 

313 I = (Hunc - np.sum(w * Hcond)) / ln2 

314 return I 

315 

316 

317def gcmi_cc(x, y): 

318 """Gaussian-Copula Mutual Information between two continuous variables. 

319 I = gcmi_cc(x,y) returns the MI between two (possibly multidimensional) 

320 continuous variables, x and y, estimated via a Gaussian copula. 

321 If x and/or y are multivariate columns must correspond to samples, rows 

322 to dimensions/variables. (Samples first axis) 

323 This provides a lower bound to the true MI value. 

324 """ 

325 

326 x = np.atleast_2d(x) 

327 y = np.atleast_2d(y) 

328 if x.ndim > 2 or y.ndim > 2: 

329 raise ValueError("x and y must be at most 2d") 

330 Ntrl = x.shape[1] 

331 Nvarx = x.shape[0] 

332 Nvary = y.shape[0] 

333 

334 if y.shape[1] != Ntrl: 

335 raise ValueError("number of trials do not match") 

336 

337 # Use JIT version if available and suitable 

338 if (_JIT_AVAILABLE and 

339 x.flags.c_contiguous and y.flags.c_contiguous and 

340 x.dtype in (np.float32, np.float64) and y.dtype in (np.float32, np.float64)): 

341 return gcmi_cc_jit(x, y) 

342 

343 ''' 

344 # check for repeated values 

345 for xi in range(Nvarx): 

346 if (np.unique(x[xi,:]).size / float(Ntrl)) < 0.9: 

347 warnings.warn("Input x has more than 10% repeated values") 

348 break 

349 for yi in range(Nvary): 

350 if (np.unique(y[yi,:]).size / float(Ntrl)) < 0.9: 

351 warnings.warn("Input y has more than 10% repeated values") 

352 break 

353 ''' 

354 

355 # copula normalization 

356 cx = copnorm(x) 

357 cy = copnorm(y) 

358 # parametric Gaussian MI 

359 I = mi_gg(cx, cy, True, True) 

360 return I 

361 

362# TODO: integrate into numba everything below this line 

363def cmi_ggg(x, y, z, biascorrect=True, demeaned=False): 

364 """Conditional Mutual information (CMI) between two Gaussian variables 

365 conditioned on a third 

366 

367 I = cmi_ggg(x,y,z) returns the CMI between two (possibly multidimensional) 

368 Gassian variables, x and y, conditioned on a third, z, with bias correction. 

369 If x / y / z are multivariate columns must correspond to samples, rows 

370 to dimensions/variables. (Samples last axis) 

371 

372 biascorrect : true / false option (default true) which specifies whether 

373 bias correction should be applied to the esimtated MI. 

374 demeaned : false / true option (default false) which specifies whether the 

375 input data already has zero mean (true if it has been copula-normalized) 

376 

377 """ 

378 

379 x = np.atleast_2d(x) 

380 y = np.atleast_2d(y) 

381 z = np.atleast_2d(z) 

382 

383 # Use JIT version if available and suitable 

384 if (_JIT_AVAILABLE and 

385 x.flags.c_contiguous and y.flags.c_contiguous and z.flags.c_contiguous and 

386 x.dtype in (np.float32, np.float64) and 

387 y.dtype in (np.float32, np.float64) and 

388 z.dtype in (np.float32, np.float64)): 

389 return cmi_ggg_jit(x, y, z, biascorrect, demeaned) 

390 

391 if x.ndim > 2 or y.ndim > 2 or z.ndim > 2: 

392 raise ValueError("x, y and z must be at most 2d") 

393 Ntrl = x.shape[1] 

394 Nvarx = x.shape[0] 

395 Nvary = y.shape[0] 

396 Nvarz = z.shape[0] 

397 Nvaryz = Nvary + Nvarz 

398 Nvarxy = Nvarx + Nvary 

399 Nvarxz = Nvarx + Nvarz 

400 Nvarxyz = Nvarx + Nvaryz 

401 

402 if y.shape[1] != Ntrl or z.shape[1] != Ntrl: 

403 raise ValueError("number of trials do not match") 

404 

405 # joint variable 

406 xyz = np.vstack((x,y,z)) 

407 if not demeaned: 

408 xyz = xyz - xyz.mean(axis=1)[:,np.newaxis] 

409 Cxyz = np.dot(xyz,xyz.T) / float(Ntrl - 1) 

410 # submatrices of joint covariance 

411 Cz = Cxyz[Nvarxy:,Nvarxy:] 

412 Cyz = Cxyz[Nvarx:,Nvarx:] 

413 Cxz = np.zeros((Nvarxz,Nvarxz)) 

414 Cxz[:Nvarx,:Nvarx] = Cxyz[:Nvarx,:Nvarx] 

415 Cxz[:Nvarx,Nvarx:] = Cxyz[:Nvarx,Nvarxy:] 

416 Cxz[Nvarx:,:Nvarx] = Cxyz[Nvarxy:,:Nvarx] 

417 Cxz[Nvarx:,Nvarx:] = Cxyz[Nvarxy:,Nvarxy:] 

418 

419 chCz = regularized_cholesky(Cz) 

420 chCxz = regularized_cholesky(Cxz) 

421 chCyz = regularized_cholesky(Cyz) 

422 chCxyz = regularized_cholesky(Cxyz) 

423 

424 # entropies in nats 

425 # normalizations cancel for cmi 

426 HZ = np.sum(np.log(np.diagonal(chCz))) # + 0.5*Nvarz*(np.log(2*np.pi)+1.0) 

427 HXZ = np.sum(np.log(np.diagonal(chCxz))) # + 0.5*Nvarxz*(np.log(2*np.pi)+1.0) 

428 HYZ = np.sum(np.log(np.diagonal(chCyz))) # + 0.5*Nvaryz*(np.log(2*np.pi)+1.0) 

429 HXYZ = np.sum(np.log(np.diagonal(chCxyz))) # + 0.5*Nvarxyz*(np.log(2*np.pi)+1.0) 

430 

431 ln2 = np.log(2) 

432 if biascorrect: 

433 psiterms = psi((Ntrl - np.arange(1,Nvarxyz+1)).astype(float)/2.0) / 2.0 

434 dterm = (ln2 - np.log(Ntrl-1.0)) / 2.0 

435 HZ = HZ - Nvarz*dterm - psiterms[:Nvarz].sum() 

436 HXZ = HXZ - Nvarxz*dterm - psiterms[:Nvarxz].sum() 

437 HYZ = HYZ - Nvaryz*dterm - psiterms[:Nvaryz].sum() 

438 HXYZ = HXYZ - Nvarxyz*dterm - psiterms[:Nvarxyz].sum() 

439 

440 # MI in bits 

441 I = (HXZ + HYZ - HXYZ - HZ) / ln2 

442 return I 

443 

444 

445def gccmi_ccd(x,y,z,Zm): 

446 """Gaussian-Copula CMI between 2 continuous variables conditioned on a discrete variable. 

447 

448 I = gccmi_ccd(x,y,z,Zm) returns the CMI between two (possibly multidimensional) 

449 continuous variables, x and y, conditioned on a third discrete variable z, estimated 

450 via a Gaussian copula. 

451 If x and/or y are multivariate columns must correspond to samples, rows 

452 to dimensions/variables. (Samples first axis) 

453 z should contain integer values in the range [0 Zm-1] (inclusive). 

454 

455 """ 

456 

457 x = np.atleast_2d(x) 

458 y = np.atleast_2d(y) 

459 if x.ndim > 2 or y.ndim > 2: 

460 raise ValueError("x and y must be at most 2d") 

461 if z.ndim > 1: 

462 raise ValueError("only univariate discrete variables supported") 

463 if not np.issubdtype(z.dtype, np.integer): 

464 raise ValueError("z should be an integer array") 

465 if not isinstance(Zm, int): 

466 raise ValueError("Zm should be an integer") 

467 

468 Ntrl = x.shape[1] 

469 Nvarx = x.shape[0] 

470 Nvary = y.shape[0] 

471 

472 if y.shape[1] != Ntrl or z.size != Ntrl: 

473 raise ValueError("number of trials do not match") 

474 

475 # check for repeated values 

476 for xi in range(Nvarx): 

477 if (np.unique(x[xi,:]).size / float(Ntrl)) < 0.9: 

478 warnings.warn("Input x has more than 10% repeated values") 

479 break 

480 for yi in range(Nvary): 

481 if (np.unique(y[yi,:]).size / float(Ntrl)) < 0.9: 

482 warnings.warn("Input y has more than 10% repeated values") 

483 break 

484 

485 # check values of discrete variable 

486 if z.min()!=0 or z.max()!=(Zm-1): 

487 raise ValueError("values of discrete variable z are out of bounds") 

488 

489 # calculate gcmi for each z value 

490 Icond = np.zeros(Zm) 

491 Pz = np.zeros(Zm) 

492 cx = [] 

493 cy = [] 

494 for zi in range(Zm): 

495 idx = z==zi 

496 thsx = copnorm(x[:,idx]) 

497 thsy = copnorm(y[:,idx]) 

498 Pz[zi] = idx.sum() 

499 cx.append(thsx) 

500 cy.append(thsy) 

501 Icond[zi] = mi_gg(thsx,thsy,True,True) 

502 

503 Pz = Pz / float(Ntrl) 

504 

505 # conditional mutual information 

506 CMI = np.sum(Pz*Icond) 

507 #I = mi_gg(np.hstack(cx),np.hstack(cy),True,False) 

508 return CMI