Coverage for src/driada/information/info_base.py: 72.49%

389 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1from sklearn.feature_selection import mutual_info_classif, mutual_info_regression 

2from sklearn.metrics.cluster import mutual_info_score 

3import scipy 

4 

5from .ksg import * 

6from .gcmi import * 

7from .info_utils import binary_mi_score 

8from ..utils.data import correlation_matrix 

9from .entropy import entropy_d, joint_entropy_dd, joint_entropy_cd, joint_entropy_cdd 

10from ..dim_reduction.data import MVData 

11 

12import numpy as np 

13import warnings 

14from typing import Optional 

15from sklearn.preprocessing import MinMaxScaler 

16from scipy.stats import entropy, differential_entropy 

17 

18from ..utils.data import to_numpy_array 

19 

20 

21DEFAULT_NN = 5 

22 

23# TODO: add @property decorators to properly set getter-setter functionality 

24 

25 

26class TimeSeries(): 

27 @staticmethod 

28 def define_ts_type(ts): 

29 if len(ts) < 100: 

30 warnings.warn('Time series is too short for accurate type (discrete/continuous) determination') 

31 

32 unique_vals = np.unique(ts) 

33 sc1 = len(unique_vals) / len(ts) 

34 hist = np.histogram(ts, bins=len(ts))[0] 

35 ent = entropy(hist) 

36 maxent = entropy(np.ones(len(ts))) 

37 sc2 = ent / maxent 

38 

39 # TODO: refactor thresholds 

40 if sc1 > 0.70 and sc2 > 0.70: 

41 return False # both scores are high - the variable is most probably continuous 

42 elif sc1 < 0.25 and sc2 < 0.25: 

43 return True # both scores are low - the variable is most probably discrete 

44 else: 

45 raise ValueError(f'Unable to determine time series type automatically: score 1 = {sc1}, score 2 = {sc2}') 

46 

47 # TODO: complete this function 

48 def _check_input(self): 

49 pass 

50 

51 def __init__(self, data, discrete=None, shuffle_mask=None): 

52 self.data = to_numpy_array(data) 

53 

54 if discrete is None: 

55 #warnings.warn('Time series type not specified and will be inferred automatically') 

56 self.discrete = TimeSeries.define_ts_type(self.data) 

57 else: 

58 self.discrete = discrete 

59 

60 scaler = MinMaxScaler() 

61 self.scdata = scaler.fit_transform(self.data.reshape(-1, 1)).reshape(1, -1)[0] 

62 self.data_scale = scaler.scale_ 

63 self.copula_normal_data = None 

64 

65 if self.discrete: 

66 self.int_data = np.round(self.data).astype(int) 

67 if len(set(self.data.astype(int))) == 2: 

68 self.is_binary = True 

69 self.bool_data = self.int_data.astype(bool) 

70 else: 

71 self.is_binary = False 

72 

73 else: 

74 self.copula_normal_data = copnorm(self.data).ravel() 

75 

76 self.entropy = dict() # supports various downsampling constants 

77 self.kdtree = None 

78 self.kdtree_query = None 

79 

80 if shuffle_mask is None: 

81 # which shuffles are valid 

82 self.shuffle_mask = np.ones(len(self.data)).astype(bool) 

83 else: 

84 self.shuffle_mask = shuffle_mask.astype(bool) 

85 

86 def get_kdtree(self): 

87 if self.kdtree is None: 

88 tree = self._compute_kdtree() 

89 self.kdtree = tree 

90 

91 return self.kdtree 

92 

93 def _compute_kdtree(self): 

94 d = self.data.reshape(self.data.shape[0], -1) 

95 return build_tree(d) 

96 

97 def get_kdtree_query(self, k=DEFAULT_NN): 

98 if self.kdtree_query is None: 

99 q = self._compute_kdtree_query(k=k) 

100 self.kdtree_query = q 

101 

102 return self.kdtree_query 

103 

104 def _compute_kdtree_query(self, k=DEFAULT_NN): 

105 tree = self.get_kdtree() 

106 return tree.query(self.data, k=k + 1) 

107 

108 def get_entropy(self, ds=1): 

109 if ds not in self.entropy.keys(): 

110 self._compute_entropy(ds=ds) 

111 return self.entropy[ds] 

112 

113 def _compute_entropy(self, ds=1): 

114 if self.discrete: 

115 # TODO: rewrite this using int_data and via ent_d from driada.information.entropy 

116 counts = [] 

117 for val in np.unique(self.data[::ds]): 

118 counts.append(len(np.where(self.data[::ds] == val)[0])) 

119 

120 self.entropy[ds] = entropy(counts, base=np.e) 

121 

122 else: 

123 self.entropy[ds] = nonparam_entropy_c(self.data) / np.log(2) 

124 #self.entropy[ds] = get_tdmi(self.scdata[::ds], min_shift=1, max_shift=2)[0] 

125 #raise AttributeError('Entropy for continuous variables is not yet implemented' 

126 

127 def filter(self, method='gaussian', **kwargs): 

