Coverage for src/driada/dim_reduction/manifold_metrics.py: 61.22%

245 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2Spatial Correspondence Metrics for Manifold Preservation Evaluation 

3 

4This module provides comprehensive metrics for evaluating how well dimensionality 

5reduction methods preserve manifold structure. 

6 

7Key metric categories: 

8--------------------- 

91. Neighborhood preservation: How well are local neighborhoods preserved? 

102. Distance preservation: How well are geodesic distances preserved? 

113. Topology preservation: How well is the global structure preserved? 

124. Shape matching: How similar are the shapes after optimal alignment? 

13""" 

14 

15import numpy as np 

16from sklearn.neighbors import NearestNeighbors 

17from scipy.spatial.distance import pdist, squareform 

18from scipy.stats import spearmanr 

19from scipy.linalg import orthogonal_procrustes 

20from typing import Optional, Tuple, Union 

21 

22 

23def compute_distance_matrix(X: np.ndarray, metric: str = 'euclidean') -> np.ndarray: 

24 """ 

25 Compute pairwise distance matrix. 

26  

27 Parameters 

28 ---------- 

29 X : np.ndarray 

30 Data matrix of shape (n_samples, n_features) 

31 metric : str 

32 Distance metric to use (default: 'euclidean') 

33  

34 Returns 

35 ------- 

36 np.ndarray 

37 Symmetric distance matrix of shape (n_samples, n_samples) 

38 """ 

39 if X.ndim != 2: 

40 raise ValueError(f"X must be 2D array, got shape {X.shape}") 

41 

42 distances = pdist(X, metric=metric) 

43 return squareform(distances) 

44 

45 

46def knn_preservation_rate( 

47 X_high: np.ndarray, 

48 X_low: np.ndarray, 

49 k: int = 10, 

50 flexible: bool = False, 

51 flexibility_factor: float = 2.0 

52) -> float: 

53 """ 

54 Compute k-nearest neighbor preservation rate. 

55  

56 This metric measures what fraction of k nearest neighbors in the original 

57 high-dimensional space are preserved in the low-dimensional embedding. 

58  

59 Parameters 

60 ---------- 

61 X_high : np.ndarray 

62 Original high-dimensional data (n_samples, n_features_high) 

63 X_low : np.ndarray 

64 Low-dimensional embedding (n_samples, n_features_low) 

65 k : int 

66 Number of nearest neighbors to consider 

67 flexible : bool 

68 If True, check if k-NN are within (k * flexibility_factor)-NN in embedding 

69 flexibility_factor : float 

70 Factor to multiply k for flexible matching (default: 2.0) 

71  

72 Returns 

73 ------- 

74 float 

75 Preservation rate between 0 and 1 

76 """ 

77 if X_high.shape[0] != X_low.shape[0]: 

78 raise ValueError("X_high and X_low must have same number of samples") 

79 

80 n_samples = X_high.shape[0] 

81 if k >= n_samples: 

82 raise ValueError(f"k={k} must be less than n_samples={n_samples}") 

83 

84 # Find k-NN in original space 

85 nbrs_high = NearestNeighbors(n_neighbors=k+1).fit(X_high) 

86 _, indices_high = nbrs_high.kneighbors(X_high) 

87 indices_high = indices_high[:, 1:] # Remove self 

88 

89 # Find k-NN (or flexible k-NN) in embedded space 

90 k_low = int(k * flexibility_factor) if flexible else k 

91 nbrs_low = NearestNeighbors(n_neighbors=min(k_low+1, n_samples)).fit(X_low) 

92 _, indices_low = nbrs_low.kneighbors(X_low) 

93 indices_low = indices_low[:, 1:] # Remove self 

94 

95 # Count preserved neighbors 

96 preserved = 0 

97 for i in range(n_samples): 

98 high_neighbors = set(indices_high[i]) 

99 low_neighbors = set(indices_low[i][:k_low]) 

100 preserved += len(high_neighbors.intersection(low_neighbors)) 

101 

102 return preserved / (n_samples * k) 

103 

104 

105def trustworthiness( 

106 X_high: np.ndarray, 

107 X_low: np.ndarray, 

108 k: int = 10 

109) -> float: 

110 """ 

111 Compute trustworthiness of the embedding. 

112  

113 Trustworthiness measures how much we can trust that points nearby in the 

114 embedding are truly neighbors in the original space. 

115  

