Coverage for src/driada/utils/spatial.py: 93.46%

214 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1""" 

2Spatial Analysis Utilities for Neural Data 

3========================================== 

4 

5This module provides comprehensive spatial analysis tools for neural data, 

6particularly for analyzing place cells, grid cells, and other spatially-modulated neurons. 

7 

8Key functionality: 

9- Place field detection and analysis 

10- Grid score computation 

11- Spatial information metrics 

12- Position decoding 

13- Speed/direction filtering 

14- High-level spatial analysis pipelines 

15""" 

16 

17import numpy as np 

18from scipy import ndimage, signal, stats 

19from scipy.spatial.distance import pdist, squareform 

20from sklearn.model_selection import train_test_split 

21from sklearn.ensemble import RandomForestRegressor 

22from sklearn.metrics import r2_score 

23from typing import Optional, Tuple, Dict, List, Union 

24import logging 

25 

26from ..information import TimeSeries, MultiTimeSeries, get_sim 

27 

28 

29def compute_occupancy_map( 

30 positions: np.ndarray, 

31 arena_bounds: Optional[Tuple[Tuple[float, float], Tuple[float, float]]] = None, 

32 bin_size: float = 0.025, 

33 min_occupancy: float = 0.1, 

34 smooth_sigma: Optional[float] = None 

35) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 

36 """ 

37 Compute 2D occupancy map from position data. 

38  

39 Parameters 

40 ---------- 

41 positions : np.ndarray, shape (n_samples, 2) 

42 X, Y positions over time 

43 arena_bounds : tuple of tuples, optional 

44 ((x_min, x_max), (y_min, y_max)). If None, inferred from data 

45 bin_size : float 

46 Size of spatial bins in same units as positions 

47 min_occupancy : float 

48 Minimum occupancy time (seconds) for valid bins 

49 smooth_sigma : float, optional 

50 Gaussian smoothing sigma in bins. If None, no smoothing 

51  

52 Returns 

53 ------- 

54 occupancy_map : np.ndarray 

55 2D occupancy map in seconds 

56 x_edges : np.ndarray 

57 X bin edges 

58 y_edges : np.ndarray 

59 Y bin edges 

60 """ 

61 if positions.shape[1] != 2: 

62 raise ValueError(f"Positions must be 2D, got shape {positions.shape}") 

63 

64 # Determine arena bounds 

65 if arena_bounds is None: 

66 x_min, x_max = positions[:, 0].min(), positions[:, 0].max() 

67 y_min, y_max = positions[:, 1].min(), positions[:, 1].max() 

68 # Add small margin 

69 margin = 0.05 * max(x_max - x_min, y_max - y_min) 

70 arena_bounds = ((x_min - margin, x_max + margin), 

71 (y_min - margin, y_max + margin)) 

72 

73 # Create bins 

74 x_edges = np.arange(arena_bounds[0][0], arena_bounds[0][1] + bin_size, bin_size) 

75 y_edges = np.arange(arena_bounds[1][0], arena_bounds[1][1] + bin_size, bin_size) 

76 

77 # Compute occupancy 

78 occupancy, _, _ = np.histogram2d( 

79 positions[:, 0], positions[:, 1], 

80 bins=[x_edges, y_edges] 

81 ) 

82 

83 # Convert to time (assuming 1 sample = 1 time unit) 

84 # In real usage, this should be scaled by actual sampling rate 

85 

86 # Smooth if requested 

87 if smooth_sigma is not None and smooth_sigma > 0: 

88 occupancy = ndimage.gaussian_filter(occupancy, sigma=smooth_sigma) 

89 

90 # Set unvisited bins to NaN 

91 occupancy[occupancy < min_occupancy] = np.nan 

92 

93 return occupancy.T, x_edges, y_edges # Transpose for correct orientation 

94 

95 

96def compute_rate_map( 

97 neural_signal: np.ndarray, 

98 positions: np.ndarray, 

99 occupancy_map: np.ndarray, 

100 x_edges: np.ndarray, 

101 y_edges: np.ndarray, 

102 smooth_sigma: Optional[float] = 1.5 

103) -> np.ndarray: 

