Coverage for src/driada/intense/distribution_investigation.py: 0.00%

297 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2Investigation of MI distribution types for INTENSE statistical testing. 

3 

4This module investigates why metric_distr_type='norm' works better than 'gamma'  

5for MI distributions in INTENSE, despite gamma being theoretically more appropriate. 

6 

7Key questions addressed: 

81. What are the statistical properties of actual MI shuffle distributions? 

92. How do different distributions (norm, gamma, lognorm) fit the data? 

103. Why does normal distribution give better detection performance? 

114. What are the root causes of this empirical observation? 

12""" 

13 

14import numpy as np 

15import matplotlib.pyplot as plt 

16from scipy import stats 

17from scipy.stats import normaltest, shapiro, anderson, kstest, gamma, norm, lognorm 

18from typing import Dict, List, Tuple, Optional, Union 

19import warnings 

20from dataclasses import dataclass 

21from pathlib import Path 

22 

23from .stats import get_mi_distr_pvalue, get_distribution_function 

24from .pipelines import compute_cell_feat_significance 

25from ..experiment import generate_mixed_population_exp, generate_circular_manifold_exp 

26 

27 

28@dataclass 

29class DistributionFitResult: 

30 """Results from fitting a distribution to data.""" 

31 distribution: str 

32 parameters: Tuple 

33 aic: float 

34 bic: float 

35 ks_statistic: float 

36 ks_pvalue: float 

37 log_likelihood: float 

38 fitted_distribution: object 

39 

40 

41@dataclass 

42class ShuffleDistributionData: 

43 """Container for shuffle distribution data and metadata.""" 

44 shuffle_values: np.ndarray 

45 true_mi: float 

46 neuron_id: Union[str, int] 

47 feature_id: str 

48 is_significant: bool 

49 p_value_norm: float 

50 p_value_gamma: float 

51 statistical_properties: Dict 

52 

53 

54class MIDistributionInvestigator: 

55 """ 

56 Investigates MI distribution fitting for INTENSE statistical testing. 

57  

58 This class provides comprehensive analysis of why normal distribution 

59 works better than gamma for MI shuffle distributions. 

60 """ 

61 

62 def __init__(self, random_state: int = 42): 

63 """ 

64 Initialize the MI distribution investigator. 

65  

66 Parameters 

67 ---------- 

68 random_state : int, optional 

69 Random seed for reproducibility. Default: 42. 

70 """ 

71 self.random_state = random_state 

72 np.random.seed(random_state) 

73 

74 # Distributions to test 

75 self.distributions = { 

76 'norm': norm, 

77 'gamma': gamma, 

78 'lognorm': lognorm, 

79 'expon': stats.expon, 

80 'weibull_min': stats.weibull_min, 

81 'beta': stats.beta 

82 } 

83 

84 # Results storage 

85 self.shuffle_data: List[ShuffleDistributionData] = [] 

86 self.fit_results: Dict[str, Dict[str, DistributionFitResult]] = {} 

87 

88 def generate_test_data(self, 

89 n_scenarios: int = 5, 

90 n_shuffles: int = 1000) -> List[ShuffleDistributionData]: 

91 """ 

92 Generate test data with known MI distributions. 

93  

94 Parameters 

95 ---------- 

96 n_scenarios : int, optional 

97 Number of different scenarios to test. Default: 5. 

98 n_shuffles : int, optional 

99 Number of shuffles per scenario. Default: 1000. 

100  

101 Returns 

102 ------- 

103 shuffle_data : List[ShuffleDistributionData] 

104 List of shuffle distribution data for analysis. 