116 Parameters 

117 ---------- 

118 X_high : np.ndarray 

119 Original high-dimensional data (n_samples, n_features_high) 

120 X_low : np.ndarray 

121 Low-dimensional embedding (n_samples, n_features_low) 

122 k : int 

123 Number of nearest neighbors to consider 

124  

125 Returns 

126 ------- 

127 float 

128 Trustworthiness score between 0 and 1 

129 """ 

130 if X_high.shape[0] != X_low.shape[0]: 

131 raise ValueError("X_high and X_low must have same number of samples") 

132 

133 n_samples = X_high.shape[0] 

134 if k >= n_samples: 

135 raise ValueError(f"k={k} must be less than n_samples={n_samples}") 

136 

137 # Compute distance matrices 

138 dist_high = compute_distance_matrix(X_high) 

139 dist_low = compute_distance_matrix(X_low) 

140 

141 # Get k-NN in embedded space 

142 nbrs_low = NearestNeighbors(n_neighbors=k+1).fit(X_low) 

143 _, indices_low = nbrs_low.kneighbors(X_low) 

144 indices_low = indices_low[:, 1:] # Remove self 

145 

146 # Compute ranks in original space 

147 ranks_high = np.argsort(np.argsort(dist_high, axis=1), axis=1) 

148 

149 # Compute trustworthiness 

150 trust = 0.0 

151 for i in range(n_samples): 

152 for j in indices_low[i]: 

153 rank = ranks_high[i, j] 

154 if rank > k: 

155 trust += (rank - k) 

156 

157 # Normalize 

158 max_trust = (n_samples - k - 1) * k * n_samples / 2 

159 if max_trust > 0: 

160 trust = 1 - (2 * trust / max_trust) 

161 else: 

162 trust = 1.0 

163 

164 return trust 

165 

166 

167def continuity( 

168 X_high: np.ndarray, 

169 X_low: np.ndarray, 

170 k: int = 10 

171) -> float: 

172 """ 

173 Compute continuity of the embedding. 

174  

175 Continuity measures how well the embedding preserves the neighborhoods 

176 from the original space. 

177  

178 Parameters 

179 ---------- 

180 X_high : np.ndarray 

181 Original high-dimensional data (n_samples, n_features_high) 

182 X_low : np.ndarray 

183 Low-dimensional embedding (n_samples, n_features_low) 

184 k : int 

185 Number of nearest neighbors to consider 

186  

187 Returns 

188 ------- 

189 float 

190 Continuity score between 0 and 1 

191 """ 

192 if X_high.shape[0] != X_low.shape[0]: 

193 raise ValueError("X_high and X_low must have same number of samples") 

194 

195 n_samples = X_high.shape[0] 

196 if k >= n_samples: 

197 raise ValueError(f"k={k} must be less than n_samples={n_samples}") 

198 

199 # Compute distance matrices 

200 dist_high = compute_distance_matrix(X_high) 

201 dist_low = compute_distance_matrix(X_low) 

202 

203 # Get k-NN in original space 

204 nbrs_high = NearestNeighbors(n_neighbors=k+1).fit(X_high) 

205 _, indices_high = nbrs_high.kneighbors(X_high) 

206 indices_high = indices_high[:, 1:] # Remove self 

207 

208 # Compute ranks in embedded space 

209 ranks_low = np.argsort(np.argsort(dist_low, axis=1), axis=1) 

210 

211 # Compute continuity 

212 cont = 0.0 

213 for i in range(n_samples): 

214 for j in indices_high[i]: 

215 rank = ranks_low[i, j] 

216 if rank > k: 

217 cont += (rank - k) 

218 

219 # Normalize 

220 max_cont = (n_samples - k - 1) * k * n_samples / 2 

221 if max_cont > 0: 

222 cont = 1 - (2 * cont / max_cont) 

223 else: 

224 cont = 1.0 

225 

226 return cont 

227 

228 

229def geodesic_distance_correlation( 

230 X_high: np.ndarray, 

231 X_low: np.ndarray, 

232 k_neighbors: int = 10, 

233 method: str = 'spearman' 

234) -> float: 

235 """ 

236 Compute correlation between geodesic distances on the manifold and 

237 Euclidean distances in the embedding. 

238  

239 Uses k-NN graph to approximate geodesic distances via shortest paths. 

240  

241 Parameters 

242 ---------- 

243 X_high : np.ndarray 