104 """ 

105 Compute firing rate map from neural signal and positions. 

106  

107 Parameters 

108 ---------- 

109 neural_signal : np.ndarray 

110 Neural signal (e.g., calcium fluorescence, firing rates, spike counts) 

111 positions : np.ndarray, shape (n_samples, 2) 

112 X, Y positions corresponding to neural signal 

113 occupancy_map : np.ndarray 

114 2D occupancy map from compute_occupancy_map 

115 x_edges : np.ndarray 

116 X bin edges from compute_occupancy_map 

117 y_edges : np.ndarray  

118 Y bin edges from compute_occupancy_map 

119 smooth_sigma : float, optional 

120 Gaussian smoothing sigma in bins 

121  

122 Returns 

123 ------- 

124 rate_map : np.ndarray 

125 2D activity map (mean signal per spatial bin) 

126 """ 

127 # For continuous signals (e.g., calcium), compute mean activity per bin 

128 # Use weighted 2D histogram to sum activity in each bin 

129 activity_sum, _, _ = np.histogram2d( 

130 positions[:, 0], positions[:, 1], 

131 bins=[x_edges, y_edges], 

132 weights=neural_signal # Use signal as weights 

133 ) 

134 activity_sum = activity_sum.T # Transpose for correct orientation 

135 

136 # Count samples in each bin (should match occupancy) 

137 sample_count, _, _ = np.histogram2d( 

138 positions[:, 0], positions[:, 1], 

139 bins=[x_edges, y_edges] 

140 ) 

141 sample_count = sample_count.T 

142 

143 # Compute mean activity per bin 

144 with np.errstate(divide='ignore', invalid='ignore'): 

145 rate_map = activity_sum / sample_count 

146 rate_map[sample_count == 0] = 0 # Set unvisited bins to 0 

147 

148 # Smooth if requested 

149 if smooth_sigma is not None and smooth_sigma > 0: 

150 # Only smooth visited bins 

151 mask = sample_count > 0 

152 if np.any(mask): 

153 # Create a smoothed version preserving only visited areas 

154 smoothed = ndimage.gaussian_filter(rate_map, sigma=smooth_sigma) 

155 smoothed_mask = ndimage.gaussian_filter(mask.astype(float), sigma=smooth_sigma) 

156 with np.errstate(divide='ignore', invalid='ignore'): 

157 rate_map = smoothed / smoothed_mask 

158 rate_map[~mask] = 0 

159 

160 return rate_map 

161 

162 

163def extract_place_fields( 

164 rate_map: np.ndarray, 

165 min_peak_rate: float = 1.0, 

166 min_field_size: int = 9, 

167 peak_to_mean_ratio: float = 1.5 

168) -> List[Dict[str, Union[float, Tuple[int, int]]]]: 

169 """ 

170 Extract place fields from a rate map. 

171  

172 Parameters 

173 ---------- 

174 rate_map : np.ndarray 

175 2D firing rate map 

176 min_peak_rate : float 

177 Minimum peak firing rate for valid place field 

178 min_field_size : int 

179 Minimum number of contiguous bins for valid field 

180 peak_to_mean_ratio : float 

181 Minimum ratio of peak to mean rate in field 

182  

183 Returns 

184 ------- 

185 place_fields : list of dict 

186 List of place fields with properties: 

187 - peak_rate: Peak firing rate 

188 - center: (x, y) indices of field center 

189 - size: Number of bins in field 

190 - mean_rate: Mean rate within field 

191 """ 

192 # Threshold rate map 

193 mean_rate = np.nanmean(rate_map) 

194 threshold = mean_rate * peak_to_mean_ratio 

195 

196 # Find connected regions above threshold 

197 binary_map = rate_map > threshold 

198 labeled_map, num_fields = ndimage.label(binary_map) 

199 

200 place_fields = [] 

201 

202 for field_id in range(1, num_fields + 1): 

203 field_mask = labeled_map == field_id 

204 field_size = np.sum(field_mask) 

205 

206 if field_size < min_field_size: 

207 continue 

208 

209 field_rates = rate_map[field_mask] 

210 peak_rate = np.max(field_rates) 