128 """ 

129 Apply filtering to the time series and return a new filtered TimeSeries. 

130  

131 Parameters 

132 ---------- 

133 method : str 

134 Filtering method: 'gaussian', 'savgol', 'wavelet', or 'none' 

135 **kwargs : dict 

136 Method-specific parameters: 

137 - gaussian: sigma (default: 1.0) 

138 - savgol: window_length (default: 5), polyorder (default: 2) 

139 - wavelet: wavelet (default: 'db4'), level (default: None) 

140  

141 Returns 

142 ------- 

143 TimeSeries 

144 New TimeSeries object with filtered data 

145 """ 

146 from ..utils.signals import filter_1d_timeseries 

147 

148 if method == 'none': 

149 return TimeSeries(self.data.copy(), discrete=self.discrete, 

150 shuffle_mask=self.shuffle_mask.copy()) 

151 

152 if self.discrete: 

153 warnings.warn("Filtering discrete time series may produce unexpected results") 

154 

155 # Apply filtering to 1D time series 

156 filtered_data = filter_1d_timeseries(self.data, method=method, **kwargs) 

157 

158 # Create new TimeSeries with filtered data 

159 return TimeSeries(filtered_data, discrete=self.discrete, 

160 shuffle_mask=self.shuffle_mask.copy()) 

161 

162 def approximate_entropy(self, m: int = 2, r: Optional[float] = None) -> float: 

163 """ 

164 Calculate approximate entropy (ApEn) of the time series. 

165  

166 Approximate entropy is a regularity statistic that quantifies the  

167 unpredictability of fluctuations in a time series. A time series  

168 containing many repetitive patterns has a relatively small ApEn;  

169 a less predictable process has a higher ApEn. 

170  

171 Parameters 

172 ---------- 

173 m : int, optional 

174 Pattern length. Common values are 1 or 2. Default is 2. 

175 r : float, optional 

176 Tolerance threshold for pattern matching. If None, defaults to 

177 0.2 times the standard deviation of the data. 

178  

179 Returns 

180 ------- 

181 float 

182 The approximate entropy value. Higher values indicate more  

183 randomness/complexity. 

184  

185 Raises 

186 ------ 

187 ValueError 

188 If called on a discrete TimeSeries. 

189  

190 Notes 

191 ----- 

192 This method is only valid for continuous time series. For discrete 

193 time series, consider using other complexity measures. 

194  

195 Examples 

196 -------- 

197 >>> ts = TimeSeries(np.random.randn(1000), discrete=False) 

198 >>> apen = ts.approximate_entropy(m=2) 

199 >>> print(f"Approximate entropy: {apen:.3f}") 

200 """ 

201 if self.discrete: 

202 raise ValueError("approximate_entropy is only valid for continuous time series") 

203 

204 # Use lazy import to avoid circular imports 

205 from ..utils.signals import approximate_entropy 

206 

207 # Default r to 0.2 * std if not provided 

208 if r is None: 

209 r = 0.2 * np.std(self.data) 

210 

211 return approximate_entropy(self.data, m=m, r=r) 

212 

213 

214class MultiTimeSeries(MVData): 

215 """ 

216 MultiTimeSeries represents multiple aligned time series. 

217 Now inherits from MVData to enable direct dimensionality reduction. 

218 Supports either all-continuous or all-discrete components (no mixing). 

219 """ 

220 

221 def __init__(self, data_or_tslist, labels=None, distmat=None, rescale_rows=False, 

222 data_name=None, downsampling=None, discrete=None, shuffle_mask=None): 

223 # Handle both numpy array and list of TimeSeries inputs 

224 if isinstance(data_or_tslist, np.ndarray): 

225 # Direct numpy array input: each row is a time series 

226 if data_or_tslist.ndim != 2: 

227 raise ValueError("When providing numpy array, it must be 2D with shape (n_series, n_timepoints)") 

228 if discrete is None: 

229 raise ValueError("When providing numpy array, 'discrete' parameter must be specified") 

230 

231 # Set discrete flag early for numpy array input 

232 self.discrete = discrete 

233 

234 # Create TimeSeries objects from numpy array rows for processing 

235 tslist = [TimeSeries(data_or_tslist[i, :], discrete=discrete) for i in range(data_or_tslist.shape[0])] 

236 data = data_or_tslist 

237 

238 # Store provided shuffle_mask for later use (after combining with TimeSeries masks) 

239 self._provided_shuffle_mask = shuffle_mask 

240 else: 

241 # List of TimeSeries objects 

242 tslist = data_or_tslist 

243 self._check_input(tslist) 

244 # Stack data from all TimeSeries 

245 data = np.vstack([ts.data for ts in tslist]) 

246 

247 # Store provided shuffle_mask for later use 

248 self._provided_shuffle_mask = shuffle_mask 

249 

250 # Initialize MVData parent class 

251 super().__init__(data, labels=labels, distmat=distmat, 

252 rescale_rows=rescale_rows, data_name=data_name, 

253 downsampling=downsampling) 

254 

255 # Additional MultiTimeSeries specific attributes 

256 self.scdata = np.vstack([ts.scdata for ts in tslist]) 

257 

258 # Handle copula normal data for continuous components 

259 if not self.discrete: 

260 self.copula_normal_data = np.vstack([ts.copula_normal_data for ts in tslist]) 

261 else: 

262 # For discrete MultiTimeSeries, store integer data 