244 Original high-dimensional data (n_samples, n_features_high) 

245 X_low : np.ndarray 

246 Low-dimensional embedding (n_samples, n_features_low) 

247 k_neighbors : int 

248 Number of neighbors for graph construction 

249 method : str 

250 Correlation method ('spearman' or 'pearson') 

251  

252 Returns 

253 ------- 

254 float 

255 Correlation coefficient between -1 and 1 

256 """ 

257 from sklearn.neighbors import kneighbors_graph 

258 from scipy.sparse.csgraph import shortest_path 

259 

260 # Build k-NN graph for geodesic approximation 

261 graph = kneighbors_graph(X_high, n_neighbors=k_neighbors, mode='distance') 

262 

263 # Compute geodesic distances via shortest paths 

264 geodesic_dist = shortest_path(graph, directed=False) 

265 

266 # Handle disconnected components 

267 if np.any(np.isinf(geodesic_dist)): 

268 # Use only finite distances 

269 mask = np.isfinite(geodesic_dist) 

270 geodesic_flat = geodesic_dist[mask] 

271 else: 

272 geodesic_flat = geodesic_dist[np.triu_indices_from(geodesic_dist, k=1)] 

273 

274 # Compute Euclidean distances in embedding 

275 euclidean_dist = compute_distance_matrix(X_low) 

276 

277 if np.any(np.isinf(geodesic_dist)): 

278 euclidean_flat = euclidean_dist[mask] 

279 else: 

280 euclidean_flat = euclidean_dist[np.triu_indices_from(euclidean_dist, k=1)] 

281 

282 # Compute correlation 

283 if method == 'spearman': 

284 corr, _ = spearmanr(geodesic_flat, euclidean_flat) 

285 else: # pearson 

286 corr = np.corrcoef(geodesic_flat, euclidean_flat)[0, 1] 

287 

288 return corr 

289 

290 

291def stress( 

292 X_high: np.ndarray, 

293 X_low: np.ndarray, 

294 normalized: bool = True 

295) -> float: 

296 """ 

297 Compute stress (sum of squared differences in distances). 

298  

299 Parameters 

300 ---------- 

301 X_high : np.ndarray 

302 Original high-dimensional data (n_samples, n_features_high) 

303 X_low : np.ndarray 

304 Low-dimensional embedding (n_samples, n_features_low) 

305 normalized : bool 

306 If True, normalize by sum of squared distances 

307  

308 Returns 

309 ------- 

310 float 

311 Stress value (lower is better) 

312 """ 

313 # Compute distance matrices 

314 dist_high = compute_distance_matrix(X_high) 

315 dist_low = compute_distance_matrix(X_low) 

316 

317 # Compute stress 

318 diff = dist_high - dist_low 

319 stress_val = np.sum(diff ** 2) 

320 

321 if normalized: 

322 stress_val /= np.sum(dist_high ** 2) 

323 

324 return stress_val 

325 

326 

327def circular_structure_preservation( 

328 X_low: np.ndarray, 

329 true_angles: Optional[np.ndarray] = None, 

330 k_neighbors: int = 3 

331) -> dict: 

332 """ 

333 Evaluate preservation of circular structure in embedding. 

334  

335 Parameters 

336 ---------- 

337 X_low : np.ndarray 

338 Low-dimensional embedding (n_samples, 2) 

339 true_angles : np.ndarray, optional 

340 True angles if known (for synthetic data) 

341 k_neighbors : int 

342 Number of neighbors for consecutive preservation 

343  

344 Returns 

345 ------- 

346 dict 

347 Dictionary containing various circular preservation metrics 

