Coverage for src/driada/information/info_base.py: 72.49%
389 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
2from sklearn.metrics.cluster import mutual_info_score
3import scipy
5from .ksg import *
6from .gcmi import *
7from .info_utils import binary_mi_score
8from ..utils.data import correlation_matrix
9from .entropy import entropy_d, joint_entropy_dd, joint_entropy_cd, joint_entropy_cdd
10from ..dim_reduction.data import MVData
12import numpy as np
13import warnings
14from typing import Optional
15from sklearn.preprocessing import MinMaxScaler
16from scipy.stats import entropy, differential_entropy
18from ..utils.data import to_numpy_array
21DEFAULT_NN = 5
23# TODO: add @property decorators to properly set getter-setter functionality
26class TimeSeries():
27 @staticmethod
28 def define_ts_type(ts):
29 if len(ts) < 100:
30 warnings.warn('Time series is too short for accurate type (discrete/continuous) determination')
32 unique_vals = np.unique(ts)
33 sc1 = len(unique_vals) / len(ts)
34 hist = np.histogram(ts, bins=len(ts))[0]
35 ent = entropy(hist)
36 maxent = entropy(np.ones(len(ts)))
37 sc2 = ent / maxent
39 # TODO: refactor thresholds
40 if sc1 > 0.70 and sc2 > 0.70:
41 return False # both scores are high - the variable is most probably continuous
42 elif sc1 < 0.25 and sc2 < 0.25:
43 return True # both scores are low - the variable is most probably discrete
44 else:
45 raise ValueError(f'Unable to determine time series type automatically: score 1 = {sc1}, score 2 = {sc2}')
47 # TODO: complete this function
48 def _check_input(self):
49 pass
51 def __init__(self, data, discrete=None, shuffle_mask=None):
52 self.data = to_numpy_array(data)
54 if discrete is None:
55 #warnings.warn('Time series type not specified and will be inferred automatically')
56 self.discrete = TimeSeries.define_ts_type(self.data)
57 else:
58 self.discrete = discrete
60 scaler = MinMaxScaler()
61 self.scdata = scaler.fit_transform(self.data.reshape(-1, 1)).reshape(1, -1)[0]
62 self.data_scale = scaler.scale_
63 self.copula_normal_data = None
65 if self.discrete:
66 self.int_data = np.round(self.data).astype(int)
67 if len(set(self.data.astype(int))) == 2:
68 self.is_binary = True
69 self.bool_data = self.int_data.astype(bool)
70 else:
71 self.is_binary = False
73 else:
74 self.copula_normal_data = copnorm(self.data).ravel()
76 self.entropy = dict() # supports various downsampling constants
77 self.kdtree = None
78 self.kdtree_query = None
80 if shuffle_mask is None:
81 # which shuffles are valid
82 self.shuffle_mask = np.ones(len(self.data)).astype(bool)
83 else:
84 self.shuffle_mask = shuffle_mask.astype(bool)
86 def get_kdtree(self):
87 if self.kdtree is None:
88 tree = self._compute_kdtree()
89 self.kdtree = tree
91 return self.kdtree
93 def _compute_kdtree(self):
94 d = self.data.reshape(self.data.shape[0], -1)
95 return build_tree(d)
97 def get_kdtree_query(self, k=DEFAULT_NN):
98 if self.kdtree_query is None:
99 q = self._compute_kdtree_query(k=k)
100 self.kdtree_query = q
102 return self.kdtree_query
104 def _compute_kdtree_query(self, k=DEFAULT_NN):
105 tree = self.get_kdtree()
106 return tree.query(self.data, k=k + 1)
108 def get_entropy(self, ds=1):
109 if ds not in self.entropy.keys():
110 self._compute_entropy(ds=ds)
111 return self.entropy[ds]
113 def _compute_entropy(self, ds=1):
114 if self.discrete:
115 # TODO: rewrite this using int_data and via ent_d from driada.information.entropy
116 counts = []
117 for val in np.unique(self.data[::ds]):
118 counts.append(len(np.where(self.data[::ds] == val)[0]))
120 self.entropy[ds] = entropy(counts, base=np.e)
122 else:
123 self.entropy[ds] = nonparam_entropy_c(self.data) / np.log(2)
124 #self.entropy[ds] = get_tdmi(self.scdata[::ds], min_shift=1, max_shift=2)[0]
125 #raise AttributeError('Entropy for continuous variables is not yet implemented'
127 def filter(self, method='gaussian', **kwargs):
128 """
129 Apply filtering to the time series and return a new filtered TimeSeries.
131 Parameters
132 ----------
133 method : str
134 Filtering method: 'gaussian', 'savgol', 'wavelet', or 'none'
135 **kwargs : dict
136 Method-specific parameters:
137 - gaussian: sigma (default: 1.0)
138 - savgol: window_length (default: 5), polyorder (default: 2)
139 - wavelet: wavelet (default: 'db4'), level (default: None)
141 Returns
142 -------
143 TimeSeries
144 New TimeSeries object with filtered data
145 """
146 from ..utils.signals import filter_1d_timeseries
148 if method == 'none':
149 return TimeSeries(self.data.copy(), discrete=self.discrete,
150 shuffle_mask=self.shuffle_mask.copy())
152 if self.discrete:
153 warnings.warn("Filtering discrete time series may produce unexpected results")
155 # Apply filtering to 1D time series
156 filtered_data = filter_1d_timeseries(self.data, method=method, **kwargs)
158 # Create new TimeSeries with filtered data
159 return TimeSeries(filtered_data, discrete=self.discrete,
160 shuffle_mask=self.shuffle_mask.copy())
162 def approximate_entropy(self, m: int = 2, r: Optional[float] = None) -> float:
163 """
164 Calculate approximate entropy (ApEn) of the time series.
166 Approximate entropy is a regularity statistic that quantifies the
167 unpredictability of fluctuations in a time series. A time series
168 containing many repetitive patterns has a relatively small ApEn;
169 a less predictable process has a higher ApEn.
171 Parameters
172 ----------
173 m : int, optional
174 Pattern length. Common values are 1 or 2. Default is 2.
175 r : float, optional
176 Tolerance threshold for pattern matching. If None, defaults to
177 0.2 times the standard deviation of the data.
179 Returns
180 -------
181 float
182 The approximate entropy value. Higher values indicate more
183 randomness/complexity.
185 Raises
186 ------
187 ValueError
188 If called on a discrete TimeSeries.
190 Notes
191 -----
192 This method is only valid for continuous time series. For discrete
193 time series, consider using other complexity measures.
195 Examples
196 --------
197 >>> ts = TimeSeries(np.random.randn(1000), discrete=False)
198 >>> apen = ts.approximate_entropy(m=2)
199 >>> print(f"Approximate entropy: {apen:.3f}")
200 """
201 if self.discrete:
202 raise ValueError("approximate_entropy is only valid for continuous time series")
204 # Use lazy import to avoid circular imports
205 from ..utils.signals import approximate_entropy
207 # Default r to 0.2 * std if not provided
208 if r is None:
209 r = 0.2 * np.std(self.data)
211 return approximate_entropy(self.data, m=m, r=r)
214class MultiTimeSeries(MVData):
215 """
216 MultiTimeSeries represents multiple aligned time series.
217 Now inherits from MVData to enable direct dimensionality reduction.
218 Supports either all-continuous or all-discrete components (no mixing).
219 """
221 def __init__(self, data_or_tslist, labels=None, distmat=None, rescale_rows=False,
222 data_name=None, downsampling=None, discrete=None, shuffle_mask=None):
223 # Handle both numpy array and list of TimeSeries inputs
224 if isinstance(data_or_tslist, np.ndarray):
225 # Direct numpy array input: each row is a time series
226 if data_or_tslist.ndim != 2:
227 raise ValueError("When providing numpy array, it must be 2D with shape (n_series, n_timepoints)")
228 if discrete is None:
229 raise ValueError("When providing numpy array, 'discrete' parameter must be specified")
231 # Set discrete flag early for numpy array input
232 self.discrete = discrete
234 # Create TimeSeries objects from numpy array rows for processing
235 tslist = [TimeSeries(data_or_tslist[i, :], discrete=discrete) for i in range(data_or_tslist.shape[0])]
236 data = data_or_tslist
238 # Store provided shuffle_mask for later use (after combining with TimeSeries masks)
239 self._provided_shuffle_mask = shuffle_mask
240 else:
241 # List of TimeSeries objects
242 tslist = data_or_tslist
243 self._check_input(tslist)
244 # Stack data from all TimeSeries
245 data = np.vstack([ts.data for ts in tslist])
247 # Store provided shuffle_mask for later use
248 self._provided_shuffle_mask = shuffle_mask
250 # Initialize MVData parent class
251 super().__init__(data, labels=labels, distmat=distmat,
252 rescale_rows=rescale_rows, data_name=data_name,
253 downsampling=downsampling)
255 # Additional MultiTimeSeries specific attributes
256 self.scdata = np.vstack([ts.scdata for ts in tslist])
258 # Handle copula normal data for continuous components
259 if not self.discrete:
260 self.copula_normal_data = np.vstack([ts.copula_normal_data for ts in tslist])
261 else:
262 # For discrete MultiTimeSeries, store integer data
263 self.int_data = np.vstack([ts.int_data for ts in tslist])
264 self.copula_normal_data = None
266 # Combine shuffle masks
267 if hasattr(self, '_provided_shuffle_mask') and self._provided_shuffle_mask is not None:
268 # If shuffle_mask was provided explicitly, use it
269 self.shuffle_mask = self._provided_shuffle_mask
270 if not np.any(self.shuffle_mask):
271 warnings.warn('Provided shuffle_mask has no valid positions for shuffling!')
272 else:
273 # Otherwise, combine individual TimeSeries masks restrictively
274 shuffle_masks = np.vstack([ts.shuffle_mask for ts in tslist])
275 # Restrictive combination: ALL masks must allow shuffling at a position
276 self.shuffle_mask = np.all(shuffle_masks, axis=0)
278 # Check if the combined mask is problematic
279 valid_positions = np.sum(self.shuffle_mask)
280 total_positions = len(self.shuffle_mask)
282 if valid_positions == 0:
283 raise ValueError(f'Combined shuffle_mask has NO valid positions for shuffling! '
284 f'This typically happens when combining many neurons with restrictive individual masks. '
285 f'Consider providing an explicit shuffle_mask parameter to MultiTimeSeries.')
286 elif valid_positions < 0.1 * total_positions:
287 warnings.warn(f'Combined shuffle_mask is extremely restrictive: only {valid_positions}/{total_positions} '
288 f'({100*valid_positions/total_positions:.1f}%) positions are valid for shuffling. '
289 f'This may cause issues with shuffle-based significance testing.')
291 self.entropy = dict() # supports various downsampling constants
293 @property
294 def shape(self):
295 """Return shape of the data for compatibility with numpy-like access."""
296 return self.data.shape
298 def _check_input(self, tslist):
299 is_ts = np.array([isinstance(ts, TimeSeries) for ts in tslist])
300 if not np.all(is_ts):
301 raise ValueError('Input to MultiTimeSeries must be iterable of TimeSeries')
303 # Check all TimeSeries have same length
304 lengths = np.array([len(ts.data) for ts in tslist])
305 if not np.all(lengths == lengths[0]):
306 raise ValueError('All TimeSeries must have the same length')
308 # Check all TimeSeries have same discrete/continuous type
309 is_discrete = np.array([ts.discrete for ts in tslist])
310 if not (np.all(is_discrete) or np.all(~is_discrete)):
311 raise ValueError('All components of MultiTimeSeries must be either continuous or discrete (no mixing)')
313 # Set discrete flag based on components
314 self.discrete = is_discrete[0]
316 def get_entropy(self, ds=1):
317 if ds not in self.entropy.keys():
318 self._compute_entropy(ds=ds)
319 return self.entropy[ds]
321 def _compute_entropy(self, ds=1):
322 if self.discrete:
323 # All components are discrete - use joint discrete entropy
324 self.entropy[ds] = entropy_d(self.int_data[:, ::ds])
325 else:
326 # All continuous - use existing continuous entropy
327 self.entropy[ds] = ent_g(self.data[:, ::ds])
329 def filter(self, method='gaussian', **kwargs):
330 """
331 Apply filtering to all time series components and return a new filtered MultiTimeSeries.
333 Parameters
334 ----------
335 method : str
336 Filtering method: 'gaussian', 'savgol', 'wavelet', or 'none'
337 **kwargs : dict
338 Method-specific parameters (see TimeSeries.filter for details)
340 Returns
341 -------
342 MultiTimeSeries
343 New MultiTimeSeries object with all components filtered
344 """
345 from ..signals.neural_filtering import filter_neural_signals
347 # Apply filtering to all time series at once
348 filtered_data = filter_neural_signals(self.data, method=method, **kwargs)
350 # Create new MultiTimeSeries from filtered data
351 return MultiTimeSeries(filtered_data, labels=self.labels,
352 rescale_rows=self.rescale_rows,
353 data_name=self.data_name, discrete=self.discrete)
356def get_stats_function(sname):
357 try:
358 return getattr(scipy.stats, sname)
359 except AttributeError:
360 raise ValueError(f"Metric '{sname}' not found in scipy.stats")
363def calc_signal_ratio(binary_ts, continuous_ts):
364 # Calculate average of continuous_ts when binary_ts is 1 or 0
365 avg_on = np.mean(continuous_ts[binary_ts == 1])
366 avg_off = np.mean(continuous_ts[binary_ts == 0])
368 # Calculate ratio (handle division by zero)
369 if avg_off == 0:
370 return np.inf if avg_on != 0 else np.nan
372 return avg_on / avg_off
375def get_sim(x, y, metric, shift=0, ds=1, k=5, estimator='gcmi', check_for_coincidence=False):
376 """Computes similarity between two (possibly multidimensional) variables efficiently
378 Parameters
379 ----------
380 x: TimeSeries/MultiTimeSeries instance or numpy array
382 y: TimeSeries/MultiTimeSeries instance or numpy array
384 metric: similarity metric between time series
386 shift: int
387 y will be roll-moved by the number 'shift' after downsampling by 'ds' factor
389 ds: int
390 downsampling constant (take every 'ds'-th point)
392 Returns
393 -------
394 me: similarity metric between x and (possibly) shifted y
396 """
397 def _check_input(ts):
398 if not isinstance(ts, TimeSeries) and not isinstance(ts, MultiTimeSeries):
399 if np.ndim(ts) == 1:
400 ts = TimeSeries(ts)
401 else:
402 raise Exception('Multidimensional inputs must be provided as MultiTimeSeries')
403 return ts
405 ts1 = _check_input(x)
406 ts2 = _check_input(y)
408 if metric == 'mi':
409 me = get_mi(ts1, ts2, shift=shift, ds=ds, k=k, estimator=estimator,
410 check_for_coincidence=check_for_coincidence)
412 else:
413 if isinstance(ts1, TimeSeries) and isinstance(ts2, TimeSeries):
414 if not ts1.discrete and not ts2.discrete:
415 if metric == 'fast_pearsonr':
416 x = ts1.data[::ds]
417 y = np.roll(ts2.data[::ds], shift)
418 me = correlation_matrix(np.vstack([x, y]))[0, 1]
419 else:
420 metric_func = get_stats_function(metric)
421 me = metric_func(ts1.data[::ds], np.roll(ts2.data[::ds], shift))[0]
423 if ts1.discrete and not ts2.discrete:
424 if metric == 'av':
425 if ts1.is_binary:
426 me = calc_signal_ratio(ts1.data[::ds], np.roll(ts2.data[::ds], shift))
427 else:
428 raise ValueError(f'Discrete ts must be binary for metric={metric}')
429 else:
430 raise ValueError("Only 'av' and 'mi' metrics are supported for binary-continuous similarity")
432 if ts2.discrete and not ts1.discrete:
433 if metric == 'av':
434 if ts2.is_binary:
435 me = calc_signal_ratio(ts2.data[::ds], np.roll(ts1.data[::ds], shift))
436 else:
437 raise ValueError(f'Discrete ts must be binary for metric={metric}')
438 else:
439 raise ValueError("Only 'av' and 'mi' metrics are supported for binary-continuous similarity")
441 if ts2.discrete and ts1.discrete:
442 raise ValueError(f'Metric={metric} is not supported for two discrete ts')
444 else:
445 raise Exception("Metrics except 'mi' are not supported for multi-dimensional data")
447 return me
450def get_mi(x, y, shift=0, ds=1, k=5, estimator='gcmi', check_for_coincidence=False):
451 """Computes mutual information between two (possibly multidimensional) variables efficiently
453 Parameters
454 ----------
455 x: TimeSeries/MultiTimeSeries instance or numpy array
456 y: TimeSeries/MultiTimeSeries instance or numpy array
457 shift: int
458 y will be roll-moved by the number 'shift' after downsampling by 'ds' factor
459 ds: int
460 downsampling constant (take every 'ds'-th point)
461 k: int
462 number of neighbors for ksg estimator
463 estimator: str
464 Estimation method. Should be 'ksg' (accurate but slow) and 'gcmi' (fast, but estimates the lower bound on MI).
465 In most cases 'gcmi' should be preferred.
467 Returns
468 -------
469 mi: mutual information (or its lower bound in case of 'gcmi' estimator) between x and (possibly) shifted y
471 """
473 def _check_input(ts):
474 if not isinstance(ts, TimeSeries) and not isinstance(ts, MultiTimeSeries):
475 if np.ndim(ts) == 1:
476 ts = TimeSeries(ts)
477 else:
478 raise Exception('Multidimensional inputs must be provided as MultiTimeSeries')
479 return ts
481 def multi_single_mi(mts, ts, ds=1, k=5, estimator='gcmi'):
482 if estimator == 'ksg':
483 raise NotImplementedError('KSG estimator is not supported for dim>1 yet')
485 # Safety check: if single TimeSeries data is contained in MultiTimeSeries
486 # This should not happen due to aggregate_multiple_ts adding noise, but add as safety net
487 if not ts.discrete and shift == 0:
488 # Check if any row of the MultiTimeSeries matches the single TimeSeries
489 for i in range(mts.data.shape[0]):
490 if np.allclose(mts.data[i, ::ds], ts.data[::ds], rtol=1e-10, atol=1e-10):
491 warnings.warn('MI computation between MultiTimeSeries containing identical data detected, returning 0')
492 return 0.0
494 if ts.discrete:
495 ny1 = np.roll(mts.copula_normal_data[:, ::ds], shift)
496 # Ensure ny1 is contiguous for better performance with Numba
497 if not ny1.flags['C_CONTIGUOUS']:
498 ny1 = np.ascontiguousarray(ny1)
499 ny2 = ts.int_data[::ds]
500 mi = mi_model_gd(ny1, ny2, np.max(ny2), biascorrect=True, demeaned=True)
502 else:
503 ny1 = mts.copula_normal_data[:, ::ds]
504 ny2 = np.roll(ts.copula_normal_data[::ds], shift)
505 mi = mi_gg(ny1, ny2, True, True)
507 return mi
509 def multi_multi_mi(mts1, mts2, ds=1, k=5, estimator='gcmi', check_for_coincidence=False):
510 if estimator == 'ksg':
511 raise NotImplementedError('KSG estimator is not supported for dim>1 yet')
513 if check_for_coincidence:
514 if np.allclose(ts1.data, ts2.data) and shift == 0: # and not (ts1.discrete and ts2.discrete):
515 warnings.warn('MI computation of a MultiTimeSeries with itself is meaningless, 0 will be returned forcefully')
516 # raise ValueError('MI(X,X) computation for continuous variable X should give an infinite result')
517 return 0.0
519 if mts1.discrete or mts2.discrete:
520 raise NotImplementedError('MI computation between MultiTimeSeries\
521 is currently supported for continuous data only')
523 else:
524 ny1 = mts1.copula_normal_data[:, ::ds]
525 ny2 = np.roll(mts2.copula_normal_data[:, ::ds], shift, axis=1)
526 mi = mi_gg(ny1, ny2, True, True)
528 return mi
530 ts1 = _check_input(x)
531 ts2 = _check_input(y)
533 if isinstance(ts1, TimeSeries) and isinstance(ts2, TimeSeries):
534 mi = get_1d_mi(x, y, shift=shift, ds=ds, k=k, estimator=estimator,
535 check_for_coincidence=check_for_coincidence)
537 if isinstance(ts1, MultiTimeSeries) and isinstance(ts2, TimeSeries):
538 mi = multi_single_mi(ts1, ts2, ds=ds, k=k, estimator=estimator)
540 if isinstance(ts2, MultiTimeSeries) and isinstance(ts1, TimeSeries):
541 mi = multi_single_mi(ts2, ts1, ds=ds, k=k, estimator=estimator)
543 if isinstance(ts1, MultiTimeSeries) and isinstance(ts2, MultiTimeSeries):
544 mi = multi_multi_mi(ts1, ts2, ds=ds, k=k, estimator=estimator,
545 check_for_coincidence=check_for_coincidence)
546 #raise NotImplementedError('MI computation between two MultiTimeSeries is not supported yet')
548 if mi < 0:
549 mi = 0
551 return mi
554def get_1d_mi(ts1, ts2, shift=0, ds=1, k=5, estimator='gcmi', check_for_coincidence=True):
555 """Computes mutual information between two 1d variables efficiently
557 Parameters
558 ----------
559 ts1: TimeSeries/MultiTimeSeries instance or numpy array
560 ts2: TimeSeries/MultiTimeSeries instance or numpy array
561 shift: int
562 ts2 will be roll-moved by the number 'shift' after downsampling by 'ds' factor
563 ds: int
564 downsampling constant (take every 'ds'-th point)
565 k: int
566 number of neighbors for ksg estimator
567 estimator: str
568 Estimation method. Should be 'ksg' (accurate but slow) and 'gcmi' (fast, but estimates the lower bound on MI).
569 In most cases 'gcmi' should be preferred.
570 check_for_coincidence : bool, optional
571 If True, raises error when computing MI of a signal with itself at zero shift. Default: True.
573 Returns
574 -------
575 mi: mutual information (or its lower bound in case of 'gcmi' estimator) between ts1 and (possibly) shifted ts2
577 """
579 def _check_input(ts):
580 if not isinstance(ts, TimeSeries) and not isinstance(ts, MultiTimeSeries):
581 if np.ndim(ts) == 1:
582 ts = TimeSeries(ts)
583 else:
584 raise Exception('Multidimensional inputs must be provided as MultiTimeSeries')
585 return ts
587 ts1 = _check_input(ts1)
588 ts2 = _check_input(ts2)
590 if check_for_coincidence and ts1.data.shape == ts2.data.shape:
591 if np.allclose(ts1.data, ts2.data) and shift == 0: #and not (ts1.discrete and ts2.discrete):
592 raise ValueError('MI computation of a TimeSeries or MultiTimeSeries with itself is not allowed')
593 #raise ValueError('MI(X,X) computation for continuous variable X should give an infinite result')
595 if estimator == 'ksg':
596 #TODO: add shifts everywhere in this branch
597 x = ts1.data[::ds].reshape(-1, 1)
598 y = ts2.data[::ds]
599 if shift != 0:
600 y = np.roll(y, shift)
602 if not ts1.discrete and not ts2.discrete:
603 mi = nonparam_mi_cc_mod(ts1.data, y, k=k,
604 precomputed_tree_x=ts1.get_kdtree(),
605 precomputed_tree_y=ts2.get_kdtree())
607 elif ts1.discrete and ts2.discrete:
608 mi = mutual_info_classif(ts1.int_data[::ds].reshape(-1, 1),
609 ts2.int_data[::ds],
610 discrete_features=True,
611 n_neighbors=k)[0]
613 # TODO: refactor using ksg functions
614 elif ts1.discrete and not ts2.discrete:
615 mi = mutual_info_regression(ts1.int_data[::ds],
616 y[::ds],
617 discrete_features=False,
618 n_neighbors=k)[0]
620 elif not ts1.discrete and ts2.discrete:
621 mi = mutual_info_classif(x[::ds],
622 ts2.int_data[::ds],
623 discrete_features=True,
624 n_neighbors=k)[0]
626 return mi
628 elif estimator == 'gcmi':
629 if not ts1.discrete and not ts2.discrete:
630 ny1 = ts1.copula_normal_data[::ds]
631 ny2 = np.roll(ts2.copula_normal_data[::ds], shift)
632 mi = mi_gg(ny1, ny2, True, True)
634 elif ts1.discrete and ts2.discrete:
635 # if features are binary:
636 if ts1.is_binary and ts2.is_binary:
637 ny1 = ts1.bool_data[::ds]
638 ny2 = np.roll(ts2.bool_data[::ds], shift)
640 contingency = np.zeros((2, 2))
641 contingency[0, 0] = (ny1 & ny2).sum()
642 contingency[0, 1] = (~ny1 & ny2).sum()
643 contingency[1, 0] = (ny1 & ~ny2).sum()
644 contingency[1, 1] = (~ny1 & ~ny2).sum()
646 mi = binary_mi_score(contingency)
648 else:
649 ny1 = ts1.int_data[::ds] # .reshape(-1, 1)
650 ny2 = np.roll(ts2.int_data[::ds], shift)
651 mi = mutual_info_score(ny1, ny2)
652 # Ensure float type for consistency
653 mi = float(mi)
655 elif ts1.discrete and not ts2.discrete:
656 ny1 = ts1.int_data[::ds]
657 ny2 = np.roll(ts2.copula_normal_data[::ds], shift)
658 # Ensure ny2 is contiguous for better performance with Numba
659 if not ny2.flags['C_CONTIGUOUS']:
660 ny2 = np.ascontiguousarray(ny2)
661 mi = mi_model_gd(ny2, ny1, np.max(ny1), biascorrect=True, demeaned=True)
663 elif not ts1.discrete and ts2.discrete:
664 ny1 = ts1.copula_normal_data[::ds]
665 #TODO: fix zd error
666 ny2 = np.roll(ts2.int_data[::ds], shift)
667 #ny2 = np.roll(ts2.data[::ds], shift)
668 # Ensure ny1 is contiguous for better performance with Numba
669 if not ny1.flags['C_CONTIGUOUS']:
670 ny1 = np.ascontiguousarray(ny1)
671 '''
672 print(ny2)
673 print(sum(ny2))
674 print(ny1)
675 '''
676 # Ensure ny1 is contiguous for better performance with Numba
677 if not ny1.flags['C_CONTIGUOUS']:
678 ny1 = np.ascontiguousarray(ny1)
679 mi = mi_model_gd(ny1, ny2, np.max(ny2), biascorrect=True, demeaned=True)
681 if mi < 0:
682 mi = 0.0
684 return mi
687def get_tdmi(data, min_shift=1, max_shift=100, nn=DEFAULT_NN):
688 ts = TimeSeries(data, discrete=False)
689 tdmi = [get_1d_mi(ts, ts, shift=shift, k=nn) for shift in range(min_shift, max_shift)]
691 return tdmi
694def get_multi_mi(tslist, ts2, shift=0, ds=1, k=DEFAULT_NN, estimator='gcmi'):
696 #TODO: make shift the same as in get_1d_mi
697 if ~np.all([ts.discrete for ts in tslist]) and not ts2.discrete:
698 nylist = [ts.copula_normal_data[::ds] for ts in tslist]
699 ny1 = np.vstack(nylist)
700 ny2 = np.roll(ts2.copula_normal_data, shift)[::ds]
701 mi = mi_gg(ny1, ny2, True, True)
702 else:
703 raise ValueError('Multidimensional MI only implemented for continuous data!')
705 if mi < 0:
706 mi = 0
708 return mi
711def aggregate_multiple_ts(*ts_args, noise=1e-5):
712 """Aggregate multiple continuous TimeSeries into a single MultiTimeSeries.
714 Adds small noise to break degeneracy and creates a MultiTimeSeries from
715 the input TimeSeries objects.
717 Parameters
718 ----------
719 *ts_args : TimeSeries
720 Variable number of TimeSeries objects to aggregate.
721 noise : float, optional
722 Amount of noise to add to break degeneracy. Default: 1e-5.
724 Returns
725 -------
726 MultiTimeSeries
727 Aggregated multi-dimensional time series.
729 Raises
730 ------
731 ValueError
732 If any input TimeSeries is discrete.
734 Examples
735 --------
736 >>> ts1 = TimeSeries(np.random.randn(100), discrete=False)
737 >>> ts2 = TimeSeries(np.random.randn(100), discrete=False)
738 >>> mts = aggregate_multiple_ts(ts1, ts2)
739 """
740 # add small noise to break degeneracy
741 mod_tslist = []
742 for ts in ts_args:
743 if ts.discrete:
744 raise ValueError('this is not applicable to discrete TimeSeries')
745 mod_ts = TimeSeries(ts.data + np.random.random(size=len(ts.data)) * noise, discrete=False)
746 mod_tslist.append(mod_ts)
748 mts = MultiTimeSeries(mod_tslist) # add last two TS into a single 2-d MTS
749 return mts
752def conditional_mi(ts1, ts2, ts3, ds=1, k=5):
753 """Calculate conditional mutual information I(X;Y|Z).
755 Computes the conditional mutual information between ts1 (X) and ts2 (Y)
756 given ts3 (Z) for various combinations of continuous and discrete variables.
758 Parameters
759 ----------
760 ts1 : TimeSeries
761 First variable (X). Must be continuous.
762 ts2 : TimeSeries
763 Second variable (Y). Can be continuous or discrete.
764 ts3 : TimeSeries
765 Conditioning variable (Z). Can be continuous or discrete.
766 ds : int, optional
767 Downsampling factor. Default: 1.
768 k : int, optional
769 Number of neighbors for entropy estimation. Default: 5.
771 Returns
772 -------
773 float
774 Conditional mutual information I(X;Y|Z) in bits.
776 Raises
777 ------
778 ValueError
779 If ts1 is discrete (only continuous X is currently supported).
781 Notes
782 -----
783 Supports four cases:
784 - CCC: All continuous - uses Gaussian copula
785 - CCD: X,Y continuous, Z discrete - uses Gaussian copula per Z value
786 - CDC: X,Z continuous, Y discrete - uses chain rule identity
787 - CDD: X continuous, Y,Z discrete - uses entropy decomposition
789 For the CDD case, GCMI estimator has limitations due to uncontrollable
790 biases (copula transform does not conserve entropy). See
791 https://doi.org/10.1002/hbm.23471 for details.
792 """
793 if ts1.discrete:
794 raise ValueError('conditional MI(X,Y|Z) is currently implemented for continuous X only')
796 #print(ts1.discrete, ts2.discrete, ts3.discrete)
797 if not ts2.discrete and not ts3.discrete:
798 # CCC: All continuous
799 g1 = ts1.copula_normal_data[::ds]
800 g2 = ts2.copula_normal_data[::ds]
801 g3 = ts3.copula_normal_data[::ds]
802 cmi = cmi_ggg(g1, g2, g3, biascorrect=True, demeaned=True)
804 elif not ts2.discrete and ts3.discrete:
805 # CCD: X,Y continuous, Z discrete
806 unique_discrete_vals = np.unique(ts3.int_data[::ds])
807 cmi = gccmi_ccd(ts1.data[::ds],
808 ts2.data[::ds],
809 ts3.int_data[::ds],
810 len(unique_discrete_vals))
812 elif ts2.discrete and not ts3.discrete:
813 # CDC: X,Z continuous, Y discrete
814 # Use entropy-based identity: I(X;Y|Z) = H(X|Z) - H(X|Y,Z)
815 # This avoids mixing different MI estimators that cause bias inconsistency
817 # H(X|Z) for continuous X,Z using GCMI
818 x_data = ts1.data[::ds].reshape(1, -1)
819 z_data = ts3.data[::ds].reshape(1, -1)
821 # Joint data for H(X,Z) and marginal H(Z)
822 xz_joint = np.vstack([x_data, z_data])
823 H_xz = ent_g(xz_joint, biascorrect=True)
824 H_z = ent_g(z_data, biascorrect=True)
825 H_x_given_z = H_xz - H_z
827 # H(X|Y,Z) - conditional entropy of X given both Y (discrete) and Z (continuous)
828 unique_y_vals = np.unique(ts2.int_data[::ds])
829 H_x_given_yz = 0.0
831 for y_val in unique_y_vals:
832 # Find indices where Y = y_val
833 y_mask = (ts2.int_data[::ds] == y_val)
834 n_y = np.sum(y_mask)
836 if n_y > 2: # Need sufficient samples for entropy estimation
837 # Extract X,Z values for this Y group
838 x_subset = x_data[:, y_mask]
839 z_subset = z_data[:, y_mask]
841 # Joint entropy H(X,Z|Y=y_val)
842 xz_subset = np.vstack([x_subset, z_subset])
843 H_xz_given_y = ent_g(xz_subset, biascorrect=True)
845 # Marginal entropy H(Z|Y=y_val)
846 H_z_given_y = ent_g(z_subset, biascorrect=True)
848 # Conditional entropy H(X|Z,Y=y_val) = H(X,Z|Y=y_val) - H(Z|Y=y_val)
849 H_x_given_z_y = H_xz_given_y - H_z_given_y
851 # Weight by probability P(Y=y_val)
852 p_y = n_y / len(ts2.int_data[::ds])
853 H_x_given_yz += p_y * H_x_given_z_y
855 # Final CMI calculation
856 cmi = H_x_given_z - H_x_given_yz
858 # Ensure CMI >= 0 due to information theory constraint
859 # Small negative values are due to numerical precision and estimation noise
860 if cmi < 0 and abs(cmi) < 0.01:
861 cmi = 0.0
863 else:
864 # CDD: X continuous, Y,Z discrete
865 # Here we use the identity I(X;Y|Z) = H(X,Z) + H(Y,Z) - H(X,Y,Z) - H(Z)
866 '''
867 # TODO: check this
868 # Note that GCMI estimator is poorly applicable here because of the uncontrollable biases:
869 # GCMI correctly estimates the lower bound on MI, but copula transform does not conserve the entropy
870 # See https://doi.org/10.1002/hbm.23471 for further details
871 # Therefore, joint entropy estimation relies on ksg estimator instead
872 '''
873 # Note: Original code used copula_normal_data, but our entropy functions expect raw data
874 # Using data instead of copula_normal_data for consistency with entropy functions
875 H_xz = joint_entropy_cd(ts3.int_data[::ds], ts1.data[::ds], k=k)
876 H_yz = joint_entropy_dd(ts2.int_data[::ds], ts3.int_data[::ds])
877 H_xyz = joint_entropy_cdd(ts2.int_data[::ds], ts3.int_data[::ds], ts1.data[::ds], k=k)
878 H_z = entropy_d(ts3.int_data[::ds])
879 #print('entropies:', H_xz, H_yz, H_xyz, H_z)
880 cmi = H_xz + H_yz - H_xyz - H_z
882 return cmi
885def interaction_information(ts1, ts2, ts3, ds=1, k=5):
886 """Calculate three-way interaction information II(X;Y;Z).
888 The interaction information quantifies the amount of information
889 that is shared among all three variables. It can be positive (synergy)
890 or negative (redundancy).
892 Parameters
893 ----------
894 ts1 : TimeSeries
895 First variable (X). Must be continuous.
896 ts2 : TimeSeries
897 Second variable (Y). Can be continuous or discrete.
898 ts3 : TimeSeries
899 Third variable (Z). Can be continuous or discrete.
900 ds : int, optional
901 Downsampling factor. Default: 1.
902 k : int, optional
903 Number of neighbors for entropy estimation. Default: 5.
905 Returns
906 -------
907 float
908 Interaction information II(X;Y;Z) in bits.
909 - II < 0: Redundancy (Y and Z provide overlapping information about X)
910 - II > 0: Synergy (Y and Z together provide more information than separately)
912 Notes
913 -----
914 The interaction information is computed using Williams & Beer convention:
915 II(X;Y;Z) = I(X;Y|Z) - I(X;Y) = I(X;Z|Y) - I(X;Z)
917 This implementation assumes X is the target variable (e.g., neural activity)
918 and Y, Z are predictor variables (e.g., behavioral features).
919 """
920 # Compute pairwise mutual information
921 mi_xy = get_mi(ts1, ts2, ds=ds)
922 mi_xz = get_mi(ts1, ts3, ds=ds)
924 # Compute conditional mutual information
925 cmi_xy_given_z = conditional_mi(ts1, ts2, ts3, ds=ds, k=k)
926 cmi_xz_given_y = conditional_mi(ts1, ts3, ts2, ds=ds, k=k)
928 # Compute interaction information (should be the same from both formulas)
929 # Using Williams & Beer convention: II = I(X;Y|Z) - I(X;Y)
930 # This gives negative II for redundancy and positive II for synergy
931 ii_1 = cmi_xy_given_z - mi_xy
932 ii_2 = cmi_xz_given_y - mi_xz
934 # Average for numerical stability
935 ii = (ii_1 + ii_2) / 2.0
937 return ii