105 """ 

106 print("Generating test data for MI distribution investigation...") 

107 

108 shuffle_data = [] 

109 

110 # Scenario 1: Circular manifold (head direction cells) 

111 print(" - Scenario 1: Circular manifold") 

112 exp_circular = generate_circular_manifold_exp( 

113 n_neurons=20, 

114 duration=300, 

115 fps=20, 

116 seed=self.random_state 

117 ) 

118 

119 # Extract shuffle distributions from INTENSE analysis 

120 circular_data = self._extract_shuffle_distributions( 

121 exp_circular, 

122 scenario_name="circular", 

123 n_shuffles=n_shuffles 

124 ) 

125 shuffle_data.extend(circular_data) 

126 

127 # Scenario 2: Mixed population with spatial and feature components 

128 print(" - Scenario 2: Mixed population") 

129 exp_mixed = generate_mixed_population_exp( 

130 n_neurons=50, 

131 manifold_type='2d_spatial', 

132 manifold_fraction=0.6, 

133 n_discrete_features=1, 

134 n_continuous_features=2, 

135 duration=300, 

136 fps=20, 

137 seed=self.random_state + 1 

138 ) 

139 

140 mixed_data = self._extract_shuffle_distributions( 

141 exp_mixed, 

142 scenario_name="mixed", 

143 n_shuffles=n_shuffles 

144 ) 

145 shuffle_data.extend(mixed_data) 

146 

147 self.shuffle_data = shuffle_data 

148 print(f"Generated {len(shuffle_data)} shuffle distributions for analysis") 

149 

150 return shuffle_data 

151 

152 def _extract_shuffle_distributions(self, 

153 exp, 

154 scenario_name: str, 

155 n_shuffles: int) -> List[ShuffleDistributionData]: 

156 """ 

157 Extract shuffle distributions from INTENSE analysis. 

158  

159 Parameters 

160 ---------- 

161 exp : Experiment 

162 DRIADA experiment object. 

163 scenario_name : str 

164 Name of the scenario for identification. 

165 n_shuffles : int 

166 Number of shuffles to use. 

167  

168 Returns 

169 ------- 

170 shuffle_data : List[ShuffleDistributionData] 

171 Extracted shuffle distribution data. 

172 """ 

173 # Run INTENSE analysis to get shuffle distributions 

174 from .intense_base import scan_pairs, calculate_optimal_delays 

175 

176 # Get features and neurons 

177 available_features = list(exp.dynamic_features.keys()) 

178 if len(available_features) > 3: 

179 available_features = available_features[:3] # Limit for efficiency 

180 

181 cell_ids = list(range(min(10, exp.n_cells))) # Limit for efficiency 

182 

183 shuffle_data = [] 

184 

185 for i, cell_id in enumerate(cell_ids): 

186 for j, feature_id in enumerate(available_features): 

187 try: 

188 # Get neural and feature time series 

189 neural_ts = exp.calcium[cell_id] 

190 feature_ts = exp.dynamic_features[feature_id] 

191 

192 # Create TimeSeries objects 

193 from ..information.info_base import TimeSeries 

194 ts_neural = TimeSeries(neural_ts) 

195 ts_feature = feature_ts if hasattr(feature_ts, 'data') else TimeSeries(feature_ts) 

196 

197 # Calculate optimal delays 

198 optimal_delays = calculate_optimal_delays( 

199 [ts_neural], 

200 [ts_feature], 

201 metric='mi', 

202 shift_window=20, # Small window for efficiency 

203 ds=1, 

204 verbose=False 

205 ) 

206 

207 # Run scan_pairs to get MI shuffle distributions 

208 random_shifts, me_total = scan_pairs( 

209 [ts_neural], 

210 [ts_feature], 

211 metric='mi', 

212 nsh=n_shuffles, 

213 optimal_delays=optimal_delays, 

214 joint_distr=False, 

215 ds=1, 

216 mask=None, 

217 noise_const=1e-3, 

218 seed=self.random_state + i * 100 + j, 

219 allow_mixed_dimensions=False, 

220 enable_progressbar=False 

221 ) 

222 

223 # Extract true MI and shuffle values 

224 true_mi = me_total[0, 0, 0] # First element is true MI 

225 shuffle_values = me_total[0, 0, 1:] # Rest are shuffles 

226 

227 # Calculate p-values with both distributions 

228 p_val_norm = get_mi_distr_pvalue(shuffle_values, true_mi, 'norm') 

229 p_val_gamma = get_mi_distr_pvalue(shuffle_values, true_mi, 'gamma') 

230 

231 # Calculate statistical properties 

232 stats_props = self._calculate_statistical_properties(shuffle_values) 

233 

234 # Determine significance (using p < 0.05 as threshold) 

235 is_significant = min(p_val_norm, p_val_gamma) < 0.05 

236 

237 # Create data container 

238 data = ShuffleDistributionData( 

239 shuffle_values=shuffle_values, 

240 true_mi=true_mi, 

241 neuron_id=f"{scenario_name}_neuron_{cell_id}", 

242 feature_id=feature_id, 

243 is_significant=is_significant, 

244 p_value_norm=p_val_norm, 

245 p_value_gamma=p_val_gamma, 

246 statistical_properties=stats_props 

247 ) 

248 

249 shuffle_data.append(data) 

250 

251 except Exception as e: 

252 print(f" Warning: Failed to extract data for neuron {cell_id}, feature {feature_id}: {e}") 

253 continue 

254 

255 return shuffle_data 

256 

257 def _calculate_statistical_properties(self, data: np.ndarray) -> Dict: 

258 """ 