348 """ 

349 if X_low.shape[1] != 2: 

350 raise ValueError("Circular analysis requires 2D embedding") 

351 

352 n_samples = X_low.shape[0] 

353 

354 # Center the embedding 

355 center = np.mean(X_low, axis=0) 

356 centered = X_low - center 

357 

358 # Compute distances from center 

359 distances = np.linalg.norm(centered, axis=1) 

360 

361 # Coefficient of variation of distances (should be small for circle) 

362 cv_distances = np.std(distances) / np.mean(distances) 

363 

364 # Compute angles 

365 angles = np.arctan2(centered[:, 1], centered[:, 0]) 

366 

367 # Sort by angle to check consecutive preservation 

368 angle_order = np.argsort(angles) 

369 

370 # Check consecutive neighbor preservation 

371 nbrs = NearestNeighbors(n_neighbors=k_neighbors+1).fit(X_low) 

372 _, indices = nbrs.kneighbors(X_low) 

373 

374 consecutive_preserved = 0 

375 for i in range(n_samples): 

376 pos_in_order = np.where(angle_order == i)[0][0] 

377 prev_idx = angle_order[(pos_in_order - 1) % n_samples] 

378 next_idx = angle_order[(pos_in_order + 1) % n_samples] 

379 

380 neighbors = set(indices[i, 1:]) # Exclude self 

381 if prev_idx in neighbors or next_idx in neighbors: 

382 consecutive_preserved += 1 

383 

384 results = { 

385 'distance_cv': cv_distances, 

386 'consecutive_preservation': consecutive_preserved / n_samples 

387 } 

388 

389 # If true angles provided, compute angular correlation 

390 if true_angles is not None: 

391 # Unwrap angles to handle discontinuity 

392 angle_diff = angles - true_angles 

393 angle_diff = np.arctan2(np.sin(angle_diff), np.cos(angle_diff)) 

394 

395 # Circular correlation 

396 circular_corr = 1 - np.mean(np.abs(angle_diff)) / np.pi 

397 results['circular_correlation'] = circular_corr 

398 

399 return results 

400 

401 

402def procrustes_analysis( 

403 X: np.ndarray, 

404 Y: np.ndarray, 

405 scaling: bool = True, 

406 reflection: bool = True 

407) -> Tuple[np.ndarray, float]: 

408 """ 

409 Perform Procrustes analysis to find optimal alignment. 

410  

411 Parameters 

412 ---------- 

413 X : np.ndarray 

414 Reference configuration (n_samples, n_features) 

415 Y : np.ndarray 

416 Configuration to be aligned (n_samples, n_features) 

417 scaling : bool 

418 Whether to allow scaling 

419 reflection : bool 

420 Whether to allow reflection 

421  

422 Returns 

423 ------- 

424 Y_aligned : np.ndarray 

425 Aligned version of Y 

426 disparity : float 

427 Procrustes distance after alignment 

428 """ 

429 if X.shape != Y.shape: 

430 raise ValueError("X and Y must have the same shape") 

431 

432 # Center configurations 

433 X_centered = X - np.mean(X, axis=0) 

434 Y_centered = Y - np.mean(Y, axis=0) 

435 

436 # Compute optimal rotation 

437 R, scale = orthogonal_procrustes(Y_centered, X_centered) 

438 

439 # Apply transformation 

440 Y_aligned = Y_centered @ R 

441 

442 if scaling: 

443 # Compute optimal scaling 

444 norm_X = np.linalg.norm(X_centered) 

445 norm_Y = np.linalg.norm(Y_aligned) 

446 if norm_Y > 0: 

447 scale_factor = norm_X / norm_Y 

448 Y_aligned *= scale_factor 

449 

450 if not reflection: 

451 # Check if R includes reflection 

452 if np.linalg.det(R) < 0: 

453 # Remove reflection by flipping one axis 

454 Y_aligned[:, -1] *= -1 

455 

456 # Compute disparity 

457 disparity = np.sqrt(np.sum((X_centered - Y_aligned) ** 2)) 

458 

459 # Return aligned points (with original center) 

460 Y_aligned += np.mean(X, axis=0) 

461 

462 return Y_aligned, disparity 

463 

464 

465def manifold_preservation_score( 

466 X_high: np.ndarray, 

467 X_low: np.ndarray, 

468 k_neighbors: int = 10, 

469 weights: Optional[dict] = None 

470) -> dict: 

471 """ 

472 Compute comprehensive manifold preservation score. 

473  

474 Combines multiple metrics into an overall assessment of how well 

475 the embedding preserves manifold structure. 

476  

477 Parameters 

478 ---------- 

479 X_high : np.ndarray 

480 Original high-dimensional data (n_samples, n_features_high) 

481 X_low : np.ndarray 

482 Low-dimensional embedding (n_samples, n_features_low) 

483 k_neighbors : int 

484 Number of neighbors for local metrics 

485 weights : dict, optional 

486 Weights for combining metrics (default: equal weights) 

487  

488 Returns 

489 ------- 

490 dict 

491 Dictionary containing individual metrics and overall score 