211 

212 if peak_rate < min_peak_rate: 

213 continue 

214 

215 # Find center of mass (weighted by rate) 

216 y_indices, x_indices = np.where(field_mask) 

217 field_rates_for_com = rate_map[field_mask] 

218 center_y = int(np.round(np.average(y_indices, weights=field_rates_for_com))) 

219 center_x = int(np.round(np.average(x_indices, weights=field_rates_for_com))) 

220 

221 place_fields.append({ 

222 'peak_rate': peak_rate, 

223 'center': (center_x, center_y), 

224 'size': field_size, 

225 'mean_rate': np.mean(field_rates) 

226 }) 

227 

228 return place_fields 

229 

230 

231def compute_spatial_information_rate( 

232 rate_map: np.ndarray, 

233 occupancy_map: np.ndarray 

234) -> float: 

235 """ 

236 Compute spatial information rate (bits/spike). 

237  

238 Implements Skaggs et al. (1993) spatial information metric. 

239  

240 Parameters 

241 ---------- 

242 rate_map : np.ndarray 

243 2D firing rate map 

244 occupancy_map : np.ndarray 

245 2D occupancy map (time spent in each bin) 

246  

247 Returns 

248 ------- 

249 spatial_info : float 

250 Spatial information in bits/spike 

251 """ 

252 # Normalize occupancy to get probability 

253 valid_mask = ~np.isnan(occupancy_map) 

254 p_i = occupancy_map[valid_mask] / np.sum(occupancy_map[valid_mask]) 

255 r_i = rate_map[valid_mask] 

256 

257 # Mean firing rate 

258 r_mean = np.sum(p_i * r_i) 

259 

260 if r_mean == 0: 

261 return 0.0 

262 

263 # Spatial information 

264 with np.errstate(divide='ignore', invalid='ignore'): 

265 info_per_bin = p_i * (r_i / r_mean) * np.log2(r_i / r_mean) 

266 info_per_bin[np.isnan(info_per_bin)] = 0 

267 info_per_bin[np.isinf(info_per_bin)] = 0 

268 

269 spatial_info = np.sum(info_per_bin) 

270 

271 return max(0.0, spatial_info) # Ensure non-negative 

272 

273 

274def compute_grid_score( 

275 rate_map: np.ndarray, 

276 min_peaks: int = 3, 

277 max_field_size_ratio: float = 0.5 

278) -> float: 

279 """ 

280 Compute grid score from spatial autocorrelation. 

281  

282 Grid score measures hexagonal regularity (Sargolini et al., 2006). 

283  

284 Parameters 

285 ---------- 

286 rate_map : np.ndarray 

287 2D firing rate map 

288 min_peaks : int 

289 Minimum number of peaks for valid grid 

290 max_field_size_ratio : float 

291 Maximum field size as ratio of map size 

292  

293 Returns 

294 ------- 

295 grid_score : float 

296 Grid score (-2 to 2, higher = more grid-like) 

297 """ 

298 # Compute 2D autocorrelation 

299 rate_map_clean = np.nan_to_num(rate_map) 

300 autocorr = signal.correlate2d(rate_map_clean, rate_map_clean, mode='same') 

301 autocorr = autocorr / np.max(autocorr) # Normalize 

302 

303 # Find peaks in autocorrelation 

304 from scipy.ndimage import maximum_filter 

305 local_maxima = (autocorr == maximum_filter(autocorr, size=3)) 

306 

307 # Remove central peak 

308 center = np.array(autocorr.shape) // 2 

309 local_maxima[center[0]-2:center[0]+3, center[1]-2:center[1]+3] = False 

310 

311 # Check if enough peaks 

312 peak_coords = np.column_stack(np.where(local_maxima)) 

313 

314 # For very sparse patterns, might not have enough peaks 

315 if len(peak_coords) < min_peaks: 

316 # Check for minimal structure - if only central peak exists 

317 if np.sum(local_maxima) == 0: 

318 return -2.0 # No structure beyond center 

319 

320 # Compute rotational correlations at 30°, 60°, 90°, 120°, 150° 

321 angles = [30, 60, 90, 120, 150] 