259 Calculate comprehensive statistical properties of data. 

260  

261 Parameters 

262 ---------- 

263 data : np.ndarray 

264 Input data array. 

265  

266 Returns 

267 ------- 

268 properties : Dict 

269 Dictionary containing statistical properties. 

270 """ 

271 # Basic moments 

272 mean = np.mean(data) 

273 std = np.std(data) 

274 var = np.var(data) 

275 

276 # Shape statistics 

277 skewness = stats.skew(data) 

278 kurtosis = stats.kurtosis(data) 

279 

280 # Quantiles 

281 q25, q50, q75 = np.percentile(data, [25, 50, 75]) 

282 iqr = q75 - q25 

283 

284 # Normality tests 

285 try: 

286 shapiro_stat, shapiro_p = shapiro(data) 

287 except: 

288 shapiro_stat, shapiro_p = np.nan, np.nan 

289 

290 try: 

291 normaltest_stat, normaltest_p = normaltest(data) 

292 except: 

293 normaltest_stat, normaltest_p = np.nan, np.nan 

294 

295 # Anderson-Darling test for normality 

296 try: 

297 anderson_result = anderson(data, dist='norm') 

298 anderson_stat = anderson_result.statistic 

299 anderson_critical = anderson_result.critical_values[2] # 5% level 

300 except: 

301 anderson_stat, anderson_critical = np.nan, np.nan 

302 

303 # Range and outliers 

304 data_range = np.max(data) - np.min(data) 

305 outlier_threshold = q75 + 1.5 * iqr 

306 n_outliers = np.sum(data > outlier_threshold) 

307 

308 return { 

309 'mean': mean, 

310 'std': std, 

311 'var': var, 

312 'skewness': skewness, 

313 'kurtosis': kurtosis, 

314 'median': q50, 

315 'q25': q25, 

316 'q75': q75, 

317 'iqr': iqr, 

318 'range': data_range, 

319 'n_outliers': n_outliers, 

320 'shapiro_stat': shapiro_stat, 

321 'shapiro_pvalue': shapiro_p, 

322 'normaltest_stat': normaltest_stat, 

323 'normaltest_pvalue': normaltest_p, 

324 'anderson_stat': anderson_stat, 

325 'anderson_critical': anderson_critical, 

326 'min_value': np.min(data), 

327 'max_value': np.max(data), 

328 'n_samples': len(data) 

329 } 

330 

331 def fit_distributions(self, 

332 data: np.ndarray, 

333 distributions: Optional[List[str]] = None) -> Dict[str, DistributionFitResult]: 

334 """ 

335 Fit multiple distributions to data and compare goodness of fit. 

336  

337 Parameters 

338 ---------- 

339 data : np.ndarray 

340 Data to fit distributions to. 

341 distributions : List[str], optional 

342 List of distribution names to test. If None, uses default set. 

343  

344 Returns 

345 ------- 

346 results : Dict[str, DistributionFitResult] 

347 Dictionary mapping distribution names to fit results. 

348 """ 

349 if distributions is None: 

350 distributions = list(self.distributions.keys()) 

351 

352 results = {} 

353 n_samples = len(data) 

354 

355 for dist_name in distributions: 

356 try: 

357 dist = self.distributions[dist_name] 

358 

359 # Fit distribution 

360 if dist_name in ['gamma', 'lognorm']: 

361 # Use floc=0 for positive distributions 

362 params = dist.fit(data, floc=0) 

363 elif dist_name == 'beta': 

364 # Beta distribution needs data in [0,1] 

365 if np.min(data) < 0 or np.max(data) > 1: 

366 # Skip beta if data is outside [0,1] 

367 continue 

368 params = dist.fit(data) 

369 else: 

370 params = dist.fit(data) 

371 

372 # Create fitted distribution 

373 fitted_dist = dist(*params) 

374 

375 # Calculate log-likelihood 

376 log_likelihood = np.sum(dist.logpdf(data, *params)) 

377 

378 # Calculate AIC and BIC 

379 k = len(params) # Number of parameters 

380 aic = 2 * k - 2 * log_likelihood 

381 bic = k * np.log(n_samples) - 2 * log_likelihood 

382 

383 # Kolmogorov-Smirnov test 

384 ks_stat, ks_p = kstest(data, fitted_dist.cdf) 

385 

386 # Store results 

387 results[dist_name] = DistributionFitResult( 

388 distribution=dist_name, 

389 parameters=params, 

390 aic=aic, 

391 bic=bic, 

392 ks_statistic=ks_stat, 

393 ks_pvalue=ks_p, 

394 log_likelihood=log_likelihood, 

395 fitted_distribution=fitted_dist 

396 ) 

397 

398 except Exception as e: 

399 print(f" Warning: Failed to fit {dist_name}: {e}") 

400 continue 

401 

402 return results 

403 

404 def analyze_all_distributions(self) -> Dict[str, Dict[str, DistributionFitResult]]: 

405 """ 