492 """ 

493 if weights is None: 

494 weights = { 

495 'knn_preservation': 0.25, 

496 'trustworthiness': 0.25, 

497 'continuity': 0.25, 

498 'geodesic_correlation': 0.25 

499 } 

500 

501 # Compute individual metrics 

502 metrics = { 

503 'knn_preservation': knn_preservation_rate(X_high, X_low, k=k_neighbors), 

504 'trustworthiness': trustworthiness(X_high, X_low, k=k_neighbors), 

505 'continuity': continuity(X_high, X_low, k=k_neighbors), 

506 'geodesic_correlation': geodesic_distance_correlation( 

507 X_high, X_low, k_neighbors=k_neighbors 

508 ) 

509 } 

510 

511 # Handle potential NaN in geodesic correlation 

512 if np.isnan(metrics['geodesic_correlation']): 

513 metrics['geodesic_correlation'] = 0.0 

514 

515 # Compute weighted average 

516 overall_score = sum( 

517 metrics[key] * weights.get(key, 0) 

518 for key in metrics 

519 ) 

520 

521 metrics['overall_score'] = overall_score 

522 

523 return metrics 

524 

525 

526# ============================================================================= 

527# MANIFOLD RECONSTRUCTION VALIDATION 

528# ============================================================================= 

529 

530def circular_distance(angles1: np.ndarray, angles2: np.ndarray) -> np.ndarray: 

531 """Compute circular distance between two sets of angles 

532  

533 Parameters: 

534 ----------- 

535 angles1, angles2 : np.ndarray 

536 Arrays of angles in radians 

537  

538 Returns: 

539 -------- 

540 np.ndarray 

541 Circular distances between corresponding angles 

542 """ 

543 diff = angles1 - angles2 

544 return np.abs(np.arctan2(np.sin(diff), np.cos(diff))) 

545 

546 

547def extract_angles_from_embedding(embedding: np.ndarray) -> np.ndarray: 

548 """Extract angular information from 2D embedding 

549  

550 Parameters: 

551 ----------- 

552 embedding : np.ndarray 

553 2D embedding with shape (n_timepoints, 2) 

554  

555 Returns: 

556 -------- 

557 np.ndarray 

558 Extracted angles in radians 

559 """ 

560 if embedding.shape[1] != 2: 

561 raise ValueError("Embedding must be 2D for angle extraction") 

562 

563 # Center the embedding 

564 centered = embedding - np.mean(embedding, axis=0) 

565 

566 # Extract angles 

567 angles = np.arctan2(centered[:, 1], centered[:, 0]) 

568 

569 return angles 

570 

571 

572def compute_reconstruction_error( 

573 embedding: np.ndarray, 

574 true_variable: np.ndarray, 

575 manifold_type: str = 'circular' 

576) -> float: 

577 """Compute reconstruction error between embedding and ground truth 

578  

579 Parameters: 

580 ----------- 

581 embedding : np.ndarray 

582 Low-dimensional embedding 

583 true_variable : np.ndarray 

584 Ground truth variable (angles or positions) 

585 manifold_type : str 

586 Type of manifold ('circular' or 'spatial') 

587  

588 Returns: 

589 -------- 

590 float 

591 Reconstruction error 

592 """ 

593 if manifold_type == 'circular': 

594 # Extract angles from embedding 

595 reconstructed_angles = extract_angles_from_embedding(embedding) 

596 

597 # Compute circular distance 

598 distances = circular_distance(reconstructed_angles, true_variable) 

599 return np.mean(distances) 

600 

601 elif manifold_type == 'spatial': 

602 # For spatial manifolds, we need to align the embedding with true positions 

603 # Use Procrustes analysis for optimal alignment 

604 aligned_embedding, _ = procrustes_analysis(true_variable, embedding) 

605 

606 # Compute distances 

607 distances = np.linalg.norm(aligned_embedding - true_variable, axis=1) 

608 return np.mean(distances) 

609 

610 else: 

611 raise ValueError(f"Unknown manifold type: {manifold_type}") 

612 

613 

614def compute_temporal_consistency( 

615 embedding: np.ndarray, 

616 true_variable: np.ndarray, 

617 manifold_type: str = 'circular' 

618) -> float: 

619 """Compute temporal consistency between embedding and ground truth 

620  

621 Parameters: 

622 ----------- 

623 embedding : np.ndarray 

624 Low-dimensional embedding 

625 true_variable : np.ndarray 

626 Ground truth variable (angles or positions) 