263 self.int_data = np.vstack([ts.int_data for ts in tslist]) 

264 self.copula_normal_data = None 

265 

266 # Combine shuffle masks 

267 if hasattr(self, '_provided_shuffle_mask') and self._provided_shuffle_mask is not None: 

268 # If shuffle_mask was provided explicitly, use it 

269 self.shuffle_mask = self._provided_shuffle_mask 

270 if not np.any(self.shuffle_mask): 

271 warnings.warn('Provided shuffle_mask has no valid positions for shuffling!') 

272 else: 

273 # Otherwise, combine individual TimeSeries masks restrictively 

274 shuffle_masks = np.vstack([ts.shuffle_mask for ts in tslist]) 

275 # Restrictive combination: ALL masks must allow shuffling at a position 

276 self.shuffle_mask = np.all(shuffle_masks, axis=0) 

277 

278 # Check if the combined mask is problematic 

279 valid_positions = np.sum(self.shuffle_mask) 

280 total_positions = len(self.shuffle_mask) 

281 

282 if valid_positions == 0: 

283 raise ValueError(f'Combined shuffle_mask has NO valid positions for shuffling! ' 

284 f'This typically happens when combining many neurons with restrictive individual masks. ' 

285 f'Consider providing an explicit shuffle_mask parameter to MultiTimeSeries.') 

286 elif valid_positions < 0.1 * total_positions: 

287 warnings.warn(f'Combined shuffle_mask is extremely restrictive: only {valid_positions}/{total_positions} ' 

288 f'({100*valid_positions/total_positions:.1f}%) positions are valid for shuffling. ' 

289 f'This may cause issues with shuffle-based significance testing.') 

290 

291 self.entropy = dict() # supports various downsampling constants 

292 

293 @property 

294 def shape(self): 

295 """Return shape of the data for compatibility with numpy-like access.""" 

296 return self.data.shape 

297 

298 def _check_input(self, tslist): 

299 is_ts = np.array([isinstance(ts, TimeSeries) for ts in tslist]) 

300 if not np.all(is_ts): 

301 raise ValueError('Input to MultiTimeSeries must be iterable of TimeSeries') 

302 

303 # Check all TimeSeries have same length 

304 lengths = np.array([len(ts.data) for ts in tslist]) 

305 if not np.all(lengths == lengths[0]): 

306 raise ValueError('All TimeSeries must have the same length') 

307 

308 # Check all TimeSeries have same discrete/continuous type 

309 is_discrete = np.array([ts.discrete for ts in tslist]) 

310 if not (np.all(is_discrete) or np.all(~is_discrete)): 

311 raise ValueError('All components of MultiTimeSeries must be either continuous or discrete (no mixing)') 

312 

313 # Set discrete flag based on components 

314 self.discrete = is_discrete[0] 

315 

316 def get_entropy(self, ds=1): 

317 if ds not in self.entropy.keys(): 

318 self._compute_entropy(ds=ds) 

319 return self.entropy[ds] 

320 

321 def _compute_entropy(self, ds=1): 

322 if self.discrete: 

323 # All components are discrete - use joint discrete entropy 

324 self.entropy[ds] = entropy_d(self.int_data[:, ::ds]) 

325 else: 

326 # All continuous - use existing continuous entropy 

327 self.entropy[ds] = ent_g(self.data[:, ::ds]) 

328 

329 def filter(self, method='gaussian', **kwargs): 

330 """ 

331 Apply filtering to all time series components and return a new filtered MultiTimeSeries. 

332  

333 Parameters 

334 ---------- 

335 method : str 

336 Filtering method: 'gaussian', 'savgol', 'wavelet', or 'none' 

337 **kwargs : dict 

338 Method-specific parameters (see TimeSeries.filter for details) 

339  

340 Returns 

341 ------- 

342 MultiTimeSeries 

343 New MultiTimeSeries object with all components filtered 

344 """ 

345 from ..signals.neural_filtering import filter_neural_signals 

346 

347 # Apply filtering to all time series at once 

348 filtered_data = filter_neural_signals(self.data, method=method, **kwargs) 

349 

350 # Create new MultiTimeSeries from filtered data 

351 return MultiTimeSeries(filtered_data, labels=self.labels, 

352 rescale_rows=self.rescale_rows, 

353 data_name=self.data_name, discrete=self.discrete) 

354 

355 

356def get_stats_function(sname): 

357 try: 

358 return getattr(scipy.stats, sname) 

359 except AttributeError: 

360 raise ValueError(f"Metric '{sname}' not found in scipy.stats") 

361 

362 

363def calc_signal_ratio(binary_ts, continuous_ts): 

364 # Calculate average of continuous_ts when binary_ts is 1 or 0 

365 avg_on = np.mean(continuous_ts[binary_ts == 1]) 

366 avg_off = np.mean(continuous_ts[binary_ts == 0]) 

367 

368 # Calculate ratio (handle division by zero) 

369 if avg_off == 0: 

370 return np.inf if avg_on != 0 else np.nan 

371 

372 return avg_on / avg_off 

373 

374 

375def get_sim(x, y, metric, shift=0, ds=1, k=5, estimator='gcmi', check_for_coincidence=False): 

