Coverage for src/driada/information/gcmi_jit_utils.py: 10.14%

207 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2JIT-compiled copula transformation functions for GCMI. 

3""" 

4 

5import numpy as np 

6from numba import njit 

7from scipy.special import ndtri 

8 

9 

10@njit 

11def ctransform_jit(x): 

12 """JIT-compiled copula transformation (empirical CDF). 

13  

14 Efficient O(n log n) implementation using sorting-based ranking. 

15  

16 Parameters 

17 ---------- 

18 x : ndarray 

19 1D array of values to transform. 

20  

21 Returns 

22 ------- 

23 ndarray 

24 Copula-transformed values in (0, 1). 

25 """ 

26 n = x.size 

27 

28 # Get sorting indices - O(n log n) 

29 sorted_indices = np.argsort(x) 

30 

31 # Create ranks array 

32 ranks = np.empty(n, dtype=np.int64) 

33 

34 # Assign ranks based on sorted order 

35 for i in range(n): 

36 ranks[sorted_indices[i]] = i + 1 

37 

38 # Handle ties by using the original tie-breaking logic 

39 # For identical values, use index-based tie breaking 

40 sorted_values = x[sorted_indices] 

41 current_rank = 1 

42 

43 for i in range(n): 

44 if i > 0 and sorted_values[i] == sorted_values[i-1]: 

45 # Same value as previous - keep incrementing rank for each occurrence 

46 current_rank += 1 

47 else: 

48 # New value - start fresh rank 

49 current_rank = i + 1 

50 

51 ranks[sorted_indices[i]] = current_rank 

52 

53 # Convert to copula values 

54 return ranks.astype(np.float64) / (n + 1) 

55 

56 

57@njit 

58def ctransform_2d_jit(x): 

59 """JIT-compiled copula transformation for 2D arrays. 

60  

61 Transforms each row independently. 

62  

63 Parameters 

64 ---------- 

65 x : ndarray 

66 2D array where each row is transformed independently. 

67  

68 Returns 

69 ------- 

70 ndarray 

71 Copula-transformed array. 

72 """ 

73 n_vars, n_samples = x.shape 

74 result = np.empty_like(x) 

75 

76 for i in range(n_vars): 

77 result[i, :] = ctransform_jit(x[i, :]) 

78 

79 return result 

80 

81 

82@njit 

83def ndtri_approx(p): 

84 """Approximate inverse normal CDF for JIT compilation. 

85  

86 Uses a simpler but efficient approximation for JIT compilation. 

87  

88 Parameters 

89 ---------- 

90 p : float or ndarray 

91 Probability values in (0, 1). 

92  

93 Returns 

94 ------- 

95 float or ndarray 

96 Approximate quantile values. 

97 """ 

98 # Handle array input 

99 if hasattr(p, 'shape'): 

100 result = np.empty_like(p) 

101 for i in range(p.size): 

102 flat_p = p.flat[i] 

103 if flat_p <= 0.0: 

104 result.flat[i] = -np.inf 

105 elif flat_p >= 1.0: 

106 result.flat[i] = np.inf 

107 else: 

108 # Use Box-Muller-like transformation 

109 # This is a simplified approximation suitable for JIT 

110 if flat_p < 0.5: 

111 # Use symmetry 

112 q = flat_p 

113 t = np.sqrt(-2.0 * np.log(q)) 

114 result.flat[i] = -(t - (2.515517 + 0.802853*t + 0.010328*t*t) / (1.0 + 1.432788*t + 0.189269*t*t + 0.001308*t*t*t)) 

115 else: 

116 q = 1.0 - flat_p 

117 t = np.sqrt(-2.0 * np.log(q)) 

118 result.flat[i] = t - (2.515517 + 0.802853*t + 0.010328*t*t) / (1.0 + 1.432788*t + 0.189269*t*t + 0.001308*t*t*t) 

119 return result 

120 else: 

121 # Scalar input 

122 if p <= 0.0: 

123 return -np.inf 

124 elif p >= 1.0: 

125 return np.inf 

126 else: 

127 if p < 0.5: 

128 q = p 

129 t = np.sqrt(-2.0 * np.log(q)) 

130 return -(t - (2.515517 + 0.802853*t + 0.010328*t*t) / (1.0 + 1.432788*t + 0.189269*t*t + 0.001308*t*t*t)) 

131 else: 

132 q = 1.0 - p 

133 t = np.sqrt(-2.0 * np.log(q)) 

134 return t - (2.515517 + 0.802853*t + 0.010328*t*t) / (1.0 + 1.432788*t + 0.189269*t*t + 0.001308*t*t*t) 

135 

136 

137@njit 

138def copnorm_jit(x): 

139 """JIT-compiled copula normalization. 