322 correlations = [] 

323 

324 for angle in angles: 

325 rotated = ndimage.rotate(autocorr, angle, reshape=False, order=1) 

326 corr = np.corrcoef(autocorr.flatten(), rotated.flatten())[0, 1] 

327 correlations.append(corr) 

328 

329 # Grid score is minimum of 60° and 120° minus maximum of 30°, 90°, 150° 

330 grid_score = min(correlations[1], correlations[3]) - max(correlations[0], correlations[2], correlations[4]) 

331 

332 return grid_score 

333 

334 

335def compute_spatial_decoding_accuracy( 

336 neural_activity: np.ndarray, 

337 positions: np.ndarray, 

338 test_size: float = 0.5, 

339 n_estimators: int = 20, 

340 max_depth: int = 3, 

341 min_samples_leaf: int = 50, 

342 random_state: int = 42, 

343 logger: Optional[logging.Logger] = None 

344) -> Dict[str, float]: 

345 """ 

346 Compute position decoding accuracy from neural activity. 

347  

348 Parameters 

349 ---------- 

350 neural_activity : np.ndarray, shape (n_neurons, n_samples) 

351 Neural activity matrix 

352 positions : np.ndarray, shape (n_samples, 2) 

353 True X, Y positions 

354 test_size : float 

355 Fraction of data for testing 

356 n_estimators : int 

357 Number of trees in random forest 

358 max_depth : int 

359 Maximum tree depth 

360 min_samples_leaf : int 

361 Minimum samples per leaf 

362 random_state : int 

363 Random seed for reproducibility 

364 logger : logging.Logger, optional 

365 Logger for debugging 

366  

367 Returns 

368 ------- 

369 metrics : dict 

370 Decoding accuracy metrics: 

371 - r2_x: R² score for X position 

372 - r2_y: R² score for Y position 

373 - r2_avg: Average R² score 

374 - mse: Mean squared error 

375 """ 

376 if logger: 

377 logger.info(f"Computing spatial decoding with {neural_activity.shape[0]} neurons") 

378 

379 # Prepare data (transpose for sklearn format) 

380 X = neural_activity.T # (n_samples, n_neurons) 

381 y = positions 

382 

383 # Split data 

384 X_train, X_test, y_train, y_test = train_test_split( 

385 X, y, test_size=test_size, random_state=random_state 

386 ) 

387 

388 # Train decoder 

389 decoder = RandomForestRegressor( 

390 n_estimators=n_estimators, 

391 max_depth=max_depth, 

392 min_samples_leaf=min_samples_leaf, 

393 random_state=random_state, 

394 n_jobs=-1 

395 ) 

396 

397 decoder.fit(X_train, y_train) 

398 y_pred = decoder.predict(X_test) 

399 

400 # Compute metrics 

401 r2_x = r2_score(y_test[:, 0], y_pred[:, 0]) 

402 r2_y = r2_score(y_test[:, 1], y_pred[:, 1]) 

403 mse = np.mean((y_test - y_pred)**2) 

404 

405 metrics = { 

406 'r2_x': max(0.0, r2_x), # Avoid negative R² 

407 'r2_y': max(0.0, r2_y), 

408 'r2_avg': max(0.0, (r2_x + r2_y) / 2), 

409 'mse': mse 

410 } 

411 

412 if logger: 

413 logger.info(f"Decoding accuracy: R²_avg = {metrics['r2_avg']:.3f}") 

414 

415 return metrics 

416 

417 

418def compute_spatial_information( 

419 neural_activity: Union[np.ndarray, TimeSeries, MultiTimeSeries], 

420 positions: Union[np.ndarray, TimeSeries, MultiTimeSeries], 

421 logger: Optional[logging.Logger] = None 

422) -> Dict[str, float]: 

423 """ 

424 Compute mutual information between neural activity and spatial position. 

425  

426 Parameters 

427 ---------- 

428 neural_activity : array-like or TimeSeries 

429 Neural activity data 

430 positions : array-like or TimeSeries 

431 Spatial position data (X, Y) 

432 logger : logging.Logger, optional 

433 Logger for debugging 

434  

435 Returns 

436 ------- 

437 metrics : dict 

438 Spatial information metrics: 

439 - mi_x: MI with X position 

440 - mi_y: MI with Y position 

441 - mi_total: MI with 2D position 

442 """ 