376 """Computes similarity between two (possibly multidimensional) variables efficiently 

377 

378 Parameters 

379 ---------- 

380 x: TimeSeries/MultiTimeSeries instance or numpy array 

381 

382 y: TimeSeries/MultiTimeSeries instance or numpy array 

383 

384 metric: similarity metric between time series 

385 

386 shift: int 

387 y will be roll-moved by the number 'shift' after downsampling by 'ds' factor 

388 

389 ds: int 

390 downsampling constant (take every 'ds'-th point) 

391 

392 Returns 

393 ------- 

394 me: similarity metric between x and (possibly) shifted y 

395 

396 """ 

397 def _check_input(ts): 

398 if not isinstance(ts, TimeSeries) and not isinstance(ts, MultiTimeSeries): 

399 if np.ndim(ts) == 1: 

400 ts = TimeSeries(ts) 

401 else: 

402 raise Exception('Multidimensional inputs must be provided as MultiTimeSeries') 

403 return ts 

404 

405 ts1 = _check_input(x) 

406 ts2 = _check_input(y) 

407 

408 if metric == 'mi': 

409 me = get_mi(ts1, ts2, shift=shift, ds=ds, k=k, estimator=estimator, 

410 check_for_coincidence=check_for_coincidence) 

411 

412 else: 

413 if isinstance(ts1, TimeSeries) and isinstance(ts2, TimeSeries): 

414 if not ts1.discrete and not ts2.discrete: 

415 if metric == 'fast_pearsonr': 

416 x = ts1.data[::ds] 

417 y = np.roll(ts2.data[::ds], shift) 

418 me = correlation_matrix(np.vstack([x, y]))[0, 1] 

419 else: 

420 metric_func = get_stats_function(metric) 

421 me = metric_func(ts1.data[::ds], np.roll(ts2.data[::ds], shift))[0] 

422 

423 if ts1.discrete and not ts2.discrete: 

424 if metric == 'av': 

425 if ts1.is_binary: 

426 me = calc_signal_ratio(ts1.data[::ds], np.roll(ts2.data[::ds], shift)) 

427 else: 

428 raise ValueError(f'Discrete ts must be binary for metric={metric}') 

429 else: 

430 raise ValueError("Only 'av' and 'mi' metrics are supported for binary-continuous similarity") 

431 

432 if ts2.discrete and not ts1.discrete: 

433 if metric == 'av': 

434 if ts2.is_binary: 

435 me = calc_signal_ratio(ts2.data[::ds], np.roll(ts1.data[::ds], shift)) 

436 else: 

437 raise ValueError(f'Discrete ts must be binary for metric={metric}') 

438 else: 

439 raise ValueError("Only 'av' and 'mi' metrics are supported for binary-continuous similarity") 

440 

441 if ts2.discrete and ts1.discrete: 

442 raise ValueError(f'Metric={metric} is not supported for two discrete ts') 

443 

444 else: 

445 raise Exception("Metrics except 'mi' are not supported for multi-dimensional data") 

446 

447 return me 

448 

449 

450def get_mi(x, y, shift=0, ds=1, k=5, estimator='gcmi', check_for_coincidence=False): 

451 """Computes mutual information between two (possibly multidimensional) variables efficiently 

452 

453 Parameters 

454 ---------- 

455 x: TimeSeries/MultiTimeSeries instance or numpy array 

456 y: TimeSeries/MultiTimeSeries instance or numpy array 

457 shift: int 

458 y will be roll-moved by the number 'shift' after downsampling by 'ds' factor 

459 ds: int 

460 downsampling constant (take every 'ds'-th point) 

461 k: int 

462 number of neighbors for ksg estimator 

463 estimator: str 

464 Estimation method. Should be 'ksg' (accurate but slow) and 'gcmi' (fast, but estimates the lower bound on MI). 

465 In most cases 'gcmi' should be preferred. 

466 

467 Returns 

468 ------- 

469 mi: mutual information (or its lower bound in case of 'gcmi' estimator) between x and (possibly) shifted y 

470 

471 """ 

472 

473 def _check_input(ts): 

474 if not isinstance(ts, TimeSeries) and not isinstance(ts, MultiTimeSeries): 

475 if np.ndim(ts) == 1: 

476 ts = TimeSeries(ts) 

477 else: 

478 raise Exception('Multidimensional inputs must be provided as MultiTimeSeries') 

479 return ts 

480 

481 def multi_single_mi(mts, ts, ds=1, k=5, estimator='gcmi'): 

482 if estimator == 'ksg': 

483 raise NotImplementedError('KSG estimator is not supported for dim>1 yet') 

484 

485 # Safety check: if single TimeSeries data is contained in MultiTimeSeries 

486 # This should not happen due to aggregate_multiple_ts adding noise, but add as safety net 

487 if not ts.discrete and shift == 0: 

488 # Check if any row of the MultiTimeSeries matches the single TimeSeries 

489 for i in range(mts.data.shape[0]): 

490 if np.allclose(mts.data[i, ::ds], ts.data[::ds], rtol=1e-10, atol=1e-10): 

491 warnings.warn('MI computation between MultiTimeSeries containing identical data detected, returning 0') 

492 return 0.0 

493 

494 if ts.discrete: 