140  

141 Fast implementation using approximations suitable for JIT. 

142  

143 Parameters 

144 ---------- 

145 x : ndarray 

146 1D array to normalize. 

147  

148 Returns  

149 ------- 

150 ndarray 

151 Standard normal samples with same empirical CDF as input. 

152 """ 

153 cx = ctransform_jit(x) 

154 return ndtri_approx(cx) 

155 

156 

157@njit 

158def copnorm_2d_jit(x): 

159 """JIT-compiled copula normalization for 2D arrays. 

160  

161 Parameters 

162 ---------- 

163 x : ndarray 

164 2D array where each row is normalized independently. 

165  

166 Returns 

167 ------- 

168 ndarray 

169 Copula-normalized array. 

170 """ 

171 n_vars, n_samples = x.shape 

172 result = np.empty_like(x) 

173 

174 for i in range(n_vars): 

175 result[i, :] = copnorm_jit(x[i, :]) 

176 

177 return result 

178 

179 

180@njit 

181def mi_gg_jit(x, y, biascorrect=True, demeaned=False): 

182 """JIT-compiled Gaussian mutual information between two variables. 

183  

184 Parameters 

185 ---------- 

186 x : ndarray 

187 First variable data (n_vars_x, n_samples). 

188 y : ndarray 

189 Second variable data (n_vars_y, n_samples). 

190 biascorrect : bool 

191 Apply bias correction. 

192 demeaned : bool 

193 Whether data is already demeaned. 

194  

195 Returns 

196 ------- 

197 float 

198 Mutual information in bits. 

199 """ 

200 if x.shape[1] != y.shape[1]: 

201 raise ValueError("Number of samples must match") 

202 

203 Ntrl = x.shape[1] 

204 Nvarx = x.shape[0] 

205 Nvary = y.shape[0] 

206 Nvarxy = Nvarx + Nvary 

207 

208 if not demeaned: 

209 # Demean data - manual implementation for numba 

210 for i in range(x.shape[0]): 

211 x[i] = x[i] - np.mean(x[i]) 

212 for i in range(y.shape[0]): 

213 y[i] = y[i] - np.mean(y[i]) 

214 

215 # Compute covariance matrices 

216 Cxx = np.dot(x, x.T) / (Ntrl - 1) 

217 Cyy = np.dot(y, y.T) / (Ntrl - 1) 

218 Cxy = np.dot(x, y.T) / (Ntrl - 1) 

219 Cyx = np.dot(y, x.T) / (Ntrl - 1) 

220 

221 # Joint covariance 

222 C = np.empty((Nvarxy, Nvarxy)) 

223 C[:Nvarx, :Nvarx] = Cxx 

224 C[:Nvarx, Nvarx:] = Cxy 

225 C[Nvarx:, :Nvarx] = Cyx 

226 C[Nvarx:, Nvarx:] = Cyy 

227 

228 # Compute log determinants using Cholesky decomposition 

229 # Add small regularization to prevent numerical issues with identical data 

230 C += np.eye(C.shape[0]) * 1e-12 

231 Cxx += np.eye(Cxx.shape[0]) * 1e-12 

232 Cyy += np.eye(Cyy.shape[0]) * 1e-12 

233 

234 chC = np.linalg.cholesky(C) 

235 chCxx = np.linalg.cholesky(Cxx) 

236 chCyy = np.linalg.cholesky(Cyy) 

237 

238 # Sum of log diagonals 

239 HX = np.sum(np.log(np.diag(chCxx))) 

240 HY = np.sum(np.log(np.diag(chCyy))) 

241 HXY = np.sum(np.log(np.diag(chC))) 

242 

243 ln2 = np.log(2.0) 

244 

245 if biascorrect: 

246 # Bias correction terms 

247 psiterms = np.zeros(Nvarxy) 

248 dterm = (ln2 - np.log(Ntrl - 1.0)) / 2.0 

249 

250 for i in range(Nvarxy): 

251 psiterms[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0 

252 

253 HX = HX - Nvarx * dterm - np.sum(psiterms[:Nvarx]) 

254 HY = HY - Nvary * dterm - np.sum(psiterms[:Nvary]) 

255 HXY = HXY - Nvarxy * dterm - np.sum(psiterms) 

256 

257 # MI in bits 

258 I = (HX + HY - HXY) / ln2 

259 return I 

260 

261 

262@njit 

263def cmi_ggg_jit(x, y, z, biascorrect=True, demeaned=False): 

264 """JIT-compiled conditional mutual information for Gaussian variables. 

