Coverage for src/driada/intense/distribution_investigation.py: 0.00%
297 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2Investigation of MI distribution types for INTENSE statistical testing.
4This module investigates why metric_distr_type='norm' works better than 'gamma'
5for MI distributions in INTENSE, despite gamma being theoretically more appropriate.
7Key questions addressed:
81. What are the statistical properties of actual MI shuffle distributions?
92. How do different distributions (norm, gamma, lognorm) fit the data?
103. Why does normal distribution give better detection performance?
114. What are the root causes of this empirical observation?
12"""
14import numpy as np
15import matplotlib.pyplot as plt
16from scipy import stats
17from scipy.stats import normaltest, shapiro, anderson, kstest, gamma, norm, lognorm
18from typing import Dict, List, Tuple, Optional, Union
19import warnings
20from dataclasses import dataclass
21from pathlib import Path
23from .stats import get_mi_distr_pvalue, get_distribution_function
24from .pipelines import compute_cell_feat_significance
25from ..experiment import generate_mixed_population_exp, generate_circular_manifold_exp
28@dataclass
29class DistributionFitResult:
30 """Results from fitting a distribution to data."""
31 distribution: str
32 parameters: Tuple
33 aic: float
34 bic: float
35 ks_statistic: float
36 ks_pvalue: float
37 log_likelihood: float
38 fitted_distribution: object
41@dataclass
42class ShuffleDistributionData:
43 """Container for shuffle distribution data and metadata."""
44 shuffle_values: np.ndarray
45 true_mi: float
46 neuron_id: Union[str, int]
47 feature_id: str
48 is_significant: bool
49 p_value_norm: float
50 p_value_gamma: float
51 statistical_properties: Dict
54class MIDistributionInvestigator:
55 """
56 Investigates MI distribution fitting for INTENSE statistical testing.
58 This class provides comprehensive analysis of why normal distribution
59 works better than gamma for MI shuffle distributions.
60 """
62 def __init__(self, random_state: int = 42):
63 """
64 Initialize the MI distribution investigator.
66 Parameters
67 ----------
68 random_state : int, optional
69 Random seed for reproducibility. Default: 42.
70 """
71 self.random_state = random_state
72 np.random.seed(random_state)
74 # Distributions to test
75 self.distributions = {
76 'norm': norm,
77 'gamma': gamma,
78 'lognorm': lognorm,
79 'expon': stats.expon,
80 'weibull_min': stats.weibull_min,
81 'beta': stats.beta
82 }
84 # Results storage
85 self.shuffle_data: List[ShuffleDistributionData] = []
86 self.fit_results: Dict[str, Dict[str, DistributionFitResult]] = {}
88 def generate_test_data(self,
89 n_scenarios: int = 5,
90 n_shuffles: int = 1000) -> List[ShuffleDistributionData]:
91 """
92 Generate test data with known MI distributions.
94 Parameters
95 ----------
96 n_scenarios : int, optional
97 Number of different scenarios to test. Default: 5.
98 n_shuffles : int, optional
99 Number of shuffles per scenario. Default: 1000.
101 Returns
102 -------
103 shuffle_data : List[ShuffleDistributionData]
104 List of shuffle distribution data for analysis.
105 """
106 print("Generating test data for MI distribution investigation...")
108 shuffle_data = []
110 # Scenario 1: Circular manifold (head direction cells)
111 print(" - Scenario 1: Circular manifold")
112 exp_circular = generate_circular_manifold_exp(
113 n_neurons=20,
114 duration=300,
115 fps=20,
116 seed=self.random_state
117 )
119 # Extract shuffle distributions from INTENSE analysis
120 circular_data = self._extract_shuffle_distributions(
121 exp_circular,
122 scenario_name="circular",
123 n_shuffles=n_shuffles
124 )
125 shuffle_data.extend(circular_data)
127 # Scenario 2: Mixed population with spatial and feature components
128 print(" - Scenario 2: Mixed population")
129 exp_mixed = generate_mixed_population_exp(
130 n_neurons=50,
131 manifold_type='2d_spatial',
132 manifold_fraction=0.6,
133 n_discrete_features=1,
134 n_continuous_features=2,
135 duration=300,
136 fps=20,
137 seed=self.random_state + 1
138 )
140 mixed_data = self._extract_shuffle_distributions(
141 exp_mixed,
142 scenario_name="mixed",
143 n_shuffles=n_shuffles
144 )
145 shuffle_data.extend(mixed_data)
147 self.shuffle_data = shuffle_data
148 print(f"Generated {len(shuffle_data)} shuffle distributions for analysis")
150 return shuffle_data
152 def _extract_shuffle_distributions(self,
153 exp,
154 scenario_name: str,
155 n_shuffles: int) -> List[ShuffleDistributionData]:
156 """
157 Extract shuffle distributions from INTENSE analysis.
159 Parameters
160 ----------
161 exp : Experiment
162 DRIADA experiment object.
163 scenario_name : str
164 Name of the scenario for identification.
165 n_shuffles : int
166 Number of shuffles to use.
168 Returns
169 -------
170 shuffle_data : List[ShuffleDistributionData]
171 Extracted shuffle distribution data.
172 """
173 # Run INTENSE analysis to get shuffle distributions
174 from .intense_base import scan_pairs, calculate_optimal_delays
176 # Get features and neurons
177 available_features = list(exp.dynamic_features.keys())
178 if len(available_features) > 3:
179 available_features = available_features[:3] # Limit for efficiency
181 cell_ids = list(range(min(10, exp.n_cells))) # Limit for efficiency
183 shuffle_data = []
185 for i, cell_id in enumerate(cell_ids):
186 for j, feature_id in enumerate(available_features):
187 try:
188 # Get neural and feature time series
189 neural_ts = exp.calcium[cell_id]
190 feature_ts = exp.dynamic_features[feature_id]
192 # Create TimeSeries objects
193 from ..information.info_base import TimeSeries
194 ts_neural = TimeSeries(neural_ts)
195 ts_feature = feature_ts if hasattr(feature_ts, 'data') else TimeSeries(feature_ts)
197 # Calculate optimal delays
198 optimal_delays = calculate_optimal_delays(
199 [ts_neural],
200 [ts_feature],
201 metric='mi',
202 shift_window=20, # Small window for efficiency
203 ds=1,
204 verbose=False
205 )
207 # Run scan_pairs to get MI shuffle distributions
208 random_shifts, me_total = scan_pairs(
209 [ts_neural],
210 [ts_feature],
211 metric='mi',
212 nsh=n_shuffles,
213 optimal_delays=optimal_delays,
214 joint_distr=False,
215 ds=1,
216 mask=None,
217 noise_const=1e-3,
218 seed=self.random_state + i * 100 + j,
219 allow_mixed_dimensions=False,
220 enable_progressbar=False
221 )
223 # Extract true MI and shuffle values
224 true_mi = me_total[0, 0, 0] # First element is true MI
225 shuffle_values = me_total[0, 0, 1:] # Rest are shuffles
227 # Calculate p-values with both distributions
228 p_val_norm = get_mi_distr_pvalue(shuffle_values, true_mi, 'norm')
229 p_val_gamma = get_mi_distr_pvalue(shuffle_values, true_mi, 'gamma')
231 # Calculate statistical properties
232 stats_props = self._calculate_statistical_properties(shuffle_values)
234 # Determine significance (using p < 0.05 as threshold)
235 is_significant = min(p_val_norm, p_val_gamma) < 0.05
237 # Create data container
238 data = ShuffleDistributionData(
239 shuffle_values=shuffle_values,
240 true_mi=true_mi,
241 neuron_id=f"{scenario_name}_neuron_{cell_id}",
242 feature_id=feature_id,
243 is_significant=is_significant,
244 p_value_norm=p_val_norm,
245 p_value_gamma=p_val_gamma,
246 statistical_properties=stats_props
247 )
249 shuffle_data.append(data)
251 except Exception as e:
252 print(f" Warning: Failed to extract data for neuron {cell_id}, feature {feature_id}: {e}")
253 continue
255 return shuffle_data
257 def _calculate_statistical_properties(self, data: np.ndarray) -> Dict:
258 """
259 Calculate comprehensive statistical properties of data.
261 Parameters
262 ----------
263 data : np.ndarray
264 Input data array.
266 Returns
267 -------
268 properties : Dict
269 Dictionary containing statistical properties.
270 """
271 # Basic moments
272 mean = np.mean(data)
273 std = np.std(data)
274 var = np.var(data)
276 # Shape statistics
277 skewness = stats.skew(data)
278 kurtosis = stats.kurtosis(data)
280 # Quantiles
281 q25, q50, q75 = np.percentile(data, [25, 50, 75])
282 iqr = q75 - q25
284 # Normality tests
285 try:
286 shapiro_stat, shapiro_p = shapiro(data)
287 except:
288 shapiro_stat, shapiro_p = np.nan, np.nan
290 try:
291 normaltest_stat, normaltest_p = normaltest(data)
292 except:
293 normaltest_stat, normaltest_p = np.nan, np.nan
295 # Anderson-Darling test for normality
296 try:
297 anderson_result = anderson(data, dist='norm')
298 anderson_stat = anderson_result.statistic
299 anderson_critical = anderson_result.critical_values[2] # 5% level
300 except:
301 anderson_stat, anderson_critical = np.nan, np.nan
303 # Range and outliers
304 data_range = np.max(data) - np.min(data)
305 outlier_threshold = q75 + 1.5 * iqr
306 n_outliers = np.sum(data > outlier_threshold)
308 return {
309 'mean': mean,
310 'std': std,
311 'var': var,
312 'skewness': skewness,
313 'kurtosis': kurtosis,
314 'median': q50,
315 'q25': q25,
316 'q75': q75,
317 'iqr': iqr,
318 'range': data_range,
319 'n_outliers': n_outliers,
320 'shapiro_stat': shapiro_stat,
321 'shapiro_pvalue': shapiro_p,
322 'normaltest_stat': normaltest_stat,
323 'normaltest_pvalue': normaltest_p,
324 'anderson_stat': anderson_stat,
325 'anderson_critical': anderson_critical,
326 'min_value': np.min(data),
327 'max_value': np.max(data),
328 'n_samples': len(data)
329 }
331 def fit_distributions(self,
332 data: np.ndarray,
333 distributions: Optional[List[str]] = None) -> Dict[str, DistributionFitResult]:
334 """
335 Fit multiple distributions to data and compare goodness of fit.
337 Parameters
338 ----------
339 data : np.ndarray
340 Data to fit distributions to.
341 distributions : List[str], optional
342 List of distribution names to test. If None, uses default set.
344 Returns
345 -------
346 results : Dict[str, DistributionFitResult]
347 Dictionary mapping distribution names to fit results.
348 """
349 if distributions is None:
350 distributions = list(self.distributions.keys())
352 results = {}
353 n_samples = len(data)
355 for dist_name in distributions:
356 try:
357 dist = self.distributions[dist_name]
359 # Fit distribution
360 if dist_name in ['gamma', 'lognorm']:
361 # Use floc=0 for positive distributions
362 params = dist.fit(data, floc=0)
363 elif dist_name == 'beta':
364 # Beta distribution needs data in [0,1]
365 if np.min(data) < 0 or np.max(data) > 1:
366 # Skip beta if data is outside [0,1]
367 continue
368 params = dist.fit(data)
369 else:
370 params = dist.fit(data)
372 # Create fitted distribution
373 fitted_dist = dist(*params)
375 # Calculate log-likelihood
376 log_likelihood = np.sum(dist.logpdf(data, *params))
378 # Calculate AIC and BIC
379 k = len(params) # Number of parameters
380 aic = 2 * k - 2 * log_likelihood
381 bic = k * np.log(n_samples) - 2 * log_likelihood
383 # Kolmogorov-Smirnov test
384 ks_stat, ks_p = kstest(data, fitted_dist.cdf)
386 # Store results
387 results[dist_name] = DistributionFitResult(
388 distribution=dist_name,
389 parameters=params,
390 aic=aic,
391 bic=bic,
392 ks_statistic=ks_stat,
393 ks_pvalue=ks_p,
394 log_likelihood=log_likelihood,
395 fitted_distribution=fitted_dist
396 )
398 except Exception as e:
399 print(f" Warning: Failed to fit {dist_name}: {e}")
400 continue
402 return results
404 def analyze_all_distributions(self) -> Dict[str, Dict[str, DistributionFitResult]]:
405 """
406 Analyze distribution fitting for all collected shuffle data.
408 Returns
409 -------
410 fit_results : Dict[str, Dict[str, DistributionFitResult]]
411 Nested dictionary with fit results for each data sample.
412 """
413 print("Analyzing distribution fitting for all shuffle data...")
415 fit_results = {}
417 for i, data in enumerate(self.shuffle_data):
418 data_id = f"{data.neuron_id}_{data.feature_id}"
420 # Fit distributions to shuffle values
421 results = self.fit_distributions(data.shuffle_values)
422 fit_results[data_id] = results
424 # Print progress
425 if (i + 1) % 10 == 0:
426 print(f" Analyzed {i + 1}/{len(self.shuffle_data)} distributions")
428 self.fit_results = fit_results
429 print(f"Completed distribution analysis for {len(fit_results)} datasets")
431 return fit_results
433 def compare_detection_performance(self) -> Dict[str, Dict[str, float]]:
434 """
435 Compare detection performance between different distributions.
437 Returns
438 -------
439 performance : Dict[str, Dict[str, float]]
440 Performance metrics for each distribution type.
441 """
442 print("Comparing detection performance between distributions...")
444 # Calculate detection performance for each distribution
445 performance = {}
447 for dist_name in ['norm', 'gamma', 'lognorm']:
448 true_positives = 0
449 false_positives = 0
450 true_negatives = 0
451 false_negatives = 0
453 for data in self.shuffle_data:
454 # Get p-value for this distribution
455 p_val = get_mi_distr_pvalue(data.shuffle_values, data.true_mi, dist_name)
456 predicted_significant = p_val < 0.05
458 # Compare with ground truth (using minimum p-value as reference)
459 actual_significant = data.is_significant
461 if predicted_significant and actual_significant:
462 true_positives += 1
463 elif predicted_significant and not actual_significant:
464 false_positives += 1
465 elif not predicted_significant and actual_significant:
466 false_negatives += 1
467 else:
468 true_negatives += 1
470 # Calculate metrics
471 total = len(self.shuffle_data)
472 sensitivity = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
473 specificity = true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) > 0 else 0
474 precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
475 accuracy = (true_positives + true_negatives) / total
477 performance[dist_name] = {
478 'sensitivity': sensitivity,
479 'specificity': specificity,
480 'precision': precision,
481 'accuracy': accuracy,
482 'true_positives': true_positives,
483 'false_positives': false_positives,
484 'true_negatives': true_negatives,
485 'false_negatives': false_negatives
486 }
488 return performance
490 def generate_summary_report(self) -> str:
491 """
492 Generate a comprehensive summary report of the investigation.
494 Returns
495 -------
496 report : str
497 Formatted summary report.
498 """
499 report = []
500 report.append("=" * 80)
501 report.append("MI DISTRIBUTION INVESTIGATION REPORT")
502 report.append("=" * 80)
504 # Data summary
505 report.append(f"\nDATA SUMMARY:")
506 report.append(f" Total shuffle distributions analyzed: {len(self.shuffle_data)}")
508 significant_count = sum(1 for d in self.shuffle_data if d.is_significant)
509 report.append(f" Significant pairs: {significant_count}")
510 report.append(f" Non-significant pairs: {len(self.shuffle_data) - significant_count}")
512 # Statistical properties summary
513 report.append(f"\nSTATISTICAL PROPERTIES SUMMARY:")
515 # Average statistics across all distributions
516 avg_stats = {}
517 for prop in ['mean', 'std', 'skewness', 'kurtosis', 'shapiro_pvalue', 'normaltest_pvalue']:
518 values = [d.statistical_properties[prop] for d in self.shuffle_data if not np.isnan(d.statistical_properties[prop])]
519 if values:
520 avg_stats[prop] = np.mean(values)
522 report.append(f" Average skewness: {avg_stats.get('skewness', 'N/A'):.3f}")
523 report.append(f" Average kurtosis: {avg_stats.get('kurtosis', 'N/A'):.3f}")
524 report.append(f" Average Shapiro p-value: {avg_stats.get('shapiro_pvalue', 'N/A'):.3f}")
525 report.append(f" Average normaltest p-value: {avg_stats.get('normaltest_pvalue', 'N/A'):.3f}")
527 # Distribution fitting summary
528 if self.fit_results:
529 report.append(f"\nDISTRIBUTION FITTING SUMMARY:")
531 # Calculate average AIC/BIC for each distribution
532 dist_summary = {}
533 for dist_name in ['norm', 'gamma', 'lognorm']:
534 aics = []
535 bics = []
536 ks_stats = []
538 for data_id, results in self.fit_results.items():
539 if dist_name in results:
540 aics.append(results[dist_name].aic)
541 bics.append(results[dist_name].bic)
542 ks_stats.append(results[dist_name].ks_statistic)
544 if aics:
545 dist_summary[dist_name] = {
546 'avg_aic': np.mean(aics),
547 'avg_bic': np.mean(bics),
548 'avg_ks_stat': np.mean(ks_stats)
549 }
551 for dist_name, stats in dist_summary.items():
552 report.append(f" {dist_name.upper()}:")
553 report.append(f" Average AIC: {stats['avg_aic']:.2f}")
554 report.append(f" Average BIC: {stats['avg_bic']:.2f}")
555 report.append(f" Average KS statistic: {stats['avg_ks_stat']:.3f}")
557 # Detection performance summary
558 performance = self.compare_detection_performance()
559 report.append(f"\nDETECTION PERFORMANCE COMPARISON:")
561 for dist_name, metrics in performance.items():
562 report.append(f" {dist_name.upper()}:")
563 report.append(f" Sensitivity: {metrics['sensitivity']:.3f}")
564 report.append(f" Specificity: {metrics['specificity']:.3f}")
565 report.append(f" Precision: {metrics['precision']:.3f}")
566 report.append(f" Accuracy: {metrics['accuracy']:.3f}")
568 # Recommendations
569 report.append(f"\nRECOMMENDATIONS:")
571 # Find best performing distribution
572 best_dist = max(performance.keys(), key=lambda d: performance[d]['accuracy'])
573 report.append(f" Best performing distribution: {best_dist.upper()}")
575 # Analyze why norm might be better
576 norm_better_count = sum(1 for d in self.shuffle_data if d.p_value_norm < d.p_value_gamma)
577 report.append(f" Cases where norm gives lower p-value: {norm_better_count}/{len(self.shuffle_data)}")
579 # Statistical significance of normality
580 normal_like_count = sum(1 for d in self.shuffle_data
581 if d.statistical_properties['shapiro_pvalue'] > 0.05)
582 report.append(f" Distributions that appear normal (Shapiro p>0.05): {normal_like_count}/{len(self.shuffle_data)}")
584 return "\n".join(report)
586 def create_visualizations(self, save_path: Optional[str] = None) -> None:
587 """
588 Create comprehensive visualizations of the investigation results.
590 Parameters
591 ----------
592 save_path : str, optional
593 Path to save visualizations. If None, displays plots.
594 """
595 print("Creating visualizations...")
597 # Create subplots
598 fig, axes = plt.subplots(2, 3, figsize=(18, 12))
599 fig.suptitle('MI Distribution Investigation Results', fontsize=16)
601 # 1. Distribution of statistical properties
602 ax1 = axes[0, 0]
603 skewness_values = [d.statistical_properties['skewness'] for d in self.shuffle_data]
604 ax1.hist(skewness_values, bins=20, alpha=0.7, color='skyblue')
605 ax1.set_xlabel('Skewness')
606 ax1.set_ylabel('Frequency')
607 ax1.set_title('Distribution of Skewness Values')
608 ax1.axvline(0, color='red', linestyle='--', alpha=0.7, label='Normal (skew=0)')
609 ax1.legend()
611 # 2. Kurtosis distribution
612 ax2 = axes[0, 1]
613 kurtosis_values = [d.statistical_properties['kurtosis'] for d in self.shuffle_data]
614 ax2.hist(kurtosis_values, bins=20, alpha=0.7, color='lightgreen')
615 ax2.set_xlabel('Kurtosis')
616 ax2.set_ylabel('Frequency')
617 ax2.set_title('Distribution of Kurtosis Values')
618 ax2.axvline(0, color='red', linestyle='--', alpha=0.7, label='Normal (kurt=0)')
619 ax2.legend()
621 # 3. P-value comparison
622 ax3 = axes[0, 2]
623 p_norm = [d.p_value_norm for d in self.shuffle_data]
624 p_gamma = [d.p_value_gamma for d in self.shuffle_data]
625 ax3.scatter(p_norm, p_gamma, alpha=0.6)
626 ax3.plot([0, 1], [0, 1], 'r--', alpha=0.7)
627 ax3.set_xlabel('P-value (norm)')
628 ax3.set_ylabel('P-value (gamma)')
629 ax3.set_title('P-value Comparison: Norm vs Gamma')
630 ax3.set_xlim(0, 1)
631 ax3.set_ylim(0, 1)
633 # 4. Example shuffle distribution
634 ax4 = axes[1, 0]
635 if self.shuffle_data:
636 example_data = self.shuffle_data[0]
637 ax4.hist(example_data.shuffle_values, bins=30, alpha=0.7, density=True, color='lightcoral')
638 ax4.axvline(example_data.true_mi, color='red', linestyle='-', linewidth=2, label='True MI')
639 ax4.set_xlabel('MI Value')
640 ax4.set_ylabel('Density')
641 ax4.set_title('Example Shuffle Distribution')
642 ax4.legend()
644 # 5. Goodness of fit comparison
645 ax5 = axes[1, 1]
646 if self.fit_results:
647 # Get average AIC values for each distribution
648 avg_aics = {}
649 for dist_name in ['norm', 'gamma', 'lognorm']:
650 aics = []
651 for results in self.fit_results.values():
652 if dist_name in results:
653 aics.append(results[dist_name].aic)
654 if aics:
655 avg_aics[dist_name] = np.mean(aics)
657 if avg_aics:
658 dists = list(avg_aics.keys())
659 aics = list(avg_aics.values())
660 ax5.bar(dists, aics, color=['skyblue', 'lightgreen', 'lightcoral'])
661 ax5.set_ylabel('Average AIC')
662 ax5.set_title('Distribution Fit Quality (lower AIC = better)')
663 ax5.tick_params(axis='x', rotation=45)
665 # 6. Detection performance
666 ax6 = axes[1, 2]
667 performance = self.compare_detection_performance()
668 if performance:
669 dists = list(performance.keys())
670 accuracies = [performance[d]['accuracy'] for d in dists]
671 ax6.bar(dists, accuracies, color=['skyblue', 'lightgreen', 'lightcoral'])
672 ax6.set_ylabel('Accuracy')
673 ax6.set_title('Detection Performance')
674 ax6.set_ylim(0, 1)
675 ax6.tick_params(axis='x', rotation=45)
677 plt.tight_layout()
679 if save_path:
680 plt.savefig(save_path, dpi=300, bbox_inches='tight')
681 print(f"Visualizations saved to: {save_path}")
682 else:
683 plt.show()