406 Analyze distribution fitting for all collected shuffle data. 

407  

408 Returns 

409 ------- 

410 fit_results : Dict[str, Dict[str, DistributionFitResult]] 

411 Nested dictionary with fit results for each data sample. 

412 """ 

413 print("Analyzing distribution fitting for all shuffle data...") 

414 

415 fit_results = {} 

416 

417 for i, data in enumerate(self.shuffle_data): 

418 data_id = f"{data.neuron_id}_{data.feature_id}" 

419 

420 # Fit distributions to shuffle values 

421 results = self.fit_distributions(data.shuffle_values) 

422 fit_results[data_id] = results 

423 

424 # Print progress 

425 if (i + 1) % 10 == 0: 

426 print(f" Analyzed {i + 1}/{len(self.shuffle_data)} distributions") 

427 

428 self.fit_results = fit_results 

429 print(f"Completed distribution analysis for {len(fit_results)} datasets") 

430 

431 return fit_results 

432 

433 def compare_detection_performance(self) -> Dict[str, Dict[str, float]]: 

434 """ 

435 Compare detection performance between different distributions. 

436  

437 Returns 

438 ------- 

439 performance : Dict[str, Dict[str, float]] 

440 Performance metrics for each distribution type. 

441 """ 

442 print("Comparing detection performance between distributions...") 

443 

444 # Calculate detection performance for each distribution 

445 performance = {} 

446 

447 for dist_name in ['norm', 'gamma', 'lognorm']: 

448 true_positives = 0 

449 false_positives = 0 

450 true_negatives = 0 

451 false_negatives = 0 

452 

453 for data in self.shuffle_data: 

454 # Get p-value for this distribution 

455 p_val = get_mi_distr_pvalue(data.shuffle_values, data.true_mi, dist_name) 

456 predicted_significant = p_val < 0.05 

457 

458 # Compare with ground truth (using minimum p-value as reference) 

459 actual_significant = data.is_significant 

460 

461 if predicted_significant and actual_significant: 

462 true_positives += 1 

463 elif predicted_significant and not actual_significant: 

464 false_positives += 1 

465 elif not predicted_significant and actual_significant: 

466 false_negatives += 1 

467 else: 

468 true_negatives += 1 

469 

470 # Calculate metrics 

471 total = len(self.shuffle_data) 

472 sensitivity = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 

473 specificity = true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) > 0 else 0 

474 precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 

475 accuracy = (true_positives + true_negatives) / total 

476 

477 performance[dist_name] = { 

478 'sensitivity': sensitivity, 

479 'specificity': specificity, 

480 'precision': precision, 

481 'accuracy': accuracy, 

482 'true_positives': true_positives, 

483 'false_positives': false_positives, 

484 'true_negatives': true_negatives, 

485 'false_negatives': false_negatives 

486 } 

487 

488 return performance 

489 

490 def generate_summary_report(self) -> str: 

491 """ 

492 Generate a comprehensive summary report of the investigation. 

493  

494 Returns 

495 ------- 

496 report : str 

497 Formatted summary report. 