265  

266 Computes I(X;Y|Z) for continuous variables. 

267  

268 Parameters 

269 ---------- 

270 x : ndarray 

271 First variable (n_vars_x, n_samples). 

272 y : ndarray 

273 Second variable (n_vars_y, n_samples). 

274 z : ndarray 

275 Conditioning variable (n_vars_z, n_samples). 

276 biascorrect : bool 

277 Apply bias correction. 

278 demeaned : bool 

279 Whether data is already demeaned. 

280  

281 Returns 

282 ------- 

283 float 

284 Conditional mutual information in bits. 

285 """ 

286 if x.shape[1] != y.shape[1] or x.shape[1] != z.shape[1]: 

287 raise ValueError("Number of samples must match") 

288 

289 Ntrl = x.shape[1] 

290 Nvarx = x.shape[0] 

291 Nvary = y.shape[0] 

292 Nvarz = z.shape[0] 

293 Nvaryz = Nvary + Nvarz 

294 Nvarxz = Nvarx + Nvarz 

295 Nvarxyz = Nvarx + Nvary + Nvarz 

296 

297 if not demeaned: 

298 # Demean data - manual implementation for numba 

299 for i in range(x.shape[0]): 

300 x[i] = x[i] - np.mean(x[i]) 

301 for i in range(y.shape[0]): 

302 y[i] = y[i] - np.mean(y[i]) 

303 for i in range(z.shape[0]): 

304 z[i] = z[i] - np.mean(z[i]) 

305 

306 # Compute all required covariance matrices 

307 Cxx = np.dot(x, x.T) / (Ntrl - 1) 

308 Cyy = np.dot(y, y.T) / (Ntrl - 1) 

309 Czz = np.dot(z, z.T) / (Ntrl - 1) 

310 Cxy = np.dot(x, y.T) / (Ntrl - 1) 

311 Cxz = np.dot(x, z.T) / (Ntrl - 1) 

312 Cyz = np.dot(y, z.T) / (Ntrl - 1) 

313 

314 # Build joint covariance matrices 

315 # C(y,z) 

316 Cyz_joint = np.empty((Nvaryz, Nvaryz)) 

317 Cyz_joint[:Nvary, :Nvary] = Cyy 

318 Cyz_joint[:Nvary, Nvary:] = Cyz 

319 Cyz_joint[Nvary:, :Nvary] = Cyz.T 

320 Cyz_joint[Nvary:, Nvary:] = Czz 

321 

322 # C(x,z) 

323 Cxz_joint = np.empty((Nvarxz, Nvarxz)) 

324 Cxz_joint[:Nvarx, :Nvarx] = Cxx 

325 Cxz_joint[:Nvarx, Nvarx:] = Cxz 

326 Cxz_joint[Nvarx:, :Nvarx] = Cxz.T 

327 Cxz_joint[Nvarx:, Nvarx:] = Czz 

328 

329 # C(x,y,z) 

330 Cxyz = np.empty((Nvarxyz, Nvarxyz)) 

331 Cxyz[:Nvarx, :Nvarx] = Cxx 

332 Cxyz[:Nvarx, Nvarx:Nvarx+Nvary] = Cxy 

333 Cxyz[:Nvarx, Nvarx+Nvary:] = Cxz 

334 Cxyz[Nvarx:Nvarx+Nvary, :Nvarx] = Cxy.T 

335 Cxyz[Nvarx:Nvarx+Nvary, Nvarx:Nvarx+Nvary] = Cyy 

336 Cxyz[Nvarx:Nvarx+Nvary, Nvarx+Nvary:] = Cyz 

337 Cxyz[Nvarx+Nvary:, :Nvarx] = Cxz.T 

338 Cxyz[Nvarx+Nvary:, Nvarx:Nvarx+Nvary] = Cyz.T 

339 Cxyz[Nvarx+Nvary:, Nvarx+Nvary:] = Czz 

340 

341 # Compute log determinants 

342 # Add small regularization to prevent numerical issues with identical data 

343 Czz += np.eye(Czz.shape[0]) * 1e-12 

344 Cyz_joint += np.eye(Cyz_joint.shape[0]) * 1e-12 

345 Cxz_joint += np.eye(Cxz_joint.shape[0]) * 1e-12 

346 Cxyz += np.eye(Cxyz.shape[0]) * 1e-12 

347 

348 chCz = np.linalg.cholesky(Czz) 

349 chCyz = np.linalg.cholesky(Cyz_joint) 

350 chCxz = np.linalg.cholesky(Cxz_joint) 

351 chCxyz = np.linalg.cholesky(Cxyz) 

352 

353 HZ = np.sum(np.log(np.diag(chCz))) 

354 HYZ = np.sum(np.log(np.diag(chCyz))) 

355 HXZ = np.sum(np.log(np.diag(chCxz))) 

356 HXYZ = np.sum(np.log(np.diag(chCxyz))) 

357 

358 ln2 = np.log(2.0) 

359 

360 if biascorrect: 

361 # Bias correction 

362 dterm = (ln2 - np.log(Ntrl - 1.0)) / 2.0 

363 

364 # Compute psi terms 

365 psiterms_z = np.zeros(Nvarz) 

366 psiterms_yz = np.zeros(Nvaryz) 

367 psiterms_xz = np.zeros(Nvarxz) 

368 psiterms_xyz = np.zeros(Nvarxyz) 

369 

370 for i in range(Nvarz): 

371 psiterms_z[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0 

372 for i in range(Nvaryz): 

373 psiterms_yz[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0 

374 for i in range(Nvarxz): 

375 psiterms_xz[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0 

376 for i in range(Nvarxyz): 

377 psiterms_xyz[i] = digamma_approx((Ntrl - i - 1.0) / 2.0) / 2.0 

378 

379 HZ = HZ - Nvarz * dterm - np.sum(psiterms_z) 

380 HYZ = HYZ - Nvaryz * dterm - np.sum(psiterms_yz) 

381 HXZ = HXZ - Nvarxz * dterm - np.sum(psiterms_xz) 

382 HXYZ = HXYZ - Nvarxyz * dterm - np.sum(psiterms_xyz) 

383 

384 # CMI in bits: I(X;Y|Z) = H(X,Z) + H(Y,Z) - H(X,Y,Z) - H(Z) 

385 I = (HXZ + HYZ - HXYZ - HZ) / ln2 

386 return I 

387 

388 

389@njit 

390def digamma_approx(x): 

391 """Approximate digamma function for JIT compilation. 