627 manifold_type : str 

628 Type of manifold ('circular' or 'spatial') 

629  

630 Returns: 

631 -------- 

632 float 

633 Temporal consistency score (correlation) 

634 """ 

635 if manifold_type == 'circular': 

636 # Extract angles from embedding 

637 reconstructed_angles = extract_angles_from_embedding(embedding) 

638 

639 # Compute temporal derivatives 

640 true_velocity = np.diff(true_variable) 

641 reconstructed_velocity = np.diff(reconstructed_angles) 

642 

643 # Handle circular wrapping 

644 true_velocity = np.arctan2(np.sin(true_velocity), np.cos(true_velocity)) 

645 reconstructed_velocity = np.arctan2(np.sin(reconstructed_velocity), np.cos(reconstructed_velocity)) 

646 

647 # Compute correlation 

648 correlation = np.corrcoef(true_velocity, reconstructed_velocity)[0, 1] 

649 return correlation if not np.isnan(correlation) else 0.0 

650 

651 elif manifold_type == 'spatial': 

652 # Use Procrustes analysis for optimal alignment 

653 aligned_embedding, _ = procrustes_analysis(true_variable, embedding) 

654 

655 # Compute velocity vectors 

656 true_velocity = np.diff(true_variable, axis=0) 

657 reconstructed_velocity = np.diff(aligned_embedding, axis=0) 

658 

659 # Compute correlation of velocity magnitudes 

660 true_speed = np.linalg.norm(true_velocity, axis=1) 

661 reconstructed_speed = np.linalg.norm(reconstructed_velocity, axis=1) 

662 

663 correlation = np.corrcoef(true_speed, reconstructed_speed)[0, 1] 

664 return correlation if not np.isnan(correlation) else 0.0 

665 

666 else: 

667 raise ValueError(f"Unknown manifold type: {manifold_type}") 

668 

669 

670def train_simple_decoder(embedding: np.ndarray, true_variable: np.ndarray, manifold_type: str = 'circular'): 

671 """Train a simple decoder from embedding to ground truth variable 

672  

673 Parameters: 

674 ----------- 

675 embedding : np.ndarray 

676 Low-dimensional embedding with shape (n_timepoints, n_features) 

677 true_variable : np.ndarray 

678 Ground truth variable (angles or positions) 

679 manifold_type : str 

680 Type of manifold ('circular' or 'spatial') 

681  

682 Returns: 

683 -------- 

684 callable 

685 Trained decoder function 

686 """ 

687 from sklearn.linear_model import LinearRegression 

688 from sklearn.preprocessing import StandardScaler 

689 

690 # Ensure embedding has correct shape 

691 if embedding.shape[0] != true_variable.shape[0]: 

692 raise ValueError(f"Embedding and true_variable must have same number of timepoints. " 

693 f"Got embedding: {embedding.shape}, true_variable: {true_variable.shape}") 

694 

695 # Standardize embedding 

696 scaler = StandardScaler() 

697 embedding_scaled = scaler.fit_transform(embedding) 

698 

699 if manifold_type == 'circular': 

700 # For circular variables, predict sin and cos components 

701 sin_component = np.sin(true_variable) 

702 cos_component = np.cos(true_variable) 

703 

704 # Train separate regressors for sin and cos 

705 sin_regressor = LinearRegression().fit(embedding_scaled, sin_component) 

706 cos_regressor = LinearRegression().fit(embedding_scaled, cos_component) 

707 

708 def decoder(new_embedding): 

709 new_embedding_scaled = scaler.transform(new_embedding) 

710 pred_sin = sin_regressor.predict(new_embedding_scaled) 

711 pred_cos = cos_regressor.predict(new_embedding_scaled) 

712 return np.arctan2(pred_sin, pred_cos) 

713 

714 elif manifold_type == 'spatial': 

715 # For spatial variables, direct regression 

716 regressor = LinearRegression().fit(embedding_scaled, true_variable) 

717 

718 def decoder(new_embedding): 

719 new_embedding_scaled = scaler.transform(new_embedding) 

720 return regressor.predict(new_embedding_scaled) 

721 

722 else: 

723 raise ValueError(f"Unknown manifold type: {manifold_type}") 

724 

725 return decoder 

726 

727 

728def compute_decoding_accuracy( 

729 embedding: np.ndarray, 

730 true_variable: np.ndarray, 

731 manifold_type: str = 'circular', 

732 train_fraction: float = 0.8 

733) -> dict: 

734 """Compute decoding accuracy using train/test split 