495 ny1 = np.roll(mts.copula_normal_data[:, ::ds], shift) 

496 # Ensure ny1 is contiguous for better performance with Numba 

497 if not ny1.flags['C_CONTIGUOUS']: 

498 ny1 = np.ascontiguousarray(ny1) 

499 ny2 = ts.int_data[::ds] 

500 mi = mi_model_gd(ny1, ny2, np.max(ny2), biascorrect=True, demeaned=True) 

501 

502 else: 

503 ny1 = mts.copula_normal_data[:, ::ds] 

504 ny2 = np.roll(ts.copula_normal_data[::ds], shift) 

505 mi = mi_gg(ny1, ny2, True, True) 

506 

507 return mi 

508 

509 def multi_multi_mi(mts1, mts2, ds=1, k=5, estimator='gcmi', check_for_coincidence=False): 

510 if estimator == 'ksg': 

511 raise NotImplementedError('KSG estimator is not supported for dim>1 yet') 

512 

513 if check_for_coincidence: 

514 if np.allclose(ts1.data, ts2.data) and shift == 0: # and not (ts1.discrete and ts2.discrete): 

515 warnings.warn('MI computation of a MultiTimeSeries with itself is meaningless, 0 will be returned forcefully') 

516 # raise ValueError('MI(X,X) computation for continuous variable X should give an infinite result') 

517 return 0.0 

518 

519 if mts1.discrete or mts2.discrete: 

