Coverage for src/driada/utils/spatial.py: 93.46%
214 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1"""
2Spatial Analysis Utilities for Neural Data
3==========================================
5This module provides comprehensive spatial analysis tools for neural data,
6particularly for analyzing place cells, grid cells, and other spatially-modulated neurons.
8Key functionality:
9- Place field detection and analysis
10- Grid score computation
11- Spatial information metrics
12- Position decoding
13- Speed/direction filtering
14- High-level spatial analysis pipelines
15"""
17import numpy as np
18from scipy import ndimage, signal, stats
19from scipy.spatial.distance import pdist, squareform
20from sklearn.model_selection import train_test_split
21from sklearn.ensemble import RandomForestRegressor
22from sklearn.metrics import r2_score
23from typing import Optional, Tuple, Dict, List, Union
24import logging
26from ..information import TimeSeries, MultiTimeSeries, get_sim
29def compute_occupancy_map(
30 positions: np.ndarray,
31 arena_bounds: Optional[Tuple[Tuple[float, float], Tuple[float, float]]] = None,
32 bin_size: float = 0.025,
33 min_occupancy: float = 0.1,
34 smooth_sigma: Optional[float] = None
35) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
36 """
37 Compute 2D occupancy map from position data.
39 Parameters
40 ----------
41 positions : np.ndarray, shape (n_samples, 2)
42 X, Y positions over time
43 arena_bounds : tuple of tuples, optional
44 ((x_min, x_max), (y_min, y_max)). If None, inferred from data
45 bin_size : float
46 Size of spatial bins in same units as positions
47 min_occupancy : float
48 Minimum occupancy time (seconds) for valid bins
49 smooth_sigma : float, optional
50 Gaussian smoothing sigma in bins. If None, no smoothing
52 Returns
53 -------
54 occupancy_map : np.ndarray
55 2D occupancy map in seconds
56 x_edges : np.ndarray
57 X bin edges
58 y_edges : np.ndarray
59 Y bin edges
60 """
61 if positions.shape[1] != 2:
62 raise ValueError(f"Positions must be 2D, got shape {positions.shape}")
64 # Determine arena bounds
65 if arena_bounds is None:
66 x_min, x_max = positions[:, 0].min(), positions[:, 0].max()
67 y_min, y_max = positions[:, 1].min(), positions[:, 1].max()
68 # Add small margin
69 margin = 0.05 * max(x_max - x_min, y_max - y_min)
70 arena_bounds = ((x_min - margin, x_max + margin),
71 (y_min - margin, y_max + margin))
73 # Create bins
74 x_edges = np.arange(arena_bounds[0][0], arena_bounds[0][1] + bin_size, bin_size)
75 y_edges = np.arange(arena_bounds[1][0], arena_bounds[1][1] + bin_size, bin_size)
77 # Compute occupancy
78 occupancy, _, _ = np.histogram2d(
79 positions[:, 0], positions[:, 1],
80 bins=[x_edges, y_edges]
81 )
83 # Convert to time (assuming 1 sample = 1 time unit)
84 # In real usage, this should be scaled by actual sampling rate
86 # Smooth if requested
87 if smooth_sigma is not None and smooth_sigma > 0:
88 occupancy = ndimage.gaussian_filter(occupancy, sigma=smooth_sigma)
90 # Set unvisited bins to NaN
91 occupancy[occupancy < min_occupancy] = np.nan
93 return occupancy.T, x_edges, y_edges # Transpose for correct orientation
96def compute_rate_map(
97 neural_signal: np.ndarray,
98 positions: np.ndarray,
99 occupancy_map: np.ndarray,
100 x_edges: np.ndarray,
101 y_edges: np.ndarray,
102 smooth_sigma: Optional[float] = 1.5
103) -> np.ndarray:
104 """
105 Compute firing rate map from neural signal and positions.
107 Parameters
108 ----------
109 neural_signal : np.ndarray
110 Neural signal (e.g., calcium fluorescence, firing rates, spike counts)
111 positions : np.ndarray, shape (n_samples, 2)
112 X, Y positions corresponding to neural signal
113 occupancy_map : np.ndarray
114 2D occupancy map from compute_occupancy_map
115 x_edges : np.ndarray
116 X bin edges from compute_occupancy_map
117 y_edges : np.ndarray
118 Y bin edges from compute_occupancy_map
119 smooth_sigma : float, optional
120 Gaussian smoothing sigma in bins
122 Returns
123 -------
124 rate_map : np.ndarray
125 2D activity map (mean signal per spatial bin)
126 """
127 # For continuous signals (e.g., calcium), compute mean activity per bin
128 # Use weighted 2D histogram to sum activity in each bin
129 activity_sum, _, _ = np.histogram2d(
130 positions[:, 0], positions[:, 1],
131 bins=[x_edges, y_edges],
132 weights=neural_signal # Use signal as weights
133 )
134 activity_sum = activity_sum.T # Transpose for correct orientation
136 # Count samples in each bin (should match occupancy)
137 sample_count, _, _ = np.histogram2d(
138 positions[:, 0], positions[:, 1],
139 bins=[x_edges, y_edges]
140 )
141 sample_count = sample_count.T
143 # Compute mean activity per bin
144 with np.errstate(divide='ignore', invalid='ignore'):
145 rate_map = activity_sum / sample_count
146 rate_map[sample_count == 0] = 0 # Set unvisited bins to 0
148 # Smooth if requested
149 if smooth_sigma is not None and smooth_sigma > 0:
150 # Only smooth visited bins
151 mask = sample_count > 0
152 if np.any(mask):
153 # Create a smoothed version preserving only visited areas
154 smoothed = ndimage.gaussian_filter(rate_map, sigma=smooth_sigma)
155 smoothed_mask = ndimage.gaussian_filter(mask.astype(float), sigma=smooth_sigma)
156 with np.errstate(divide='ignore', invalid='ignore'):
157 rate_map = smoothed / smoothed_mask
158 rate_map[~mask] = 0
160 return rate_map
163def extract_place_fields(
164 rate_map: np.ndarray,
165 min_peak_rate: float = 1.0,
166 min_field_size: int = 9,
167 peak_to_mean_ratio: float = 1.5
168) -> List[Dict[str, Union[float, Tuple[int, int]]]]:
169 """
170 Extract place fields from a rate map.
172 Parameters
173 ----------
174 rate_map : np.ndarray
175 2D firing rate map
176 min_peak_rate : float
177 Minimum peak firing rate for valid place field
178 min_field_size : int
179 Minimum number of contiguous bins for valid field
180 peak_to_mean_ratio : float
181 Minimum ratio of peak to mean rate in field
183 Returns
184 -------
185 place_fields : list of dict
186 List of place fields with properties:
187 - peak_rate: Peak firing rate
188 - center: (x, y) indices of field center
189 - size: Number of bins in field
190 - mean_rate: Mean rate within field
191 """
192 # Threshold rate map
193 mean_rate = np.nanmean(rate_map)
194 threshold = mean_rate * peak_to_mean_ratio
196 # Find connected regions above threshold
197 binary_map = rate_map > threshold
198 labeled_map, num_fields = ndimage.label(binary_map)
200 place_fields = []
202 for field_id in range(1, num_fields + 1):
203 field_mask = labeled_map == field_id
204 field_size = np.sum(field_mask)
206 if field_size < min_field_size:
207 continue
209 field_rates = rate_map[field_mask]
210 peak_rate = np.max(field_rates)
212 if peak_rate < min_peak_rate:
213 continue
215 # Find center of mass (weighted by rate)
216 y_indices, x_indices = np.where(field_mask)
217 field_rates_for_com = rate_map[field_mask]
218 center_y = int(np.round(np.average(y_indices, weights=field_rates_for_com)))
219 center_x = int(np.round(np.average(x_indices, weights=field_rates_for_com)))
221 place_fields.append({
222 'peak_rate': peak_rate,
223 'center': (center_x, center_y),
224 'size': field_size,
225 'mean_rate': np.mean(field_rates)
226 })
228 return place_fields
231def compute_spatial_information_rate(
232 rate_map: np.ndarray,
233 occupancy_map: np.ndarray
234) -> float:
235 """
236 Compute spatial information rate (bits/spike).
238 Implements Skaggs et al. (1993) spatial information metric.
240 Parameters
241 ----------
242 rate_map : np.ndarray
243 2D firing rate map
244 occupancy_map : np.ndarray
245 2D occupancy map (time spent in each bin)
247 Returns
248 -------
249 spatial_info : float
250 Spatial information in bits/spike
251 """
252 # Normalize occupancy to get probability
253 valid_mask = ~np.isnan(occupancy_map)
254 p_i = occupancy_map[valid_mask] / np.sum(occupancy_map[valid_mask])
255 r_i = rate_map[valid_mask]
257 # Mean firing rate
258 r_mean = np.sum(p_i * r_i)
260 if r_mean == 0:
261 return 0.0
263 # Spatial information
264 with np.errstate(divide='ignore', invalid='ignore'):
265 info_per_bin = p_i * (r_i / r_mean) * np.log2(r_i / r_mean)
266 info_per_bin[np.isnan(info_per_bin)] = 0
267 info_per_bin[np.isinf(info_per_bin)] = 0
269 spatial_info = np.sum(info_per_bin)
271 return max(0.0, spatial_info) # Ensure non-negative
274def compute_grid_score(
275 rate_map: np.ndarray,
276 min_peaks: int = 3,
277 max_field_size_ratio: float = 0.5
278) -> float:
279 """
280 Compute grid score from spatial autocorrelation.
282 Grid score measures hexagonal regularity (Sargolini et al., 2006).
284 Parameters
285 ----------
286 rate_map : np.ndarray
287 2D firing rate map
288 min_peaks : int
289 Minimum number of peaks for valid grid
290 max_field_size_ratio : float
291 Maximum field size as ratio of map size
293 Returns
294 -------
295 grid_score : float
296 Grid score (-2 to 2, higher = more grid-like)
297 """
298 # Compute 2D autocorrelation
299 rate_map_clean = np.nan_to_num(rate_map)
300 autocorr = signal.correlate2d(rate_map_clean, rate_map_clean, mode='same')
301 autocorr = autocorr / np.max(autocorr) # Normalize
303 # Find peaks in autocorrelation
304 from scipy.ndimage import maximum_filter
305 local_maxima = (autocorr == maximum_filter(autocorr, size=3))
307 # Remove central peak
308 center = np.array(autocorr.shape) // 2
309 local_maxima[center[0]-2:center[0]+3, center[1]-2:center[1]+3] = False
311 # Check if enough peaks
312 peak_coords = np.column_stack(np.where(local_maxima))
314 # For very sparse patterns, might not have enough peaks
315 if len(peak_coords) < min_peaks:
316 # Check for minimal structure - if only central peak exists
317 if np.sum(local_maxima) == 0:
318 return -2.0 # No structure beyond center
320 # Compute rotational correlations at 30°, 60°, 90°, 120°, 150°
321 angles = [30, 60, 90, 120, 150]
322 correlations = []
324 for angle in angles:
325 rotated = ndimage.rotate(autocorr, angle, reshape=False, order=1)
326 corr = np.corrcoef(autocorr.flatten(), rotated.flatten())[0, 1]
327 correlations.append(corr)
329 # Grid score is minimum of 60° and 120° minus maximum of 30°, 90°, 150°
330 grid_score = min(correlations[1], correlations[3]) - max(correlations[0], correlations[2], correlations[4])
332 return grid_score
335def compute_spatial_decoding_accuracy(
336 neural_activity: np.ndarray,
337 positions: np.ndarray,
338 test_size: float = 0.5,
339 n_estimators: int = 20,
340 max_depth: int = 3,
341 min_samples_leaf: int = 50,
342 random_state: int = 42,
343 logger: Optional[logging.Logger] = None
344) -> Dict[str, float]:
345 """
346 Compute position decoding accuracy from neural activity.
348 Parameters
349 ----------
350 neural_activity : np.ndarray, shape (n_neurons, n_samples)
351 Neural activity matrix
352 positions : np.ndarray, shape (n_samples, 2)
353 True X, Y positions
354 test_size : float
355 Fraction of data for testing
356 n_estimators : int
357 Number of trees in random forest
358 max_depth : int
359 Maximum tree depth
360 min_samples_leaf : int
361 Minimum samples per leaf
362 random_state : int
363 Random seed for reproducibility
364 logger : logging.Logger, optional
365 Logger for debugging
367 Returns
368 -------
369 metrics : dict
370 Decoding accuracy metrics:
371 - r2_x: R² score for X position
372 - r2_y: R² score for Y position
373 - r2_avg: Average R² score
374 - mse: Mean squared error
375 """
376 if logger:
377 logger.info(f"Computing spatial decoding with {neural_activity.shape[0]} neurons")
379 # Prepare data (transpose for sklearn format)
380 X = neural_activity.T # (n_samples, n_neurons)
381 y = positions
383 # Split data
384 X_train, X_test, y_train, y_test = train_test_split(
385 X, y, test_size=test_size, random_state=random_state
386 )
388 # Train decoder
389 decoder = RandomForestRegressor(
390 n_estimators=n_estimators,
391 max_depth=max_depth,
392 min_samples_leaf=min_samples_leaf,
393 random_state=random_state,
394 n_jobs=-1
395 )
397 decoder.fit(X_train, y_train)
398 y_pred = decoder.predict(X_test)
400 # Compute metrics
401 r2_x = r2_score(y_test[:, 0], y_pred[:, 0])
402 r2_y = r2_score(y_test[:, 1], y_pred[:, 1])
403 mse = np.mean((y_test - y_pred)**2)
405 metrics = {
406 'r2_x': max(0.0, r2_x), # Avoid negative R²
407 'r2_y': max(0.0, r2_y),
408 'r2_avg': max(0.0, (r2_x + r2_y) / 2),
409 'mse': mse
410 }
412 if logger:
413 logger.info(f"Decoding accuracy: R²_avg = {metrics['r2_avg']:.3f}")
415 return metrics
418def compute_spatial_information(
419 neural_activity: Union[np.ndarray, TimeSeries, MultiTimeSeries],
420 positions: Union[np.ndarray, TimeSeries, MultiTimeSeries],
421 logger: Optional[logging.Logger] = None
422) -> Dict[str, float]:
423 """
424 Compute mutual information between neural activity and spatial position.
426 Parameters
427 ----------
428 neural_activity : array-like or TimeSeries
429 Neural activity data
430 positions : array-like or TimeSeries
431 Spatial position data (X, Y)
432 logger : logging.Logger, optional
433 Logger for debugging
435 Returns
436 -------
437 metrics : dict
438 Spatial information metrics:
439 - mi_x: MI with X position
440 - mi_y: MI with Y position
441 - mi_total: MI with 2D position
442 """
443 # Convert to TimeSeries if needed
444 if isinstance(neural_activity, np.ndarray):
445 if neural_activity.ndim == 1:
446 neural_ts = TimeSeries(neural_activity, discrete=False)
447 else:
448 # Create MultiTimeSeries from multiple neurons
449 neural_ts_list = [TimeSeries(neural_activity[i], discrete=False)
450 for i in range(neural_activity.shape[0])]
451 neural_ts = MultiTimeSeries(neural_ts_list)
452 else:
453 neural_ts = neural_activity
455 if isinstance(positions, np.ndarray):
456 if positions.shape[1] != 2:
457 raise ValueError("Positions must be 2D (X, Y)")
458 x_ts = TimeSeries(positions[:, 0], discrete=False)
459 y_ts = TimeSeries(positions[:, 1], discrete=False)
460 pos_2d_ts = MultiTimeSeries([x_ts, y_ts])
461 else:
462 # Assume positions is already properly formatted
463 if isinstance(positions, MultiTimeSeries):
464 x_ts = positions.data[0]
465 y_ts = positions.data[1]
466 pos_2d_ts = positions
467 else:
468 raise ValueError("Positions must be 2D")
470 # Compute mutual information
471 try:
472 mi_x = get_sim(neural_ts, x_ts, metric='mi', estimator='gcmi')
473 mi_y = get_sim(neural_ts, y_ts, metric='mi', estimator='gcmi')
474 mi_total = get_sim(neural_ts, pos_2d_ts, metric='mi', estimator='gcmi')
476 metrics = {
477 'mi_x': mi_x,
478 'mi_y': mi_y,
479 'mi_total': mi_total
480 }
482 if logger:
483 logger.info(f"Spatial MI: X={mi_x:.3f}, Y={mi_y:.3f}, Total={mi_total:.3f}")
485 except Exception as e:
486 if logger:
487 logger.warning(f"MI calculation failed: {e}")
488 metrics = {'mi_x': 0.0, 'mi_y': 0.0, 'mi_total': 0.0}
490 return metrics
493def filter_by_speed(
494 data: Dict[str, np.ndarray],
495 speed_range: Tuple[float, float] = (0.05, float('inf')),
496 position_key: str = 'positions',
497 smooth_window: int = 5
498) -> Dict[str, np.ndarray]:
499 """
500 Filter data to include only periods of specific movement speeds.
502 Parameters
503 ----------
504 data : dict
505 Dictionary with at least 'positions' key containing (n_samples, 2) array
506 speed_range : tuple
507 (min_speed, max_speed) to include
508 position_key : str
509 Key for position data in dictionary
510 smooth_window : int
511 Window size for speed smoothing
513 Returns
514 -------
515 filtered_data : dict
516 Data dictionary with speed-filtered arrays
517 """
518 positions = data[position_key]
520 # Compute speed
521 velocity = np.diff(positions, axis=0)
522 speed = np.sqrt(np.sum(velocity**2, axis=1))
524 # Smooth speed
525 if smooth_window > 1:
526 speed = ndimage.uniform_filter1d(speed, size=smooth_window)
528 # Add zero speed for first sample
529 speed = np.concatenate([[0], speed])
531 # Create mask
532 mask = (speed >= speed_range[0]) & (speed <= speed_range[1])
534 # Filter all arrays in data
535 filtered_data = {}
536 for key, value in data.items():
537 if isinstance(value, np.ndarray) and len(value) == len(mask):
538 filtered_data[key] = value[mask]
539 else:
540 filtered_data[key] = value
542 filtered_data['speed'] = speed[mask]
544 return filtered_data
547def filter_by_direction(
548 data: Dict[str, np.ndarray],
549 direction_range: Tuple[float, float],
550 position_key: str = 'positions',
551 smooth_window: int = 5
552) -> Dict[str, np.ndarray]:
553 """
554 Filter data to include only periods of specific movement directions.
556 Parameters
557 ----------
558 data : dict
559 Dictionary with at least 'positions' key
560 direction_range : tuple
561 (min_angle, max_angle) in radians, where 0 = east, π/2 = north
562 position_key : str
563 Key for position data
564 smooth_window : int
565 Window size for direction smoothing
567 Returns
568 -------
569 filtered_data : dict
570 Direction-filtered data
571 """
572 positions = data[position_key]
574 # Compute movement direction
575 velocity = np.diff(positions, axis=0)
576 direction = np.arctan2(velocity[:, 1], velocity[:, 0])
578 # Add zero direction for first sample
579 direction = np.concatenate([[0], direction])
581 # Handle angle wrapping
582 min_dir, max_dir = direction_range
583 if min_dir <= max_dir:
584 mask = (direction >= min_dir) & (direction <= max_dir)
585 else:
586 # Wrapped range (e.g., -π/4 to π/4)
587 mask = (direction >= min_dir) | (direction <= max_dir)
589 # Filter all arrays
590 filtered_data = {}
591 for key, value in data.items():
592 if isinstance(value, np.ndarray) and len(value) == len(mask):
593 filtered_data[key] = value[mask]
594 else:
595 filtered_data[key] = value
597 filtered_data['direction'] = direction[mask]
599 return filtered_data
602def analyze_spatial_coding(
603 neural_activity: np.ndarray,
604 positions: np.ndarray,
605 arena_bounds: Optional[Tuple[Tuple[float, float], Tuple[float, float]]] = None,
606 bin_size: float = 0.025,
607 min_peak_rate: float = 1.0,
608 speed_range: Optional[Tuple[float, float]] = (0.05, float('inf')),
609 peak_to_mean_ratio: float = 1.5,
610 min_field_size: int = 9,
611 logger: Optional[logging.Logger] = None
612) -> Dict[str, Union[np.ndarray, List, Dict, float]]:
613 """
614 Comprehensive spatial coding analysis pipeline.
616 Parameters
617 ----------
618 neural_activity : np.ndarray, shape (n_neurons, n_samples)
619 Neural activity matrix
620 positions : np.ndarray, shape (n_samples, 2)
621 Position data
622 arena_bounds : tuple, optional
623 Arena boundaries
624 bin_size : float
625 Spatial bin size
626 min_peak_rate : float
627 Minimum peak rate for place fields
628 speed_range : tuple, optional
629 Speed filter range
630 peak_to_mean_ratio : float
631 Minimum ratio of peak to mean rate in field
632 min_field_size : int
633 Minimum number of contiguous bins for valid field
634 logger : logging.Logger, optional
635 Logger for progress
637 Returns
638 -------
639 results : dict
640 Comprehensive spatial analysis results:
641 - rate_maps: List of rate maps per neuron
642 - place_fields: List of place fields per neuron
643 - spatial_info: Spatial information per neuron
644 - grid_scores: Grid scores per neuron
645 - decoding_accuracy: Position decoding metrics
646 - spatial_mi: Mutual information metrics
647 """
648 if logger:
649 logger.info(f"Analyzing spatial coding for {neural_activity.shape[0]} neurons")
651 # Speed filtering if requested
652 if speed_range is not None:
653 data = {
654 'positions': positions,
655 'neural_activity': neural_activity.T # Transpose for filtering
656 }
657 filtered = filter_by_speed(data, speed_range)
658 positions = filtered['positions']
659 neural_activity = filtered['neural_activity'].T
661 # Compute occupancy map
662 occupancy_map, x_edges, y_edges = compute_occupancy_map(
663 positions, arena_bounds, bin_size
664 )
666 results = {
667 'rate_maps': [],
668 'place_fields': [],
669 'spatial_info': [],
670 'grid_scores': []
671 }
673 # Analyze each neuron
674 for i in range(neural_activity.shape[0]):
675 # Compute rate map
676 rate_map = compute_rate_map(
677 neural_activity[i], positions,
678 occupancy_map, x_edges, y_edges
679 )
680 results['rate_maps'].append(rate_map)
682 # Extract place fields
683 fields = extract_place_fields(
684 rate_map,
685 min_peak_rate=min_peak_rate,
686 min_field_size=min_field_size,
687 peak_to_mean_ratio=peak_to_mean_ratio
688 )
689 results['place_fields'].append(fields)
691 # Spatial information
692 si = compute_spatial_information_rate(rate_map, occupancy_map)
693 results['spatial_info'].append(si)
695 # Grid score
696 gs = compute_grid_score(rate_map)
697 results['grid_scores'].append(gs)
699 # Population-level analyses
700 results['decoding_accuracy'] = compute_spatial_decoding_accuracy(
701 neural_activity, positions, logger=logger
702 )
704 results['spatial_mi'] = compute_spatial_information(
705 neural_activity, positions, logger=logger
706 )
708 # Summary statistics
709 results['summary'] = {
710 'n_place_cells': sum(len(pf) > 0 for pf in results['place_fields']),
711 'n_grid_cells': sum(gs > 0.3 for gs in results['grid_scores']),
712 'mean_spatial_info': np.mean(results['spatial_info']),
713 'mean_grid_score': np.mean(results['grid_scores'])
714 }
716 if logger:
717 logger.info(f"Found {results['summary']['n_place_cells']} place cells, "
718 f"{results['summary']['n_grid_cells']} grid cells")
720 return results
723def compute_spatial_metrics(
724 neural_activity: np.ndarray,
725 positions: np.ndarray,
726 metrics: Optional[List[str]] = None,
727 **kwargs
728) -> Dict[str, Union[float, Dict]]:
729 """
730 Compute selected spatial metrics.
732 Parameters
733 ----------
734 neural_activity : np.ndarray
735 Neural activity data
736 positions : np.ndarray
737 Position data
738 metrics : list of str, optional
739 Metrics to compute. If None, computes all.
740 Options: 'decoding', 'information', 'place_fields', 'grid_scores'
741 **kwargs
742 Additional arguments passed to analysis functions
744 Returns
745 -------
746 results : dict
747 Computed metrics
748 """
749 if metrics is None:
750 metrics = ['decoding', 'information', 'place_fields', 'grid_scores']
752 results = {}
754 if 'decoding' in metrics:
755 results['decoding'] = compute_spatial_decoding_accuracy(
756 neural_activity, positions, **kwargs
757 )
759 if 'information' in metrics:
760 results['information'] = compute_spatial_information(
761 neural_activity, positions, **kwargs
762 )
764 if 'place_fields' in metrics or 'grid_scores' in metrics:
765 analysis = analyze_spatial_coding(
766 neural_activity, positions, **kwargs
767 )
769 if 'place_fields' in metrics:
770 results['place_fields'] = analysis['place_fields']
771 results['n_place_cells'] = analysis['summary']['n_place_cells']
773 if 'grid_scores' in metrics:
774 results['grid_scores'] = analysis['grid_scores']
775 results['n_grid_cells'] = analysis['summary']['n_grid_cells']
777 return results