498 """ 

499 report = [] 

500 report.append("=" * 80) 

501 report.append("MI DISTRIBUTION INVESTIGATION REPORT") 

502 report.append("=" * 80) 

503 

504 # Data summary 

505 report.append(f"\nDATA SUMMARY:") 

506 report.append(f" Total shuffle distributions analyzed: {len(self.shuffle_data)}") 

507 

508 significant_count = sum(1 for d in self.shuffle_data if d.is_significant) 

509 report.append(f" Significant pairs: {significant_count}") 

510 report.append(f" Non-significant pairs: {len(self.shuffle_data) - significant_count}") 

511 

512 # Statistical properties summary 

513 report.append(f"\nSTATISTICAL PROPERTIES SUMMARY:") 

514 

515 # Average statistics across all distributions 

516 avg_stats = {} 

517 for prop in ['mean', 'std', 'skewness', 'kurtosis', 'shapiro_pvalue', 'normaltest_pvalue']: 

518 values = [d.statistical_properties[prop] for d in self.shuffle_data if not np.isnan(d.statistical_properties[prop])] 

519 if values: 

520 avg_stats[prop] = np.mean(values) 

521 

522 report.append(f" Average skewness: {avg_stats.get('skewness', 'N/A'):.3f}") 

523 report.append(f" Average kurtosis: {avg_stats.get('kurtosis', 'N/A'):.3f}") 

524 report.append(f" Average Shapiro p-value: {avg_stats.get('shapiro_pvalue', 'N/A'):.3f}") 

525 report.append(f" Average normaltest p-value: {avg_stats.get('normaltest_pvalue', 'N/A'):.3f}") 

526 

527 # Distribution fitting summary 

528 if self.fit_results: 

529 report.append(f"\nDISTRIBUTION FITTING SUMMARY:") 

530 

531 # Calculate average AIC/BIC for each distribution 

532 dist_summary = {} 

533 for dist_name in ['norm', 'gamma', 'lognorm']: 

534 aics = [] 

535 bics = [] 

536 ks_stats = [] 

537 

538 for data_id, results in self.fit_results.items(): 

539 if dist_name in results: 

540 aics.append(results[dist_name].aic) 

541 bics.append(results[dist_name].bic) 

542 ks_stats.append(results[dist_name].ks_statistic) 

543 

544 if aics: 

545 dist_summary[dist_name] = { 

546 'avg_aic': np.mean(aics), 

547 'avg_bic': np.mean(bics), 

548 'avg_ks_stat': np.mean(ks_stats) 

549 } 

550 

551 for dist_name, stats in dist_summary.items(): 

552 report.append(f" {dist_name.upper()}:") 

553 report.append(f" Average AIC: {stats['avg_aic']:.2f}") 

554 report.append(f" Average BIC: {stats['avg_bic']:.2f}") 

555 report.append(f" Average KS statistic: {stats['avg_ks_stat']:.3f}") 

556 

557 # Detection performance summary 

558 performance = self.compare_detection_performance() 

559 report.append(f"\nDETECTION PERFORMANCE COMPARISON:") 

560 

561 for dist_name, metrics in performance.items(): 

562 report.append(f" {dist_name.upper()}:") 

563 report.append(f" Sensitivity: {metrics['sensitivity']:.3f}") 

564 report.append(f" Specificity: {metrics['specificity']:.3f}") 

565 report.append(f" Precision: {metrics['precision']:.3f}") 

566 report.append(f" Accuracy: {metrics['accuracy']:.3f}") 

567 

568 # Recommendations 

569 report.append(f"\nRECOMMENDATIONS:") 

570 

571 # Find best performing distribution 

572 best_dist = max(performance.keys(), key=lambda d: performance[d]['accuracy']) 

573 report.append(f" Best performing distribution: {best_dist.upper()}") 

574 

575 # Analyze why norm might be better 

576 norm_better_count = sum(1 for d in self.shuffle_data if d.p_value_norm < d.p_value_gamma) 

577 report.append(f" Cases where norm gives lower p-value: {norm_better_count}/{len(self.shuffle_data)}") 

578 

579 # Statistical significance of normality 

580 normal_like_count = sum(1 for d in self.shuffle_data 

581 if d.statistical_properties['shapiro_pvalue'] > 0.05) 

582 report.append(f" Distributions that appear normal (Shapiro p>0.05): {normal_like_count}/{len(self.shuffle_data)}") 

583 

584 return "\n".join(report) 

585 

586 def create_visualizations(self, save_path: Optional[str] = None) -> None: 

587 """ 

588 Create comprehensive visualizations of the investigation results. 

589  

590 Parameters 

591 ---------- 

592 save_path : str, optional 

593 Path to save visualizations. If None, displays plots. 