443 # Convert to TimeSeries if needed 

444 if isinstance(neural_activity, np.ndarray): 

445 if neural_activity.ndim == 1: 

446 neural_ts = TimeSeries(neural_activity, discrete=False) 

447 else: 

448 # Create MultiTimeSeries from multiple neurons 

449 neural_ts_list = [TimeSeries(neural_activity[i], discrete=False) 

450 for i in range(neural_activity.shape[0])] 

451 neural_ts = MultiTimeSeries(neural_ts_list) 

452 else: 

453 neural_ts = neural_activity 

454 

455 if isinstance(positions, np.ndarray): 

456 if positions.shape[1] != 2: 

457 raise ValueError("Positions must be 2D (X, Y)") 

458 x_ts = TimeSeries(positions[:, 0], discrete=False) 

459 y_ts = TimeSeries(positions[:, 1], discrete=False) 

460 pos_2d_ts = MultiTimeSeries([x_ts, y_ts]) 

461 else: 

462 # Assume positions is already properly formatted 

463 if isinstance(positions, MultiTimeSeries): 

464 x_ts = positions.data[0] 

465 y_ts = positions.data[1] 

466 pos_2d_ts = positions 

467 else: 

468 raise ValueError("Positions must be 2D") 

469 

470 # Compute mutual information 

471 try: 

472 mi_x = get_sim(neural_ts, x_ts, metric='mi', estimator='gcmi') 

473 mi_y = get_sim(neural_ts, y_ts, metric='mi', estimator='gcmi') 

474 mi_total = get_sim(neural_ts, pos_2d_ts, metric='mi', estimator='gcmi') 

475 

476 metrics = { 

477 'mi_x': mi_x, 

478 'mi_y': mi_y, 

479 'mi_total': mi_total 

480 } 

481 

482 if logger: 

483 logger.info(f"Spatial MI: X={mi_x:.3f}, Y={mi_y:.3f}, Total={mi_total:.3f}") 

484 

485 except Exception as e: 

486 if logger: 

487 logger.warning(f"MI calculation failed: {e}") 

488 metrics = {'mi_x': 0.0, 'mi_y': 0.0, 'mi_total': 0.0} 

489 

490 return metrics 

491 

492 

493def filter_by_speed( 

494 data: Dict[str, np.ndarray], 

495 speed_range: Tuple[float, float] = (0.05, float('inf')), 

496 position_key: str = 'positions', 

497 smooth_window: int = 5 

498) -> Dict[str, np.ndarray]: 

499 """ 

500 Filter data to include only periods of specific movement speeds. 

501  

502 Parameters 

503 ---------- 

504 data : dict 

505 Dictionary with at least 'positions' key containing (n_samples, 2) array 

506 speed_range : tuple 

507 (min_speed, max_speed) to include 

508 position_key : str 

509 Key for position data in dictionary 

510 smooth_window : int 

511 Window size for speed smoothing 

512  

513 Returns 

514 ------- 

515 filtered_data : dict 

516 Data dictionary with speed-filtered arrays 

517 """ 

518 positions = data[position_key] 

519 

520 # Compute speed 

521 velocity = np.diff(positions, axis=0) 

522 speed = np.sqrt(np.sum(velocity**2, axis=1)) 

523 

524 # Smooth speed 

525 if smooth_window > 1: 

526 speed = ndimage.uniform_filter1d(speed, size=smooth_window) 

527 

528 # Add zero speed for first sample 

529 speed = np.concatenate([[0], speed]) 

530 

531 # Create mask 

532 mask = (speed >= speed_range[0]) & (speed <= speed_range[1]) 

533 

534 # Filter all arrays in data 

535 filtered_data = {} 

536 for key, value in data.items(): 

537 if isinstance(value, np.ndarray) and len(value) == len(mask): 

538 filtered_data[key] = value[mask] 

539 else: 

540 filtered_data[key] = value 

541 

542 filtered_data['speed'] = speed[mask] 

