Coverage for src/driada/dim_reduction/manifold_metrics.py: 61.22%
245 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2Spatial Correspondence Metrics for Manifold Preservation Evaluation
4This module provides comprehensive metrics for evaluating how well dimensionality
5reduction methods preserve manifold structure.
7Key metric categories:
8---------------------
91. Neighborhood preservation: How well are local neighborhoods preserved?
102. Distance preservation: How well are geodesic distances preserved?
113. Topology preservation: How well is the global structure preserved?
124. Shape matching: How similar are the shapes after optimal alignment?
13"""
15import numpy as np
16from sklearn.neighbors import NearestNeighbors
17from scipy.spatial.distance import pdist, squareform
18from scipy.stats import spearmanr
19from scipy.linalg import orthogonal_procrustes
20from typing import Optional, Tuple, Union
23def compute_distance_matrix(X: np.ndarray, metric: str = 'euclidean') -> np.ndarray:
24 """
25 Compute pairwise distance matrix.
27 Parameters
28 ----------
29 X : np.ndarray
30 Data matrix of shape (n_samples, n_features)
31 metric : str
32 Distance metric to use (default: 'euclidean')
34 Returns
35 -------
36 np.ndarray
37 Symmetric distance matrix of shape (n_samples, n_samples)
38 """
39 if X.ndim != 2:
40 raise ValueError(f"X must be 2D array, got shape {X.shape}")
42 distances = pdist(X, metric=metric)
43 return squareform(distances)
46def knn_preservation_rate(
47 X_high: np.ndarray,
48 X_low: np.ndarray,
49 k: int = 10,
50 flexible: bool = False,
51 flexibility_factor: float = 2.0
52) -> float:
53 """
54 Compute k-nearest neighbor preservation rate.
56 This metric measures what fraction of k nearest neighbors in the original
57 high-dimensional space are preserved in the low-dimensional embedding.
59 Parameters
60 ----------
61 X_high : np.ndarray
62 Original high-dimensional data (n_samples, n_features_high)
63 X_low : np.ndarray
64 Low-dimensional embedding (n_samples, n_features_low)
65 k : int
66 Number of nearest neighbors to consider
67 flexible : bool
68 If True, check if k-NN are within (k * flexibility_factor)-NN in embedding
69 flexibility_factor : float
70 Factor to multiply k for flexible matching (default: 2.0)
72 Returns
73 -------
74 float
75 Preservation rate between 0 and 1
76 """
77 if X_high.shape[0] != X_low.shape[0]:
78 raise ValueError("X_high and X_low must have same number of samples")
80 n_samples = X_high.shape[0]
81 if k >= n_samples:
82 raise ValueError(f"k={k} must be less than n_samples={n_samples}")
84 # Find k-NN in original space
85 nbrs_high = NearestNeighbors(n_neighbors=k+1).fit(X_high)
86 _, indices_high = nbrs_high.kneighbors(X_high)
87 indices_high = indices_high[:, 1:] # Remove self
89 # Find k-NN (or flexible k-NN) in embedded space
90 k_low = int(k * flexibility_factor) if flexible else k
91 nbrs_low = NearestNeighbors(n_neighbors=min(k_low+1, n_samples)).fit(X_low)
92 _, indices_low = nbrs_low.kneighbors(X_low)
93 indices_low = indices_low[:, 1:] # Remove self
95 # Count preserved neighbors
96 preserved = 0
97 for i in range(n_samples):
98 high_neighbors = set(indices_high[i])
99 low_neighbors = set(indices_low[i][:k_low])
100 preserved += len(high_neighbors.intersection(low_neighbors))
102 return preserved / (n_samples * k)
105def trustworthiness(
106 X_high: np.ndarray,
107 X_low: np.ndarray,
108 k: int = 10
109) -> float:
110 """
111 Compute trustworthiness of the embedding.
113 Trustworthiness measures how much we can trust that points nearby in the
114 embedding are truly neighbors in the original space.
116 Parameters
117 ----------
118 X_high : np.ndarray
119 Original high-dimensional data (n_samples, n_features_high)
120 X_low : np.ndarray
121 Low-dimensional embedding (n_samples, n_features_low)
122 k : int
123 Number of nearest neighbors to consider
125 Returns
126 -------
127 float
128 Trustworthiness score between 0 and 1
129 """
130 if X_high.shape[0] != X_low.shape[0]:
131 raise ValueError("X_high and X_low must have same number of samples")
133 n_samples = X_high.shape[0]
134 if k >= n_samples:
135 raise ValueError(f"k={k} must be less than n_samples={n_samples}")
137 # Compute distance matrices
138 dist_high = compute_distance_matrix(X_high)
139 dist_low = compute_distance_matrix(X_low)
141 # Get k-NN in embedded space
142 nbrs_low = NearestNeighbors(n_neighbors=k+1).fit(X_low)
143 _, indices_low = nbrs_low.kneighbors(X_low)
144 indices_low = indices_low[:, 1:] # Remove self
146 # Compute ranks in original space
147 ranks_high = np.argsort(np.argsort(dist_high, axis=1), axis=1)
149 # Compute trustworthiness
150 trust = 0.0
151 for i in range(n_samples):
152 for j in indices_low[i]:
153 rank = ranks_high[i, j]
154 if rank > k:
155 trust += (rank - k)
157 # Normalize
158 max_trust = (n_samples - k - 1) * k * n_samples / 2
159 if max_trust > 0:
160 trust = 1 - (2 * trust / max_trust)
161 else:
162 trust = 1.0
164 return trust
167def continuity(
168 X_high: np.ndarray,
169 X_low: np.ndarray,
170 k: int = 10
171) -> float:
172 """
173 Compute continuity of the embedding.
175 Continuity measures how well the embedding preserves the neighborhoods
176 from the original space.
178 Parameters
179 ----------
180 X_high : np.ndarray
181 Original high-dimensional data (n_samples, n_features_high)
182 X_low : np.ndarray
183 Low-dimensional embedding (n_samples, n_features_low)
184 k : int
185 Number of nearest neighbors to consider
187 Returns
188 -------
189 float
190 Continuity score between 0 and 1
191 """
192 if X_high.shape[0] != X_low.shape[0]:
193 raise ValueError("X_high and X_low must have same number of samples")
195 n_samples = X_high.shape[0]
196 if k >= n_samples:
197 raise ValueError(f"k={k} must be less than n_samples={n_samples}")
199 # Compute distance matrices
200 dist_high = compute_distance_matrix(X_high)
201 dist_low = compute_distance_matrix(X_low)
203 # Get k-NN in original space
204 nbrs_high = NearestNeighbors(n_neighbors=k+1).fit(X_high)
205 _, indices_high = nbrs_high.kneighbors(X_high)
206 indices_high = indices_high[:, 1:] # Remove self
208 # Compute ranks in embedded space
209 ranks_low = np.argsort(np.argsort(dist_low, axis=1), axis=1)
211 # Compute continuity
212 cont = 0.0
213 for i in range(n_samples):
214 for j in indices_high[i]:
215 rank = ranks_low[i, j]
216 if rank > k:
217 cont += (rank - k)
219 # Normalize
220 max_cont = (n_samples - k - 1) * k * n_samples / 2
221 if max_cont > 0:
222 cont = 1 - (2 * cont / max_cont)
223 else:
224 cont = 1.0
226 return cont
229def geodesic_distance_correlation(
230 X_high: np.ndarray,
231 X_low: np.ndarray,
232 k_neighbors: int = 10,
233 method: str = 'spearman'
234) -> float:
235 """
236 Compute correlation between geodesic distances on the manifold and
237 Euclidean distances in the embedding.
239 Uses k-NN graph to approximate geodesic distances via shortest paths.
241 Parameters
242 ----------
243 X_high : np.ndarray
244 Original high-dimensional data (n_samples, n_features_high)
245 X_low : np.ndarray
246 Low-dimensional embedding (n_samples, n_features_low)
247 k_neighbors : int
248 Number of neighbors for graph construction
249 method : str
250 Correlation method ('spearman' or 'pearson')
252 Returns
253 -------
254 float
255 Correlation coefficient between -1 and 1
256 """
257 from sklearn.neighbors import kneighbors_graph
258 from scipy.sparse.csgraph import shortest_path
260 # Build k-NN graph for geodesic approximation
261 graph = kneighbors_graph(X_high, n_neighbors=k_neighbors, mode='distance')
263 # Compute geodesic distances via shortest paths
264 geodesic_dist = shortest_path(graph, directed=False)
266 # Handle disconnected components
267 if np.any(np.isinf(geodesic_dist)):
268 # Use only finite distances
269 mask = np.isfinite(geodesic_dist)
270 geodesic_flat = geodesic_dist[mask]
271 else:
272 geodesic_flat = geodesic_dist[np.triu_indices_from(geodesic_dist, k=1)]
274 # Compute Euclidean distances in embedding
275 euclidean_dist = compute_distance_matrix(X_low)
277 if np.any(np.isinf(geodesic_dist)):
278 euclidean_flat = euclidean_dist[mask]
279 else:
280 euclidean_flat = euclidean_dist[np.triu_indices_from(euclidean_dist, k=1)]
282 # Compute correlation
283 if method == 'spearman':
284 corr, _ = spearmanr(geodesic_flat, euclidean_flat)
285 else: # pearson
286 corr = np.corrcoef(geodesic_flat, euclidean_flat)[0, 1]
288 return corr
291def stress(
292 X_high: np.ndarray,
293 X_low: np.ndarray,
294 normalized: bool = True
295) -> float:
296 """
297 Compute stress (sum of squared differences in distances).
299 Parameters
300 ----------
301 X_high : np.ndarray
302 Original high-dimensional data (n_samples, n_features_high)
303 X_low : np.ndarray
304 Low-dimensional embedding (n_samples, n_features_low)
305 normalized : bool
306 If True, normalize by sum of squared distances
308 Returns
309 -------
310 float
311 Stress value (lower is better)
312 """
313 # Compute distance matrices
314 dist_high = compute_distance_matrix(X_high)
315 dist_low = compute_distance_matrix(X_low)
317 # Compute stress
318 diff = dist_high - dist_low
319 stress_val = np.sum(diff ** 2)
321 if normalized:
322 stress_val /= np.sum(dist_high ** 2)
324 return stress_val
327def circular_structure_preservation(
328 X_low: np.ndarray,
329 true_angles: Optional[np.ndarray] = None,
330 k_neighbors: int = 3
331) -> dict:
332 """
333 Evaluate preservation of circular structure in embedding.
335 Parameters
336 ----------
337 X_low : np.ndarray
338 Low-dimensional embedding (n_samples, 2)
339 true_angles : np.ndarray, optional
340 True angles if known (for synthetic data)
341 k_neighbors : int
342 Number of neighbors for consecutive preservation
344 Returns
345 -------
346 dict
347 Dictionary containing various circular preservation metrics
348 """
349 if X_low.shape[1] != 2:
350 raise ValueError("Circular analysis requires 2D embedding")
352 n_samples = X_low.shape[0]
354 # Center the embedding
355 center = np.mean(X_low, axis=0)
356 centered = X_low - center
358 # Compute distances from center
359 distances = np.linalg.norm(centered, axis=1)
361 # Coefficient of variation of distances (should be small for circle)
362 cv_distances = np.std(distances) / np.mean(distances)
364 # Compute angles
365 angles = np.arctan2(centered[:, 1], centered[:, 0])
367 # Sort by angle to check consecutive preservation
368 angle_order = np.argsort(angles)
370 # Check consecutive neighbor preservation
371 nbrs = NearestNeighbors(n_neighbors=k_neighbors+1).fit(X_low)
372 _, indices = nbrs.kneighbors(X_low)
374 consecutive_preserved = 0
375 for i in range(n_samples):
376 pos_in_order = np.where(angle_order == i)[0][0]
377 prev_idx = angle_order[(pos_in_order - 1) % n_samples]
378 next_idx = angle_order[(pos_in_order + 1) % n_samples]
380 neighbors = set(indices[i, 1:]) # Exclude self
381 if prev_idx in neighbors or next_idx in neighbors:
382 consecutive_preserved += 1
384 results = {
385 'distance_cv': cv_distances,
386 'consecutive_preservation': consecutive_preserved / n_samples
387 }
389 # If true angles provided, compute angular correlation
390 if true_angles is not None:
391 # Unwrap angles to handle discontinuity
392 angle_diff = angles - true_angles
393 angle_diff = np.arctan2(np.sin(angle_diff), np.cos(angle_diff))
395 # Circular correlation
396 circular_corr = 1 - np.mean(np.abs(angle_diff)) / np.pi
397 results['circular_correlation'] = circular_corr
399 return results
402def procrustes_analysis(
403 X: np.ndarray,
404 Y: np.ndarray,
405 scaling: bool = True,
406 reflection: bool = True
407) -> Tuple[np.ndarray, float]:
408 """
409 Perform Procrustes analysis to find optimal alignment.
411 Parameters
412 ----------
413 X : np.ndarray
414 Reference configuration (n_samples, n_features)
415 Y : np.ndarray
416 Configuration to be aligned (n_samples, n_features)
417 scaling : bool
418 Whether to allow scaling
419 reflection : bool
420 Whether to allow reflection
422 Returns
423 -------
424 Y_aligned : np.ndarray
425 Aligned version of Y
426 disparity : float
427 Procrustes distance after alignment
428 """
429 if X.shape != Y.shape:
430 raise ValueError("X and Y must have the same shape")
432 # Center configurations
433 X_centered = X - np.mean(X, axis=0)
434 Y_centered = Y - np.mean(Y, axis=0)
436 # Compute optimal rotation
437 R, scale = orthogonal_procrustes(Y_centered, X_centered)
439 # Apply transformation
440 Y_aligned = Y_centered @ R
442 if scaling:
443 # Compute optimal scaling
444 norm_X = np.linalg.norm(X_centered)
445 norm_Y = np.linalg.norm(Y_aligned)
446 if norm_Y > 0:
447 scale_factor = norm_X / norm_Y
448 Y_aligned *= scale_factor
450 if not reflection:
451 # Check if R includes reflection
452 if np.linalg.det(R) < 0:
453 # Remove reflection by flipping one axis
454 Y_aligned[:, -1] *= -1
456 # Compute disparity
457 disparity = np.sqrt(np.sum((X_centered - Y_aligned) ** 2))
459 # Return aligned points (with original center)
460 Y_aligned += np.mean(X, axis=0)
462 return Y_aligned, disparity
465def manifold_preservation_score(
466 X_high: np.ndarray,
467 X_low: np.ndarray,
468 k_neighbors: int = 10,
469 weights: Optional[dict] = None
470) -> dict:
471 """
472 Compute comprehensive manifold preservation score.
474 Combines multiple metrics into an overall assessment of how well
475 the embedding preserves manifold structure.
477 Parameters
478 ----------
479 X_high : np.ndarray
480 Original high-dimensional data (n_samples, n_features_high)
481 X_low : np.ndarray
482 Low-dimensional embedding (n_samples, n_features_low)
483 k_neighbors : int
484 Number of neighbors for local metrics
485 weights : dict, optional
486 Weights for combining metrics (default: equal weights)
488 Returns
489 -------
490 dict
491 Dictionary containing individual metrics and overall score
492 """
493 if weights is None:
494 weights = {
495 'knn_preservation': 0.25,
496 'trustworthiness': 0.25,
497 'continuity': 0.25,
498 'geodesic_correlation': 0.25
499 }
501 # Compute individual metrics
502 metrics = {
503 'knn_preservation': knn_preservation_rate(X_high, X_low, k=k_neighbors),
504 'trustworthiness': trustworthiness(X_high, X_low, k=k_neighbors),
505 'continuity': continuity(X_high, X_low, k=k_neighbors),
506 'geodesic_correlation': geodesic_distance_correlation(
507 X_high, X_low, k_neighbors=k_neighbors
508 )
509 }
511 # Handle potential NaN in geodesic correlation
512 if np.isnan(metrics['geodesic_correlation']):
513 metrics['geodesic_correlation'] = 0.0
515 # Compute weighted average
516 overall_score = sum(
517 metrics[key] * weights.get(key, 0)
518 for key in metrics
519 )
521 metrics['overall_score'] = overall_score
523 return metrics
526# =============================================================================
527# MANIFOLD RECONSTRUCTION VALIDATION
528# =============================================================================
530def circular_distance(angles1: np.ndarray, angles2: np.ndarray) -> np.ndarray:
531 """Compute circular distance between two sets of angles
533 Parameters:
534 -----------
535 angles1, angles2 : np.ndarray
536 Arrays of angles in radians
538 Returns:
539 --------
540 np.ndarray
541 Circular distances between corresponding angles
542 """
543 diff = angles1 - angles2
544 return np.abs(np.arctan2(np.sin(diff), np.cos(diff)))
547def extract_angles_from_embedding(embedding: np.ndarray) -> np.ndarray:
548 """Extract angular information from 2D embedding
550 Parameters:
551 -----------
552 embedding : np.ndarray
553 2D embedding with shape (n_timepoints, 2)
555 Returns:
556 --------
557 np.ndarray
558 Extracted angles in radians
559 """
560 if embedding.shape[1] != 2:
561 raise ValueError("Embedding must be 2D for angle extraction")
563 # Center the embedding
564 centered = embedding - np.mean(embedding, axis=0)
566 # Extract angles
567 angles = np.arctan2(centered[:, 1], centered[:, 0])
569 return angles
572def compute_reconstruction_error(
573 embedding: np.ndarray,
574 true_variable: np.ndarray,
575 manifold_type: str = 'circular'
576) -> float:
577 """Compute reconstruction error between embedding and ground truth
579 Parameters:
580 -----------
581 embedding : np.ndarray
582 Low-dimensional embedding
583 true_variable : np.ndarray
584 Ground truth variable (angles or positions)
585 manifold_type : str
586 Type of manifold ('circular' or 'spatial')
588 Returns:
589 --------
590 float
591 Reconstruction error
592 """
593 if manifold_type == 'circular':
594 # Extract angles from embedding
595 reconstructed_angles = extract_angles_from_embedding(embedding)
597 # Compute circular distance
598 distances = circular_distance(reconstructed_angles, true_variable)
599 return np.mean(distances)
601 elif manifold_type == 'spatial':
602 # For spatial manifolds, we need to align the embedding with true positions
603 # Use Procrustes analysis for optimal alignment
604 aligned_embedding, _ = procrustes_analysis(true_variable, embedding)
606 # Compute distances
607 distances = np.linalg.norm(aligned_embedding - true_variable, axis=1)
608 return np.mean(distances)
610 else:
611 raise ValueError(f"Unknown manifold type: {manifold_type}")
614def compute_temporal_consistency(
615 embedding: np.ndarray,
616 true_variable: np.ndarray,
617 manifold_type: str = 'circular'
618) -> float:
619 """Compute temporal consistency between embedding and ground truth
621 Parameters:
622 -----------
623 embedding : np.ndarray
624 Low-dimensional embedding
625 true_variable : np.ndarray
626 Ground truth variable (angles or positions)
627 manifold_type : str
628 Type of manifold ('circular' or 'spatial')
630 Returns:
631 --------
632 float
633 Temporal consistency score (correlation)
634 """
635 if manifold_type == 'circular':
636 # Extract angles from embedding
637 reconstructed_angles = extract_angles_from_embedding(embedding)
639 # Compute temporal derivatives
640 true_velocity = np.diff(true_variable)
641 reconstructed_velocity = np.diff(reconstructed_angles)
643 # Handle circular wrapping
644 true_velocity = np.arctan2(np.sin(true_velocity), np.cos(true_velocity))
645 reconstructed_velocity = np.arctan2(np.sin(reconstructed_velocity), np.cos(reconstructed_velocity))
647 # Compute correlation
648 correlation = np.corrcoef(true_velocity, reconstructed_velocity)[0, 1]
649 return correlation if not np.isnan(correlation) else 0.0
651 elif manifold_type == 'spatial':
652 # Use Procrustes analysis for optimal alignment
653 aligned_embedding, _ = procrustes_analysis(true_variable, embedding)
655 # Compute velocity vectors
656 true_velocity = np.diff(true_variable, axis=0)
657 reconstructed_velocity = np.diff(aligned_embedding, axis=0)
659 # Compute correlation of velocity magnitudes
660 true_speed = np.linalg.norm(true_velocity, axis=1)
661 reconstructed_speed = np.linalg.norm(reconstructed_velocity, axis=1)
663 correlation = np.corrcoef(true_speed, reconstructed_speed)[0, 1]
664 return correlation if not np.isnan(correlation) else 0.0
666 else:
667 raise ValueError(f"Unknown manifold type: {manifold_type}")
670def train_simple_decoder(embedding: np.ndarray, true_variable: np.ndarray, manifold_type: str = 'circular'):
671 """Train a simple decoder from embedding to ground truth variable
673 Parameters:
674 -----------
675 embedding : np.ndarray
676 Low-dimensional embedding with shape (n_timepoints, n_features)
677 true_variable : np.ndarray
678 Ground truth variable (angles or positions)
679 manifold_type : str
680 Type of manifold ('circular' or 'spatial')
682 Returns:
683 --------
684 callable
685 Trained decoder function
686 """
687 from sklearn.linear_model import LinearRegression
688 from sklearn.preprocessing import StandardScaler
690 # Ensure embedding has correct shape
691 if embedding.shape[0] != true_variable.shape[0]:
692 raise ValueError(f"Embedding and true_variable must have same number of timepoints. "
693 f"Got embedding: {embedding.shape}, true_variable: {true_variable.shape}")
695 # Standardize embedding
696 scaler = StandardScaler()
697 embedding_scaled = scaler.fit_transform(embedding)
699 if manifold_type == 'circular':
700 # For circular variables, predict sin and cos components
701 sin_component = np.sin(true_variable)
702 cos_component = np.cos(true_variable)
704 # Train separate regressors for sin and cos
705 sin_regressor = LinearRegression().fit(embedding_scaled, sin_component)
706 cos_regressor = LinearRegression().fit(embedding_scaled, cos_component)
708 def decoder(new_embedding):
709 new_embedding_scaled = scaler.transform(new_embedding)
710 pred_sin = sin_regressor.predict(new_embedding_scaled)
711 pred_cos = cos_regressor.predict(new_embedding_scaled)
712 return np.arctan2(pred_sin, pred_cos)
714 elif manifold_type == 'spatial':
715 # For spatial variables, direct regression
716 regressor = LinearRegression().fit(embedding_scaled, true_variable)
718 def decoder(new_embedding):
719 new_embedding_scaled = scaler.transform(new_embedding)
720 return regressor.predict(new_embedding_scaled)
722 else:
723 raise ValueError(f"Unknown manifold type: {manifold_type}")
725 return decoder
728def compute_decoding_accuracy(
729 embedding: np.ndarray,
730 true_variable: np.ndarray,
731 manifold_type: str = 'circular',
732 train_fraction: float = 0.8
733) -> dict:
734 """Compute decoding accuracy using train/test split
736 Parameters:
737 -----------
738 embedding : np.ndarray
739 Low-dimensional embedding
740 true_variable : np.ndarray
741 Ground truth variable (angles or positions)
742 manifold_type : str
743 Type of manifold ('circular' or 'spatial')
744 train_fraction : float
745 Fraction of data to use for training
747 Returns:
748 --------
749 dict
750 Dictionary containing training and testing errors
751 """
752 n_samples = embedding.shape[0]
753 n_train = int(n_samples * train_fraction)
755 # Split data
756 train_embedding = embedding[:n_train]
757 test_embedding = embedding[n_train:]
758 train_variable = true_variable[:n_train]
759 test_variable = true_variable[n_train:]
761 # Train decoder
762 decoder = train_simple_decoder(train_embedding, train_variable, manifold_type)
764 # Compute training error
765 train_predictions = decoder(train_embedding)
766 train_error = compute_reconstruction_error(
767 train_embedding, train_variable, manifold_type
768 ) if manifold_type == 'circular' else np.mean(
769 np.linalg.norm(train_predictions - train_variable, axis=1)
770 )
772 # Compute testing error
773 test_predictions = decoder(test_embedding)
774 if manifold_type == 'circular':
775 test_error = np.mean(circular_distance(test_predictions, test_variable))
776 else:
777 test_error = np.mean(np.linalg.norm(test_predictions - test_variable, axis=1))
779 return {
780 'train_error': train_error,
781 'test_error': test_error,
782 'generalization_gap': test_error - train_error
783 }
786def manifold_reconstruction_score(
787 embedding: np.ndarray,
788 true_variable: np.ndarray,
789 manifold_type: str = 'circular',
790 weights: Optional[dict] = None
791) -> dict:
792 """Compute comprehensive manifold reconstruction score
794 Parameters:
795 -----------
796 embedding : np.ndarray
797 Low-dimensional embedding
798 true_variable : np.ndarray
799 Ground truth variable (angles or positions)
800 manifold_type : str
801 Type of manifold ('circular' or 'spatial')
802 weights : dict, optional
803 Weights for combining metrics
805 Returns:
806 --------
807 dict
808 Dictionary containing reconstruction metrics
809 """
810 if weights is None:
811 weights = {
812 'reconstruction_error': 0.4,
813 'temporal_consistency': 0.3,
814 'decoding_accuracy': 0.3
815 }
817 # Compute metrics
818 reconstruction_error = compute_reconstruction_error(embedding, true_variable, manifold_type)
819 temporal_consistency = compute_temporal_consistency(embedding, true_variable, manifold_type)
820 decoding_results = compute_decoding_accuracy(embedding, true_variable, manifold_type)
822 # Normalize reconstruction error (lower is better, so invert)
823 max_error = np.pi if manifold_type == 'circular' else 1.0 # Normalized for spatial
824 normalized_error = 1.0 - min(reconstruction_error / max_error, 1.0)
826 # Normalize decoding accuracy (lower test error is better)
827 max_decode_error = np.pi if manifold_type == 'circular' else 1.0
828 normalized_decode = 1.0 - min(decoding_results['test_error'] / max_decode_error, 1.0)
830 # Ensure temporal consistency is positive
831 temporal_consistency = max(temporal_consistency, 0.0)
833 # Compute weighted score
834 overall_score = (
835 weights['reconstruction_error'] * normalized_error +
836 weights['temporal_consistency'] * temporal_consistency +
837 weights['decoding_accuracy'] * normalized_decode
838 )
840 return {
841 'reconstruction_error': reconstruction_error,
842 'temporal_consistency': temporal_consistency,
843 'decoding_train_error': decoding_results['train_error'],
844 'decoding_test_error': decoding_results['test_error'],
845 'generalization_gap': decoding_results['generalization_gap'],
846 'overall_reconstruction_score': overall_score
847 }