594 """ 

595 print("Creating visualizations...") 

596 

597 # Create subplots 

598 fig, axes = plt.subplots(2, 3, figsize=(18, 12)) 

599 fig.suptitle('MI Distribution Investigation Results', fontsize=16) 

600 

601 # 1. Distribution of statistical properties 

602 ax1 = axes[0, 0] 

603 skewness_values = [d.statistical_properties['skewness'] for d in self.shuffle_data] 

604 ax1.hist(skewness_values, bins=20, alpha=0.7, color='skyblue') 

605 ax1.set_xlabel('Skewness') 

606 ax1.set_ylabel('Frequency') 

607 ax1.set_title('Distribution of Skewness Values') 

608 ax1.axvline(0, color='red', linestyle='--', alpha=0.7, label='Normal (skew=0)') 

609 ax1.legend() 

610 

611 # 2. Kurtosis distribution 

612 ax2 = axes[0, 1] 

613 kurtosis_values = [d.statistical_properties['kurtosis'] for d in self.shuffle_data] 

614 ax2.hist(kurtosis_values, bins=20, alpha=0.7, color='lightgreen') 

615 ax2.set_xlabel('Kurtosis') 

616 ax2.set_ylabel('Frequency') 

617 ax2.set_title('Distribution of Kurtosis Values') 

618 ax2.axvline(0, color='red', linestyle='--', alpha=0.7, label='Normal (kurt=0)') 

619 ax2.legend() 

620 

621 # 3. P-value comparison 

622 ax3 = axes[0, 2] 

623 p_norm = [d.p_value_norm for d in self.shuffle_data] 

624 p_gamma = [d.p_value_gamma for d in self.shuffle_data] 

625 ax3.scatter(p_norm, p_gamma, alpha=0.6) 

626 ax3.plot([0, 1], [0, 1], 'r--', alpha=0.7) 

627 ax3.set_xlabel('P-value (norm)') 

628 ax3.set_ylabel('P-value (gamma)') 

629 ax3.set_title('P-value Comparison: Norm vs Gamma') 

630 ax3.set_xlim(0, 1) 

631 ax3.set_ylim(0, 1) 

632 

633 # 4. Example shuffle distribution 

634 ax4 = axes[1, 0] 

635 if self.shuffle_data: 

636 example_data = self.shuffle_data[0] 

637 ax4.hist(example_data.shuffle_values, bins=30, alpha=0.7, density=True, color='lightcoral') 

638 ax4.axvline(example_data.true_mi, color='red', linestyle='-', linewidth=2, label='True MI') 

639 ax4.set_xlabel('MI Value') 

640 ax4.set_ylabel('Density') 

641 ax4.set_title('Example Shuffle Distribution') 

642 ax4.legend() 

643 

644 # 5. Goodness of fit comparison 

645 ax5 = axes[1, 1] 

646 if self.fit_results: 

647 # Get average AIC values for each distribution 

648 avg_aics = {} 

649 for dist_name in ['norm', 'gamma', 'lognorm']: 

650 aics = [] 

651 for results in self.fit_results.values(): 

652 if dist_name in results: 

653 aics.append(results[dist_name].aic) 

654 if aics: 

655 avg_aics[dist_name] = np.mean(aics) 

656 

657 if avg_aics: 

658 dists = list(avg_aics.keys()) 

659 aics = list(avg_aics.values()) 

660 ax5.bar(dists, aics, color=['skyblue', 'lightgreen', 'lightcoral']) 

661 ax5.set_ylabel('Average AIC') 

662 ax5.set_title('Distribution Fit Quality (lower AIC = better)') 

663 ax5.tick_params(axis='x', rotation=45) 

664 

665 # 6. Detection performance 

666 ax6 = axes[1, 2] 

667 performance = self.compare_detection_performance() 

668 if performance: 

669 dists = list(performance.keys()) 

670 accuracies = [performance[d]['accuracy'] for d in dists] 

671 ax6.bar(dists, accuracies, color=['skyblue', 'lightgreen', 'lightcoral']) 

672 ax6.set_ylabel('Accuracy') 

673 ax6.set_title('Detection Performance') 

674 ax6.set_ylim(0, 1) 

675 ax6.tick_params(axis='x', rotation=45) 

676 

677 plt.tight_layout() 

678 

679 if save_path: 

680 plt.savefig(save_path, dpi=300, bbox_inches='tight') 

681 print(f"Visualizations saved to: {save_path}") 

682 else: 

683 plt.show()