543 

544 return filtered_data 

545 

546 

547def filter_by_direction( 

548 data: Dict[str, np.ndarray], 

549 direction_range: Tuple[float, float], 

550 position_key: str = 'positions', 

551 smooth_window: int = 5 

552) -> Dict[str, np.ndarray]: 

553 """ 

554 Filter data to include only periods of specific movement directions. 

555  

556 Parameters 

557 ---------- 

558 data : dict 

559 Dictionary with at least 'positions' key 

560 direction_range : tuple 

561 (min_angle, max_angle) in radians, where 0 = east, π/2 = north 

562 position_key : str 

563 Key for position data 

564 smooth_window : int 

565 Window size for direction smoothing 

566  

567 Returns 

568 ------- 

569 filtered_data : dict 

570 Direction-filtered data 

571 """ 

572 positions = data[position_key] 

573 

574 # Compute movement direction 

575 velocity = np.diff(positions, axis=0) 

576 direction = np.arctan2(velocity[:, 1], velocity[:, 0]) 

577 

578 # Add zero direction for first sample 

579 direction = np.concatenate([[0], direction]) 

580 

581 # Handle angle wrapping 

582 min_dir, max_dir = direction_range 

583 if min_dir <= max_dir: 

584 mask = (direction >= min_dir) & (direction <= max_dir) 

585 else: 

586 # Wrapped range (e.g., -π/4 to π/4) 

587 mask = (direction >= min_dir) | (direction <= max_dir) 

588 

589 # Filter all arrays 

590 filtered_data = {} 

591 for key, value in data.items(): 

592 if isinstance(value, np.ndarray) and len(value) == len(mask): 

593 filtered_data[key] = value[mask] 

594 else: 

595 filtered_data[key] = value 

596 

597 filtered_data['direction'] = direction[mask] 

598 

599 return filtered_data 

600 

601 

602def analyze_spatial_coding( 

603 neural_activity: np.ndarray, 

604 positions: np.ndarray, 

605 arena_bounds: Optional[Tuple[Tuple[float, float], Tuple[float, float]]] = None, 

606 bin_size: float = 0.025, 

607 min_peak_rate: float = 1.0, 

608 speed_range: Optional[Tuple[float, float]] = (0.05, float('inf')), 

609 peak_to_mean_ratio: float = 1.5, 

610 min_field_size: int = 9, 

611 logger: Optional[logging.Logger] = None 

612) -> Dict[str, Union[np.ndarray, List, Dict, float]]: 

613 """ 

614 Comprehensive spatial coding analysis pipeline. 

615  

616 Parameters 

617 ---------- 

618 neural_activity : np.ndarray, shape (n_neurons, n_samples) 

619 Neural activity matrix 

620 positions : np.ndarray, shape (n_samples, 2) 

621 Position data 

622 arena_bounds : tuple, optional 

623 Arena boundaries 

624 bin_size : float 

625 Spatial bin size 

626 min_peak_rate : float 

627 Minimum peak rate for place fields 

628 speed_range : tuple, optional 

629 Speed filter range 

630 peak_to_mean_ratio : float 

631 Minimum ratio of peak to mean rate in field 

632 min_field_size : int 

633 Minimum number of contiguous bins for valid field 

634 logger : logging.Logger, optional 

635 Logger for progress 

636  

637 Returns 

638 ------- 

639 results : dict 

640 Comprehensive spatial analysis results: 

641 - rate_maps: List of rate maps per neuron 

642 - place_fields: List of place fields per neuron 

643 - spatial_info: Spatial information per neuron 

644 - grid_scores: Grid scores per neuron 

645 - decoding_accuracy: Position decoding metrics 

646 - spatial_mi: Mutual information metrics 

647 """ 

648 if logger: 

649 logger.info(f"Analyzing spatial coding for {neural_activity.shape[0]} neurons") 

650 

651 # Speed filtering if requested 

652 if speed_range is not None: 

653 data = { 

654 'positions': positions, 

655 'neural_activity': neural_activity.T # Transpose for filtering 

656 } 

657 filtered = filter_by_speed(data, speed_range) 