735  

736 Parameters: 

737 ----------- 

738 embedding : np.ndarray 

739 Low-dimensional embedding 

740 true_variable : np.ndarray 

741 Ground truth variable (angles or positions) 

742 manifold_type : str 

743 Type of manifold ('circular' or 'spatial') 

744 train_fraction : float 

745 Fraction of data to use for training 

746  

747 Returns: 

748 -------- 

749 dict 

750 Dictionary containing training and testing errors 

751 """ 

752 n_samples = embedding.shape[0] 

753 n_train = int(n_samples * train_fraction) 

754 

755 # Split data 

756 train_embedding = embedding[:n_train] 

757 test_embedding = embedding[n_train:] 

758 train_variable = true_variable[:n_train] 

759 test_variable = true_variable[n_train:] 

760 

761 # Train decoder 

762 decoder = train_simple_decoder(train_embedding, train_variable, manifold_type) 

763 

764 # Compute training error 

765 train_predictions = decoder(train_embedding) 

766 train_error = compute_reconstruction_error( 

767 train_embedding, train_variable, manifold_type 

768 ) if manifold_type == 'circular' else np.mean( 

769 np.linalg.norm(train_predictions - train_variable, axis=1) 

770 ) 

771 

772 # Compute testing error 

773 test_predictions = decoder(test_embedding) 

774 if manifold_type == 'circular': 

775 test_error = np.mean(circular_distance(test_predictions, test_variable)) 

776 else: 

777 test_error = np.mean(np.linalg.norm(test_predictions - test_variable, axis=1)) 

778 

779 return { 

780 'train_error': train_error, 

781 'test_error': test_error, 

782 'generalization_gap': test_error - train_error 

783 } 

784 

785 

786def manifold_reconstruction_score( 

787 embedding: np.ndarray, 

788 true_variable: np.ndarray, 

789 manifold_type: str = 'circular', 

790 weights: Optional[dict] = None 

791) -> dict: 

792 """Compute comprehensive manifold reconstruction score 

793  

794 Parameters: 

795 ----------- 

796 embedding : np.ndarray 

797 Low-dimensional embedding 

798 true_variable : np.ndarray 

799 Ground truth variable (angles or positions) 

800 manifold_type : str 

801 Type of manifold ('circular' or 'spatial') 

802 weights : dict, optional 

803 Weights for combining metrics 

804  

805 Returns: 

806 -------- 

807 dict 

808 Dictionary containing reconstruction metrics 

809 """ 

810 if weights is None: 

811 weights = { 

812 'reconstruction_error': 0.4, 

813 'temporal_consistency': 0.3, 

814 'decoding_accuracy': 0.3 

815 } 

816 

817 # Compute metrics 

818 reconstruction_error = compute_reconstruction_error(embedding, true_variable, manifold_type) 

819 temporal_consistency = compute_temporal_consistency(embedding, true_variable, manifold_type) 

820 decoding_results = compute_decoding_accuracy(embedding, true_variable, manifold_type) 

821 

822 # Normalize reconstruction error (lower is better, so invert) 

823 max_error = np.pi if manifold_type == 'circular' else 1.0 # Normalized for spatial 

824 normalized_error = 1.0 - min(reconstruction_error / max_error, 1.0) 

825 

826 # Normalize decoding accuracy (lower test error is better) 

827 max_decode_error = np.pi if manifold_type == 'circular' else 1.0 

828 normalized_decode = 1.0 - min(decoding_results['test_error'] / max_decode_error, 1.0) 

829 

830 # Ensure temporal consistency is positive 

831 temporal_consistency = max(temporal_consistency, 0.0) 

832 

833 # Compute weighted score 

834 overall_score = ( 

835 weights['reconstruction_error'] * normalized_error + 

836 weights['temporal_consistency'] * temporal_consistency + 

837 weights['decoding_accuracy'] * normalized_decode 

838 ) 

839 

840 return { 

841 'reconstruction_error': reconstruction_error, 

842 'temporal_consistency': temporal_consistency, 

843 'decoding_train_error': decoding_results['train_error'], 

844 'decoding_test_error': decoding_results['test_error'], 

845 'generalization_gap': decoding_results['generalization_gap'], 

846 'overall_reconstruction_score': overall_score 

847 }