392  

393 Uses asymptotic expansion for x > 6 and recurrence for smaller values. 

394  

395 Parameters 

396 ---------- 

397 x : float 

398 Input value. 

399  

400 Returns 

401 ------- 

402 float 

403 Approximate digamma value. 

404 """ 

405 if x <= 0: 

406 return -np.inf 

407 

408 # Use recurrence relation to get x > 6 

409 result = 0.0 

410 while x < 6: 

411 result -= 1.0 / x 

412 x += 1.0 

413 

414 # Asymptotic expansion 

415 x_inv = 1.0 / x 

416 x_inv2 = x_inv * x_inv 

417 

418 # psi(x) ≈ ln(x) - 1/(2x) - 1/(12x²) + 1/(120x⁴) - 1/(252x⁶) 

419 result += np.log(x) - 0.5 * x_inv - x_inv2 / 12.0 + x_inv2 * x_inv2 / 120.0 

420 

421 return result 

422 

423 

424@njit 

425def gcmi_cc_jit(x, y): 

426 """JIT-compiled Gaussian-Copula MI between continuous variables. 

427  

428 Full pipeline: copula transform -> normalize -> compute MI. 

429  

430 Parameters 

431 ---------- 

432 x : ndarray 

433 First variable (n_vars_x, n_samples). 

434 y : ndarray 

435 Second variable (n_vars_y, n_samples). 

436  

437 Returns 

438 ------- 

439 float 

440 GCMI in bits. 

441 """ 

442 # Copula transform 

443 if x.ndim == 1: 

444 cx = np.empty((1, x.shape[0])) 

445 cx[0, :] = copnorm_jit(x) 

446 else: 

447 cx = copnorm_2d_jit(x) 

448 

449 if y.ndim == 1: 

450 cy = np.empty((1, y.shape[0])) 

451 cy[0, :] = copnorm_jit(y) 

452 else: 

453 cy = copnorm_2d_jit(y) 

454 

455 # Compute MI with bias correction 

456 return mi_gg_jit(cx, cy, biascorrect=True, demeaned=True)