658 positions = filtered['positions'] 

659 neural_activity = filtered['neural_activity'].T 

660 

661 # Compute occupancy map 

662 occupancy_map, x_edges, y_edges = compute_occupancy_map( 

663 positions, arena_bounds, bin_size 

664 ) 

665 

666 results = { 

667 'rate_maps': [], 

668 'place_fields': [], 

669 'spatial_info': [], 

670 'grid_scores': [] 

671 } 

672 

673 # Analyze each neuron 

674 for i in range(neural_activity.shape[0]): 

675 # Compute rate map 

676 rate_map = compute_rate_map( 

677 neural_activity[i], positions, 

678 occupancy_map, x_edges, y_edges 

679 ) 

680 results['rate_maps'].append(rate_map) 

681 

682 # Extract place fields 

683 fields = extract_place_fields( 

684 rate_map, 

685 min_peak_rate=min_peak_rate, 

686 min_field_size=min_field_size, 

687 peak_to_mean_ratio=peak_to_mean_ratio 

688 ) 

689 results['place_fields'].append(fields) 

690 

691 # Spatial information 

692 si = compute_spatial_information_rate(rate_map, occupancy_map) 

693 results['spatial_info'].append(si) 

694 

695 # Grid score 

696 gs = compute_grid_score(rate_map) 

697 results['grid_scores'].append(gs) 

698 

699 # Population-level analyses 

700 results['decoding_accuracy'] = compute_spatial_decoding_accuracy( 

701 neural_activity, positions, logger=logger 

702 ) 

703 

704 results['spatial_mi'] = compute_spatial_information( 

705 neural_activity, positions, logger=logger 

706 ) 

707 

708 # Summary statistics 

709 results['summary'] = { 

710 'n_place_cells': sum(len(pf) > 0 for pf in results['place_fields']), 

711 'n_grid_cells': sum(gs > 0.3 for gs in results['grid_scores']), 

712 'mean_spatial_info': np.mean(results['spatial_info']), 

713 'mean_grid_score': np.mean(results['grid_scores']) 

714 } 

715 

716 if logger: 

717 logger.info(f"Found {results['summary']['n_place_cells']} place cells, " 

718 f"{results['summary']['n_grid_cells']} grid cells") 

719 

720 return results 

721 

722 

723def compute_spatial_metrics( 

724 neural_activity: np.ndarray, 

725 positions: np.ndarray, 

726 metrics: Optional[List[str]] = None, 

727 **kwargs 

728) -> Dict[str, Union[float, Dict]]: 

729 """ 

730 Compute selected spatial metrics. 

731  

732 Parameters 

733 ---------- 

734 neural_activity : np.ndarray 

735 Neural activity data 

736 positions : np.ndarray 

737 Position data 

738 metrics : list of str, optional 

739 Metrics to compute. If None, computes all. 

740 Options: 'decoding', 'information', 'place_fields', 'grid_scores' 

741 **kwargs 

742 Additional arguments passed to analysis functions 

743  

744 Returns 

745 ------- 

746 results : dict 

747 Computed metrics 

748 """ 

749 if metrics is None: 

750 metrics = ['decoding', 'information', 'place_fields', 'grid_scores'] 

751 

752 results = {} 

753 

754 if 'decoding' in metrics: 

755 results['decoding'] = compute_spatial_decoding_accuracy( 

756 neural_activity, positions, **kwargs 

757 ) 

758 

759 if 'information' in metrics: 

760 results['information'] = compute_spatial_information( 

761 neural_activity, positions, **kwargs 

762 ) 

763 

764 if 'place_fields' in metrics or 'grid_scores' in metrics: 

765 analysis = analyze_spatial_coding( 

766 neural_activity, positions, **kwargs 

767 ) 

768 

769 if 'place_fields' in metrics: 

770 results['place_fields'] = analysis['place_fields'] 

771 results['n_place_cells'] = analysis['summary']['n_place_cells'] 

772 

773 if 'grid_scores' in metrics: 

774 results['grid_scores'] = analysis['grid_scores'] 

775 results['n_grid_cells'] = analysis['summary']['n_grid_cells'] 

776 

777 return results