520 raise NotImplementedError('MI computation between MultiTimeSeries\ 

521 is currently supported for continuous data only') 

522 

523 else: 

524 ny1 = mts1.copula_normal_data[:, ::ds] 

525 ny2 = np.roll(mts2.copula_normal_data[:, ::ds], shift, axis=1) 

526 mi = mi_gg(ny1, ny2, True, True) 

527 

528 return mi 

529 

530 ts1 = _check_input(x) 

531 ts2 = _check_input(y) 

532 

533 if isinstance(ts1, TimeSeries) and isinstance(ts2, TimeSeries): 

534 mi = get_1d_mi(x, y, shift=shift, ds=ds, k=k, estimator=estimator, 

535 check_for_coincidence=check_for_coincidence) 

536 

537 if isinstance(ts1, MultiTimeSeries) and isinstance(ts2, TimeSeries): 

538 mi = multi_single_mi(ts1, ts2, ds=ds, k=k, estimator=estimator) 

539 

540 if isinstance(ts2, MultiTimeSeries) and isinstance(ts1, TimeSeries): 

541 mi = multi_single_mi(ts2, ts1, ds=ds, k=k, estimator=estimator) 

542 

543 if isinstance(ts1, MultiTimeSeries) and isinstance(ts2, MultiTimeSeries): 

544 mi = multi_multi_mi(ts1, ts2, ds=ds, k=k, estimator=estimator, 

545 check_for_coincidence=check_for_coincidence) 

546 #raise NotImplementedError('MI computation between two MultiTimeSeries is not supported yet') 

547 

548 if mi < 0: 

549 mi = 0 

550 

551 return mi 

552 

553 

554def get_1d_mi(ts1, ts2, shift=0, ds=1, k=5, estimator='gcmi', check_for_coincidence=True): 

555 """Computes mutual information between two 1d variables efficiently 

556 

557 Parameters 

558 ---------- 

559 ts1: TimeSeries/MultiTimeSeries instance or numpy array 

560 ts2: TimeSeries/MultiTimeSeries instance or numpy array 

561 shift: int 

562 ts2 will be roll-moved by the number 'shift' after downsampling by 'ds' factor 

563 ds: int 

564 downsampling constant (take every 'ds'-th point) 

565 k: int 

566 number of neighbors for ksg estimator 

567 estimator: str 

568 Estimation method. Should be 'ksg' (accurate but slow) and 'gcmi' (fast, but estimates the lower bound on MI). 

569 In most cases 'gcmi' should be preferred. 

570 check_for_coincidence : bool, optional 

571 If True, raises error when computing MI of a signal with itself at zero shift. Default: True. 

572 

573 Returns 

574 ------- 

575 mi: mutual information (or its lower bound in case of 'gcmi' estimator) between ts1 and (possibly) shifted ts2 

576 

577 """ 

578 

579 def _check_input(ts): 

580 if not isinstance(ts, TimeSeries) and not isinstance(ts, MultiTimeSeries): 

581 if np.ndim(ts) == 1: 

582 ts = TimeSeries(ts) 

583 else: 

584 raise Exception('Multidimensional inputs must be provided as MultiTimeSeries') 

585 return ts 

586 

587 ts1 = _check_input(ts1) 

588 ts2 = _check_input(ts2) 

589 

590 if check_for_coincidence and ts1.data.shape == ts2.data.shape: 

591 if np.allclose(ts1.data, ts2.data) and shift == 0: #and not (ts1.discrete and ts2.discrete): 

592 raise ValueError('MI computation of a TimeSeries or MultiTimeSeries with itself is not allowed') 

593 #raise ValueError('MI(X,X) computation for continuous variable X should give an infinite result') 

594 

595 if estimator == 'ksg': 

596 #TODO: add shifts everywhere in this branch 

597 x = ts1.data[::ds].reshape(-1, 1) 

598 y = ts2.data[::ds] 

599 if shift != 0: 

600 y = np.roll(y, shift) 

601 

602 if not ts1.discrete and not ts2.discrete: 

603 mi = nonparam_mi_cc_mod(ts1.data, y, k=k, 

604 precomputed_tree_x=ts1.get_kdtree(), 

605 precomputed_tree_y=ts2.get_kdtree()) 

606 

607 elif ts1.discrete and ts2.discrete: 

608 mi = mutual_info_classif(ts1.int_data[::ds].reshape(-1, 1), 

609 ts2.int_data[::ds], 

610 discrete_features=True, 

611 n_neighbors=k)[0] 

612 

613 # TODO: refactor using ksg functions 

614 elif ts1.discrete and not ts2.discrete: 

615 mi = mutual_info_regression(ts1.int_data[::ds], 

616 y[::ds], 

617 discrete_features=False, 

618 n_neighbors=k)[0] 

619 

620 elif not ts1.discrete and ts2.discrete: 

621 mi = mutual_info_classif(x[::ds], 

622 ts2.int_data[::ds], 

623 discrete_features=True, 

624 n_neighbors=k)[0] 

625 

626 return mi 

627 

628 elif estimator == 'gcmi': 

629 if not ts1.discrete and not ts2.discrete: 

630 ny1 = ts1.copula_normal_data[::ds] 

631 ny2 = np.roll(ts2.copula_normal_data[::ds], shift) 

632 mi = mi_gg(ny1, ny2, True, True) 

633 

634 elif ts1.discrete and ts2.discrete: 

635 # if features are binary: 

636 if ts1.is_binary and ts2.is_binary: 

637 ny1 = ts1.bool_data[::ds] 

638 ny2 = np.roll(ts2.bool_data[::ds], shift) 

639 

640 contingency = np.zeros((2, 2)) 

641 contingency[0, 0] = (ny1 & ny2).sum() 

642 contingency[0, 1] = (~ny1 & ny2).sum() 

643 contingency[1, 0] = (ny1 & ~ny2).sum() 

644 contingency[1, 1] = (~ny1 & ~ny2).sum() 

645 

646 mi = binary_mi_score(contingency) 

647 

648 else: 

649 ny1 = ts1.int_data[::ds] # .reshape(-1, 1) 

650 ny2 = np.roll(ts2.int_data[::ds], shift) 

651 mi = mutual_info_score(ny1, ny2) 

652 # Ensure float type for consistency 

653 mi = float(mi) 

654 

655 elif ts1.discrete and not ts2.discrete: 

656 ny1 = ts1.int_data[::ds] 

657 ny2 = np.roll(ts2.copula_normal_data[::ds], shift) 

658 # Ensure ny2 is contiguous for better performance with Numba 

659 if not ny2.flags['C_CONTIGUOUS']: 

660 ny2 = np.ascontiguousarray(ny2) 

661 mi = mi_model_gd(ny2, ny1, np.max(ny1), biascorrect=True, demeaned=True) 

662 

663 elif not ts1.discrete and ts2.discrete: 

664 ny1 = ts1.copula_normal_data[::ds] 

665 #TODO: fix zd error 

666 ny2 = np.roll(ts2.int_data[::ds], shift) 

667 #ny2 = np.roll(ts2.data[::ds], shift) 

668 # Ensure ny1 is contiguous for better performance with Numba 

669 if not ny1.flags['C_CONTIGUOUS']: 

670 ny1 = np.ascontiguousarray(ny1) 

671 ''' 

672 print(ny2) 

673 print(sum(ny2)) 

674 print(ny1) 

675 ''' 

676 # Ensure ny1 is contiguous for better performance with Numba 

677 if not ny1.flags['C_CONTIGUOUS']: 

678 ny1 = np.ascontiguousarray(ny1) 

679 mi = mi_model_gd(ny1, ny2, np.max(ny2), biascorrect=True, demeaned=True) 

680 

681 if mi < 0: 

682 mi = 0.0 

683 

684 return mi 

685 

686 

687def get_tdmi(data, min_shift=1, max_shift=100, nn=DEFAULT_NN): 

688 ts = TimeSeries(data, discrete=False) 

689 tdmi = [get_1d_mi(ts, ts, shift=shift, k=nn) for shift in range(min_shift, max_shift)] 

690 

691 return tdmi 

692 

693 

694def get_multi_mi(tslist, ts2, shift=0, ds=1, k=DEFAULT_NN, estimator='gcmi'): 

695 

696 #TODO: make shift the same as in get_1d_mi 

697 if ~np.all([ts.discrete for ts in tslist]) and not ts2.discrete: 

698 nylist = [ts.copula_normal_data[::ds] for ts in tslist] 

699 ny1 = np.vstack(nylist) 

700 ny2 = np.roll(ts2.copula_normal_data, shift)[::ds] 

701 mi = mi_gg(ny1, ny2, True, True) 

702 else: 

703 raise ValueError('Multidimensional MI only implemented for continuous data!') 

704 

705 if mi < 0: 

706 mi = 0 

707 

708 return mi 

709 

710 

711def aggregate_multiple_ts(*ts_args, noise=1e-5): 

712 """Aggregate multiple continuous TimeSeries into a single MultiTimeSeries. 

713  

714 Adds small noise to break degeneracy and creates a MultiTimeSeries from 

715 the input TimeSeries objects. 

716  

717 Parameters 

718 ---------- 

719 *ts_args : TimeSeries 

720 Variable number of TimeSeries objects to aggregate. 

721 noise : float, optional 

722 Amount of noise to add to break degeneracy. Default: 1e-5. 

723  

724 Returns 

725 ------- 

726 MultiTimeSeries 

727 Aggregated multi-dimensional time series. 

728  

729 Raises 

730 ------ 

731 ValueError 

732 If any input TimeSeries is discrete. 

733  

734 Examples 

735 -------- 

736 >>> ts1 = TimeSeries(np.random.randn(100), discrete=False) 

737 >>> ts2 = TimeSeries(np.random.randn(100), discrete=False) 

738 >>> mts = aggregate_multiple_ts(ts1, ts2) 

739 """ 

740 # add small noise to break degeneracy 

741 mod_tslist = [] 

742 for ts in ts_args: 

743 if ts.discrete: 

744 raise ValueError('this is not applicable to discrete TimeSeries') 

745 mod_ts = TimeSeries(ts.data + np.random.random(size=len(ts.data)) * noise, discrete=False) 

746 mod_tslist.append(mod_ts) 

747 

748 mts = MultiTimeSeries(mod_tslist) # add last two TS into a single 2-d MTS 

749 return mts 

750 

751 

752def conditional_mi(ts1, ts2, ts3, ds=1, k=5): 

753 """Calculate conditional mutual information I(X;Y|Z). 

754  

755 Computes the conditional mutual information between ts1 (X) and ts2 (Y) 

756 given ts3 (Z) for various combinations of continuous and discrete variables. 

757  

758 Parameters 

759 ---------- 

760 ts1 : TimeSeries 

761 First variable (X). Must be continuous. 

762 ts2 : TimeSeries 

763 Second variable (Y). Can be continuous or discrete. 

764 ts3 : TimeSeries 

765 Conditioning variable (Z). Can be continuous or discrete. 

766 ds : int, optional 

767 Downsampling factor. Default: 1. 

768 k : int, optional 

769 Number of neighbors for entropy estimation. Default: 5. 

770  

771 Returns 

772 ------- 

773 float 

774 Conditional mutual information I(X;Y|Z) in bits. 

775  

776 Raises 

777 ------ 

778 ValueError 

779 If ts1 is discrete (only continuous X is currently supported). 

780  

781 Notes 

782 ----- 

783 Supports four cases: 

784 - CCC: All continuous - uses Gaussian copula 

785 - CCD: X,Y continuous, Z discrete - uses Gaussian copula per Z value 

786 - CDC: X,Z continuous, Y discrete - uses chain rule identity 

787 - CDD: X continuous, Y,Z discrete - uses entropy decomposition 

788  

789 For the CDD case, GCMI estimator has limitations due to uncontrollable  

790 biases (copula transform does not conserve entropy). See  

791 https://doi.org/10.1002/hbm.23471 for details. 

792 """ 

793 if ts1.discrete: 

794 raise ValueError('conditional MI(X,Y|Z) is currently implemented for continuous X only') 

795 

796 #print(ts1.discrete, ts2.discrete, ts3.discrete) 

797 if not ts2.discrete and not ts3.discrete: 

798 # CCC: All continuous 

799 g1 = ts1.copula_normal_data[::ds] 

800 g2 = ts2.copula_normal_data[::ds] 

801 g3 = ts3.copula_normal_data[::ds] 

802 cmi = cmi_ggg(g1, g2, g3, biascorrect=True, demeaned=True) 

803 

804 elif not ts2.discrete and ts3.discrete: 

805 # CCD: X,Y continuous, Z discrete 

806 unique_discrete_vals = np.unique(ts3.int_data[::ds]) 

807 cmi = gccmi_ccd(ts1.data[::ds], 

808 ts2.data[::ds], 

809 ts3.int_data[::ds], 

810 len(unique_discrete_vals)) 

811 

812 elif ts2.discrete and not ts3.discrete: 

813 # CDC: X,Z continuous, Y discrete 

814 # Use entropy-based identity: I(X;Y|Z) = H(X|Z) - H(X|Y,Z) 

815 # This avoids mixing different MI estimators that cause bias inconsistency 

816 

817 # H(X|Z) for continuous X,Z using GCMI 

818 x_data = ts1.data[::ds].reshape(1, -1) 

819 z_data = ts3.data[::ds].reshape(1, -1) 

820 

821 # Joint data for H(X,Z) and marginal H(Z) 

822 xz_joint = np.vstack([x_data, z_data]) 

823 H_xz = ent_g(xz_joint, biascorrect=True) 

824 H_z = ent_g(z_data, biascorrect=True) 

825 H_x_given_z = H_xz - H_z 

826 

827 # H(X|Y,Z) - conditional entropy of X given both Y (discrete) and Z (continuous) 

828 unique_y_vals = np.unique(ts2.int_data[::ds]) 

829 H_x_given_yz = 0.0 

830 

831 for y_val in unique_y_vals: 

832 # Find indices where Y = y_val 

833 y_mask = (ts2.int_data[::ds] == y_val) 

834 n_y = np.sum(y_mask) 

835 

836 if n_y > 2: # Need sufficient samples for entropy estimation 

837 # Extract X,Z values for this Y group 

838 x_subset = x_data[:, y_mask] 

839 z_subset = z_data[:, y_mask] 

840 

841 # Joint entropy H(X,Z|Y=y_val) 

842 xz_subset = np.vstack([x_subset, z_subset]) 

843 H_xz_given_y = ent_g(xz_subset, biascorrect=True) 

844 

845 # Marginal entropy H(Z|Y=y_val) 

846 H_z_given_y = ent_g(z_subset, biascorrect=True) 

847 

848 # Conditional entropy H(X|Z,Y=y_val) = H(X,Z|Y=y_val) - H(Z|Y=y_val) 

849 H_x_given_z_y = H_xz_given_y - H_z_given_y 

850 

851 # Weight by probability P(Y=y_val) 

852 p_y = n_y / len(ts2.int_data[::ds]) 

853 H_x_given_yz += p_y * H_x_given_z_y 

854 

855 # Final CMI calculation 

856 cmi = H_x_given_z - H_x_given_yz 

857 

858 # Ensure CMI >= 0 due to information theory constraint 

859 # Small negative values are due to numerical precision and estimation noise 

860 if cmi < 0 and abs(cmi) < 0.01: 

861 cmi = 0.0 

862 

863 else: 

864 # CDD: X continuous, Y,Z discrete 

865 # Here we use the identity I(X;Y|Z) = H(X,Z) + H(Y,Z) - H(X,Y,Z) - H(Z) 

866 ''' 

867 # TODO: check this 

868 # Note that GCMI estimator is poorly applicable here because of the uncontrollable biases: 

869 # GCMI correctly estimates the lower bound on MI, but copula transform does not conserve the entropy 

870 # See https://doi.org/10.1002/hbm.23471 for further details 

871 # Therefore, joint entropy estimation relies on ksg estimator instead 

872 ''' 

873 # Note: Original code used copula_normal_data, but our entropy functions expect raw data 

874 # Using data instead of copula_normal_data for consistency with entropy functions 

875 H_xz = joint_entropy_cd(ts3.int_data[::ds], ts1.data[::ds], k=k) 

876 H_yz = joint_entropy_dd(ts2.int_data[::ds], ts3.int_data[::ds]) 

877 H_xyz = joint_entropy_cdd(ts2.int_data[::ds], ts3.int_data[::ds], ts1.data[::ds], k=k) 

878 H_z = entropy_d(ts3.int_data[::ds]) 

879 #print('entropies:', H_xz, H_yz, H_xyz, H_z) 

880 cmi = H_xz + H_yz - H_xyz - H_z 

881 

882 return cmi 

883 

884 

885def interaction_information(ts1, ts2, ts3, ds=1, k=5): 

886 """Calculate three-way interaction information II(X;Y;Z). 

887  

888 The interaction information quantifies the amount of information 

889 that is shared among all three variables. It can be positive (synergy) 

890 or negative (redundancy). 

891  

892 Parameters 

893 ---------- 

894 ts1 : TimeSeries 

895 First variable (X). Must be continuous. 

896 ts2 : TimeSeries 

897 Second variable (Y). Can be continuous or discrete. 

898 ts3 : TimeSeries 

899 Third variable (Z). Can be continuous or discrete. 

900 ds : int, optional 

901 Downsampling factor. Default: 1. 

902 k : int, optional 

903 Number of neighbors for entropy estimation. Default: 5. 

904  

905 Returns 

906 ------- 

907 float 

908 Interaction information II(X;Y;Z) in bits. 

909 - II < 0: Redundancy (Y and Z provide overlapping information about X) 

910 - II > 0: Synergy (Y and Z together provide more information than separately) 

911  

912 Notes 

913 ----- 

914 The interaction information is computed using Williams & Beer convention: 

915 II(X;Y;Z) = I(X;Y|Z) - I(X;Y) = I(X;Z|Y) - I(X;Z) 

916  

917 This implementation assumes X is the target variable (e.g., neural activity) 

918 and Y, Z are predictor variables (e.g., behavioral features). 

919 """ 

920 # Compute pairwise mutual information 

921 mi_xy = get_mi(ts1, ts2, ds=ds) 

922 mi_xz = get_mi(ts1, ts3, ds=ds) 

923 

924 # Compute conditional mutual information 

925 cmi_xy_given_z = conditional_mi(ts1, ts2, ts3, ds=ds, k=k) 

926 cmi_xz_given_y = conditional_mi(ts1, ts3, ts2, ds=ds, k=k) 

927 

928 # Compute interaction information (should be the same from both formulas) 

929 # Using Williams & Beer convention: II = I(X;Y|Z) - I(X;Y) 

930 # This gives negative II for redundancy and positive II for synergy 

931 ii_1 = cmi_xy_given_z - mi_xy 

932 ii_2 = cmi_xz_given_y - mi_xz 

933 

934 # Average for numerical stability 

935 ii = (ii_1 + ii_2) / 2.0 

936 

937 return ii