Coverage for copick_torch/copick.py: 5%

671 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import logging 

2import os 

3import pickle 

4import random 

5from collections import Counter 

6from datetime import datetime 

7from pathlib import Path 

8from typing import Any, Dict, List, Optional, Tuple, Union 

9 

10import copick 

11import numpy as np 

12import pandas as pd 

13import torch 

14import zarr 

15from scipy.ndimage import gaussian_filter 

16from torch.utils.data import ConcatDataset, Dataset, Subset 

17 

18 

19class CopickDataset(Dataset): 

20 """ 

21 A PyTorch dataset for working with copick data for particle picking tasks. 

22 

23 This implementation focuses on extracting subvolumes around pick coordinates 

24 with support for data augmentation, caching, and class balancing. 

25 """ 

26 

27 def __init__( 

28 self, 

29 config_path: Union[str, Any] = None, 

30 copick_root: Optional[Any] = None, 

31 boxsize: Tuple[int, int, int] = (32, 32, 32), 

32 augment: bool = False, 

33 cache_dir: Optional[str] = None, 

34 cache_format: str = "parquet", # Can be "pickle" or "parquet" 

35 seed: Optional[int] = 1717, 

36 max_samples: Optional[int] = None, 

37 voxel_spacing: float = 10.0, 

38 include_background: bool = False, 

39 background_ratio: float = 0.2, # Background samples as proportion of particle samples 

40 min_background_distance: Optional[float] = None, # Min distance in voxels from particles 

41 patch_strategy: str = "centered", # Can be "centered", "random", or "jittered" 

42 augmentations: Optional[List[str]] = None, # List of augmentation types to apply 

43 augmentation_prob: float = 0.2, # Probability of applying each augmentation 

44 mixup_alpha: Optional[float] = None, # Alpha parameter for mixup augmentation 

45 rotate_axes: Tuple[int, int, int] = (1, 1, 1), # Enable/disable rotation around each axis (x, y, z) 

46 debug_mode: bool = False, 

47 ): 

48 # Validate input: either config_path or copick_root must be provided 

49 if config_path is None and copick_root is None: 

50 raise ValueError("Either config_path or copick_root must be provided") 

51 

52 self.config_path = config_path 

53 self.copick_root = copick_root 

54 self.boxsize = boxsize 

55 self.augment = augment 

56 self.cache_dir = cache_dir 

57 self.cache_format = cache_format.lower() 

58 self.seed = seed 

59 self.max_samples = max_samples 

60 self.voxel_spacing = voxel_spacing 

61 self.include_background = include_background 

62 self.background_ratio = background_ratio 

63 self.min_background_distance = min_background_distance or max(boxsize) 

64 self.patch_strategy = patch_strategy 

65 self.debug_mode = debug_mode 

66 

67 # Augmentation settings 

68 self.augmentation_prob = augmentation_prob 

69 self.mixup_alpha = mixup_alpha 

70 self.rotate_axes = rotate_axes 

71 

72 # Default augmentations if not specified 

73 self.default_augmentations = ["brightness", "blur", "intensity", "flip", "rotate"] 

74 self.augmentations = augmentations or self.default_augmentations 

75 

76 # Special augmentations that need additional handling 

77 self.special_augmentations = ["mixup", "rotate_z"] 

78 

79 # Validate augmentations 

80 valid_augmentations = self.default_augmentations + self.special_augmentations 

81 for aug in self.augmentations: 

82 if aug not in valid_augmentations: 

83 raise ValueError(f"Unknown augmentation type: {aug}. Valid options are: {valid_augmentations}") 

84 

85 # Validate parameters 

86 if self.cache_format not in ["pickle", "parquet"]: 

87 raise ValueError("cache_format must be either 'pickle' or 'parquet'") 

88 

89 if self.patch_strategy not in ["centered", "random", "jittered"]: 

90 raise ValueError("patch_strategy must be one of 'centered', 'random', or 'jittered'") 

91 

92 # Initialize dataset 

93 self._set_random_seed() 

94 self._subvolumes = [] 

95 self._molecule_ids = [] 

96 self._keys = [] 

97 self._is_background = [] 

98 self._load_or_process_data() 

99 self._compute_sample_weights() 

100 

101 def _set_random_seed(self): 

102 """Set random seeds for reproducibility.""" 

103 if self.seed is not None: 

104 random.seed(self.seed) 

105 np.random.seed(self.seed) 

106 torch.manual_seed(self.seed) 

107 if torch.cuda.is_available(): 

108 torch.cuda.manual_seed(self.seed) 

109 

110 def _compute_sample_weights(self): 

111 """Compute sample weights based on class frequency for balancing.""" 

112 # Include special handling for background class if it exists 

113 class_counts = Counter(self._molecule_ids) 

114 total_samples = len(self._molecule_ids) 

115 

116 # Assign weights inversely proportional to class frequency 

117 class_weights = {cls: total_samples / count for cls, count in class_counts.items()} 

118 

119 # Compute weights for each sample 

120 self.sample_weights = [class_weights[mol_id] for mol_id in self._molecule_ids] 

121 

122 def _get_cache_path(self): 

123 """Get the appropriate cache file path based on format.""" 

124 # If we have a copick_root but no config_path, use dataset IDs 

125 cache_key = self.config_path 

126 if cache_key is None and self.copick_root is not None: 

127 # Try to get dataset IDs from the datasets attribute 

128 try: 

129 dataset_ids = [] 

130 for dataset in self.copick_root.datasets: 

131 if hasattr(dataset, "id"): 

132 dataset_ids.append(str(dataset.id)) 

133 

134 if dataset_ids: 

135 # Use the dataset IDs in order as the cache key 

136 dataset_ids_str = "_".join(dataset_ids) 

137 cache_key = f"datasets_{dataset_ids_str}" 

138 else: 

139 # Fallback if no dataset IDs found 

140 cache_key = "copick_root_unknown" 

141 except (AttributeError, IndexError): 

142 # Fallback if datasets attribute doesn't exist 

143 if hasattr(self.copick_root, "dataset_ids"): 

144 dataset_ids = [str(did) for did in self.copick_root.dataset_ids] 

145 cache_key = f"datasets_{'_'.join(dataset_ids)}" 

146 else: 

147 # Last resort fallback 

148 cache_key = f"copick_root_{hash(str(self.copick_root))}" 

149 

150 if self.cache_format == "pickle": 

151 return os.path.join( 

152 self.cache_dir, 

153 f"{cache_key}_{self.boxsize[0]}x{self.boxsize[1]}x{self.boxsize[2]}" 

154 f"_{self.voxel_spacing}" 

155 f"{'_with_bg' if self.include_background else ''}.pkl", 

156 ) 

157 else: # parquet 

158 return os.path.join( 

159 self.cache_dir, 

160 f"{cache_key}_{self.boxsize[0]}x{self.boxsize[1]}x{self.boxsize[2]}" 

161 f"_{self.voxel_spacing}" 

162 f"{'_with_bg' if self.include_background else ''}.parquet", 

163 ) 

164 

165 def _load_or_process_data(self): 

166 """Load data from cache or process it directly.""" 

167 # If cache_dir is None, process data directly without caching 

168 if self.cache_dir is None: 

169 print("Cache directory not specified. Processing data without caching...") 

170 self._load_data() 

171 return 

172 

173 # If cache_dir is specified, use caching logic 

174 os.makedirs(self.cache_dir, exist_ok=True) 

175 cache_file = self._get_cache_path() 

176 

177 if os.path.exists(cache_file): 

178 print(f"Loading cached data from {cache_file}") 

179 

180 if self.cache_format == "pickle": 

181 self._load_from_pickle(cache_file) 

182 else: # parquet 

183 self._load_from_parquet(cache_file) 

184 

185 # Apply max_samples limit if specified 

186 if self.max_samples is not None and len(self._subvolumes) > self.max_samples: 

187 indices = np.random.choice(len(self._subvolumes), self.max_samples, replace=False) 

188 self._subvolumes = np.array(self._subvolumes)[indices] 

189 self._molecule_ids = np.array(self._molecule_ids)[indices] 

190 if self._is_background: 

191 self._is_background = np.array(self._is_background)[indices] 

192 else: 

193 print("Processing data and creating cache...") 

194 self._load_data() 

195 

196 # Only save to cache if we actually loaded some data 

197 if len(self._subvolumes) > 0: 

198 if self.cache_format == "pickle": 

199 self._save_to_pickle(cache_file) 

200 else: # parquet 

201 self._save_to_parquet(cache_file) 

202 print(f"Cached data saved to {cache_file}") 

203 else: 

204 print("No data loaded, skipping cache creation") 

205 

206 def _load_from_pickle(self, cache_file): 

207 """Load dataset from pickle cache.""" 

208 with open(cache_file, "rb") as f: 

209 cached_data = pickle.load(f) 

210 self._subvolumes = cached_data.get("subvolumes", []) 

211 self._molecule_ids = cached_data.get("molecule_ids", []) 

212 self._keys = cached_data.get("keys", []) 

213 self._is_background = cached_data.get("is_background", []) 

214 

215 # Handle case where background flag wasn't saved 

216 if not self._is_background and self.include_background: 

217 # Initialize all as non-background 

218 self._is_background = [False] * len(self._subvolumes) 

219 

220 def _save_to_pickle(self, cache_file): 

221 """Save dataset to pickle cache.""" 

222 with open(cache_file, "wb") as f: 

223 pickle.dump( 

224 { 

225 "subvolumes": self._subvolumes, 

226 "molecule_ids": self._molecule_ids, 

227 "keys": self._keys, 

228 "is_background": self._is_background, 

229 }, 

230 f, 

231 ) 

232 

233 def _load_from_parquet(self, cache_file): 

234 """Load dataset from parquet cache.""" 

235 try: 

236 df = pd.read_parquet(cache_file) 

237 

238 # Process subvolumes from bytes back to numpy arrays 

239 self._subvolumes = [] 

240 for _, row in df.iterrows(): 

241 if isinstance(row["subvolume"], bytes): 

242 # Reconstruct numpy array from bytes 

243 subvol = np.frombuffer(row["subvolume"], dtype=np.float32) 

244 shape = row["shape"] 

245 if isinstance(shape, list): 

246 shape = tuple(shape) 

247 subvol = subvol.reshape(shape) 

248 self._subvolumes.append(subvol) 

249 else: 

250 raise ValueError(f"Invalid subvolume format: {type(row['subvolume'])}") 

251 

252 # Convert to numpy array 

253 self._subvolumes = np.array(self._subvolumes) 

254 

255 # Extract other fields 

256 self._molecule_ids = df["molecule_id"].tolist() 

257 self._keys = df["key"].tolist() if "key" in df.columns else [] 

258 

259 # Load or initialize background flags 

260 if "is_background" in df.columns: 

261 self._is_background = df["is_background"].tolist() 

262 else: 

263 self._is_background = [False] * len(self._subvolumes) 

264 

265 # Reconstruct unique keys if not available 

266 if not self._keys: 

267 unique_ids = set() 

268 for mol_id in self._molecule_ids: 

269 if mol_id != -1 and not df.loc[df["molecule_id"] == mol_id, "is_background"].any(): 

270 unique_ids.add(mol_id) 

271 self._keys = sorted(unique_ids) 

272 

273 except Exception as e: 

274 print(f"Error loading from parquet: {str(e)}") 

275 raise 

276 

277 def _save_to_parquet(self, cache_file): 

278 """Save dataset to parquet cache.""" 

279 try: 

280 # Check if we have any data to save 

281 if len(self._subvolumes) == 0: 

282 print("No data to save to parquet") 

283 return 

284 

285 # Prepare records 

286 records = [] 

287 for subvol, mol_id, is_bg in zip(self._subvolumes, self._molecule_ids, self._is_background): 

288 record = { 

289 "subvolume": subvol.tobytes(), 

290 "shape": list(subvol.shape), 

291 "molecule_id": mol_id, 

292 "is_background": is_bg, 

293 } 

294 records.append(record) 

295 

296 # Add keys information 

297 key_mapping = [] 

298 for i, key in enumerate(self._keys): 

299 key_mapping.append({"key_index": i, "key": key}) 

300 

301 # Create and save main dataframe 

302 df = pd.DataFrame(records) 

303 

304 # Add keys as a column for each row 

305 df["key"] = df["molecule_id"].apply( 

306 lambda x: self._keys[x] if x != -1 and x < len(self._keys) else "background", 

307 ) 

308 

309 df.to_parquet(cache_file, index=False) 

310 

311 # Save additional metadata 

312 metadata_file = cache_file.replace(".parquet", "_metadata.parquet") 

313 metadata = { 

314 "creation_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 

315 "total_samples": len(records), 

316 "unique_molecules": len(self._keys), 

317 "boxsize": self.boxsize, 

318 "include_background": self.include_background, 

319 "background_samples": sum(self._is_background), 

320 } 

321 pd.DataFrame([metadata]).to_parquet(metadata_file, index=False) 

322 

323 except Exception as e: 

324 print(f"Error saving to parquet: {str(e)}") 

325 raise 

326 

327 def _extract_subvolume_with_validation(self, tomogram_array, x, y, z): 

328 """Extract a subvolume with validation checks, applying the selected patch strategy.""" 

329 half_box = np.array(self.boxsize) // 2 

330 

331 # Apply patch strategy 

332 if self.patch_strategy == "centered": 

333 # Standard centered extraction 

334 offset_x, offset_y, offset_z = 0, 0, 0 

335 elif self.patch_strategy == "random": 

336 # Random offsets within half_box/4 

337 max_offset = [size // 4 for size in half_box] 

338 offset_x = np.random.randint(-max_offset[0], max_offset[0] + 1) 

339 offset_y = np.random.randint(-max_offset[1], max_offset[1] + 1) 

340 offset_z = np.random.randint(-max_offset[2], max_offset[2] + 1) 

341 elif self.patch_strategy == "jittered": 

342 # Small random jitter for data augmentation 

343 offset_x = np.random.randint(-2, 3) 

344 offset_y = np.random.randint(-2, 3) 

345 offset_z = np.random.randint(-2, 3) 

346 

347 # Apply offsets to coordinates 

348 x_adj = x + offset_x 

349 y_adj = y + offset_y 

350 z_adj = z + offset_z 

351 

352 # Calculate slice indices 

353 x_start = max(0, int(x_adj - half_box[0])) 

354 x_end = min(tomogram_array.shape[2], int(x_adj + half_box[0])) 

355 y_start = max(0, int(y_adj - half_box[1])) 

356 y_end = min(tomogram_array.shape[1], int(y_adj + half_box[1])) 

357 z_start = max(0, int(z_adj - half_box[2])) 

358 z_end = min(tomogram_array.shape[0], int(z_adj + half_box[2])) 

359 

360 # Validate slice ranges 

361 if x_end <= x_start or y_end <= y_start or z_end <= z_start: 

362 return None, False, "Invalid slice range" 

363 

364 # Extract subvolume 

365 subvolume = tomogram_array[z_start:z_end, y_start:y_end, x_start:x_end] 

366 

367 # Check if extracted shape matches requested size 

368 if subvolume.shape != self.boxsize: 

369 # Need to pad the subvolume 

370 padded = np.zeros(self.boxsize, dtype=subvolume.dtype) 

371 

372 # Calculate padding amounts 

373 z_pad_start = max(0, half_box[2] - int(z)) 

374 y_pad_start = max(0, half_box[1] - int(y)) 

375 x_pad_start = max(0, half_box[0] - int(x)) 

376 

377 # Calculate end indices 

378 z_pad_end = min(z_pad_start + (z_end - z_start), self.boxsize[0]) 

379 y_pad_end = min(y_pad_start + (y_end - y_start), self.boxsize[1]) 

380 x_pad_end = min(x_pad_start + (x_end - x_start), self.boxsize[2]) 

381 

382 # Copy data 

383 padded[z_pad_start:z_pad_end, y_pad_start:y_pad_end, x_pad_start:x_pad_end] = subvolume 

384 return padded, True, "padded" 

385 

386 return subvolume, True, "valid" 

387 

388 def _load_data(self): 

389 """Load particle picks data from copick project.""" 

390 # Determine which root to use 

391 if self.copick_root is not None: 

392 root = self.copick_root 

393 print("Using provided copick root object") 

394 else: 

395 try: 

396 root = copick.from_file(self.config_path) 

397 print(f"Loading data from {self.config_path}") 

398 except Exception as e: 

399 print(f"Failed to load copick root: {str(e)}") 

400 return 

401 

402 # Store all particle coordinates for background sampling 

403 all_particle_coords = [] 

404 

405 for run in root.runs: 

406 print(f"Processing run: {run.name}") 

407 

408 # Try to load tomogram 

409 try: 

410 voxel_spacing_obj = run.get_voxel_spacing(self.voxel_spacing) 

411 if ( 

412 voxel_spacing_obj is None 

413 or not hasattr(voxel_spacing_obj, "tomograms") 

414 or not voxel_spacing_obj.tomograms 

415 ): 

416 print(f"No tomograms found for run {run.name} at voxel spacing {self.voxel_spacing}") 

417 continue 

418 

419 tomogram = voxel_spacing_obj.tomograms[0] 

420 tomogram_array = tomogram.numpy() 

421 except Exception as e: 

422 print(f"Error loading tomogram for run {run.name}: {str(e)}") 

423 continue 

424 

425 # Process picks 

426 run_particle_coords = [] # Store coordinates for this run 

427 

428 for picks in run.get_picks(): 

429 if not picks.from_tool: 

430 continue 

431 

432 object_name = picks.pickable_object_name 

433 

434 try: 

435 points, _ = picks.numpy() 

436 points = points / self.voxel_spacing 

437 

438 for point in points: 

439 try: 

440 x, y, z = point 

441 

442 # Save for background sampling 

443 run_particle_coords.append((x, y, z)) 

444 

445 # Extract subvolume 

446 subvolume, is_valid, _ = self._extract_subvolume_with_validation(tomogram_array, x, y, z) 

447 

448 if is_valid: 

449 self._subvolumes.append(subvolume) 

450 

451 if object_name not in self._keys: 

452 self._keys.append(object_name) 

453 

454 self._molecule_ids.append(self._keys.index(object_name)) 

455 self._is_background.append(False) 

456 except Exception as e: 

457 print(f"Error extracting subvolume: {str(e)}") 

458 except Exception as e: 

459 print(f"Error processing picks for {object_name}: {str(e)}") 

460 

461 # Sample background points for this run if needed 

462 if self.include_background and run_particle_coords: 

463 all_particle_coords.extend(run_particle_coords) 

464 self._sample_background_points(tomogram_array, run_particle_coords) 

465 

466 self._subvolumes = np.array(self._subvolumes) 

467 self._molecule_ids = np.array(self._molecule_ids) 

468 self._is_background = np.array(self._is_background) 

469 

470 # Apply max_samples limit if specified 

471 if self.max_samples is not None and len(self._subvolumes) > self.max_samples: 

472 indices = np.random.choice(len(self._subvolumes), self.max_samples, replace=False) 

473 self._subvolumes = self._subvolumes[indices] 

474 self._molecule_ids = self._molecule_ids[indices] 

475 self._is_background = self._is_background[indices] 

476 

477 print(f"Loaded {len(self._subvolumes)} subvolumes with {len(self._keys)} classes") 

478 print(f"Background samples: {sum(self._is_background)}") 

479 

480 def _sample_background_points(self, tomogram_array, particle_coords): 

481 """Sample background points away from particles.""" 

482 if not particle_coords: 

483 return 

484 

485 # Convert to numpy array for distance calculations 

486 particle_coords = np.array(particle_coords) 

487 

488 # Calculate number of background samples based on ratio 

489 num_particles = len(particle_coords) 

490 num_background = int(num_particles * self.background_ratio) 

491 

492 # Limit attempts to avoid infinite loop 

493 max_attempts = num_background * 10 

494 attempts = 0 

495 bg_points_found = 0 

496 

497 half_box = np.array(self.boxsize) // 2 

498 

499 while bg_points_found < num_background and attempts < max_attempts: 

500 # Generate random point within tomogram bounds with margin for box extraction 

501 random_point = np.array( 

502 [ 

503 np.random.randint(half_box[0], tomogram_array.shape[2] - half_box[0]), 

504 np.random.randint(half_box[1], tomogram_array.shape[1] - half_box[1]), 

505 np.random.randint(half_box[2], tomogram_array.shape[0] - half_box[2]), 

506 ], 

507 ) 

508 

509 # Calculate distances to all particles 

510 distances = np.linalg.norm(particle_coords - random_point, axis=1) 

511 

512 # Check if point is far enough from all particles 

513 if np.min(distances) >= self.min_background_distance: 

514 # Extract subvolume 

515 x, y, z = random_point 

516 subvolume, is_valid, _ = self._extract_subvolume_with_validation(tomogram_array, x, y, z) 

517 

518 if is_valid: 

519 self._subvolumes.append(subvolume) 

520 self._molecule_ids.append(-1) # Use -1 to indicate background 

521 self._is_background.append(True) 

522 bg_points_found += 1 

523 

524 attempts += 1 

525 

526 print(f"Added {bg_points_found} background points after {attempts} attempts") 

527 

528 def _augment_subvolume(self, subvolume, idx=None): 

529 """Apply data augmentation to a subvolume. 

530 

531 Args: 

532 subvolume: The 3D volume to augment 

533 idx: Optional index for mixup augmentation 

534 

535 Returns: 

536 Augmented subvolume and list of applied augmentations if debug_mode is True 

537 """ 

538 # Track applied augmentations if debug mode is enabled 

539 applied_augmentations = [] 

540 mixup_info = None 

541 

542 # Apply standard augmentations with probability 

543 for aug in self.augmentations: 

544 if random.random() < self.augmentation_prob: 

545 if aug == "brightness": 

546 delta = np.random.uniform(-0.5, 0.5) 

547 subvolume = self._brightness(subvolume, max_delta=0.5) 

548 if self.debug_mode: 

549 applied_augmentations.append({"type": "brightness", "delta": float(delta)}) 

550 

551 elif aug == "blur": 

552 sigma = np.random.uniform(0.5, 1.5) 

553 subvolume = self._gaussian_blur(subvolume, sigma_range=(0.5, 1.5)) 

554 if self.debug_mode: 

555 applied_augmentations.append({"type": "blur", "sigma": float(sigma)}) 

556 

557 elif aug == "intensity": 

558 factor = np.random.uniform(0.5, 1.5) 

559 subvolume = self._intensity_scaling(subvolume, intensity_range=(0.5, 1.5)) 

560 if self.debug_mode: 

561 applied_augmentations.append({"type": "intensity", "factor": float(factor)}) 

562 

563 elif aug == "flip": 

564 axis = random.randint(0, 2) 

565 subvolume = self._flip(subvolume, axis=axis) 

566 if self.debug_mode: 

567 applied_augmentations.append({"type": "flip", "axis": int(axis)}) 

568 

569 elif aug == "rotate": 

570 # Filter available axes based on rotate_axes setting 

571 available_axes = [i for i, allowed in enumerate(self.rotate_axes) if allowed] 

572 if len(available_axes) < 2: 

573 axis1, axis2 = 0, 1 

574 else: 

575 axis1, axis2 = random.sample(available_axes, 2) 

576 k = random.randint(1, 3) # 90, 180, or 270 degrees 

577 subvolume = self._rotate(subvolume, axes=(axis1, axis2), k=k) 

578 if self.debug_mode: 

579 applied_augmentations.append({"type": "rotate", "axes": (int(axis1), int(axis2)), "k": int(k)}) 

580 

581 elif aug == "rotate_z" and "rotate_z" in self.augmentations: 

582 angle = np.random.uniform(0, 360) 

583 subvolume = self._rotate_z(subvolume, angle=angle) 

584 if self.debug_mode: 

585 applied_augmentations.append({"type": "rotate_z", "angle": float(angle)}) 

586 

587 # Apply mixup if enabled and we have an index to work with 

588 if "mixup" in self.augmentations and self.mixup_alpha is not None and idx is not None: # noqa: SIM102 

589 if random.random() < self.augmentation_prob: 

590 subvolume, mixup_other_idx, mixup_lambda = self._apply_mixup(subvolume, idx) 

591 if self.debug_mode and mixup_other_idx is not None: 

592 mixup_info = {"other_idx": int(mixup_other_idx), "lambda": float(mixup_lambda)} 

593 

594 if self.debug_mode: 

595 return subvolume, applied_augmentations, mixup_info 

596 else: 

597 return subvolume 

598 

599 def _brightness(self, volume, max_delta=0.5): 

600 """Adjust brightness of a volume.""" 

601 delta = np.random.uniform(-max_delta, max_delta) 

602 return volume + delta 

603 

604 def _gaussian_blur(self, volume, sigma_range=(0.5, 1.5)): 

605 """Apply Gaussian blur to a volume.""" 

606 sigma = np.random.uniform(*sigma_range) 

607 return gaussian_filter(volume, sigma=sigma) 

608 

609 def _intensity_scaling(self, volume, intensity_range=(0.5, 1.5)): 

610 """Scale the intensity of a volume.""" 

611 intensity_factor = np.random.uniform(*intensity_range) 

612 return volume * intensity_factor 

613 

614 def _flip(self, volume, axis=None): 

615 """Flip the volume along specified axis or a random axis if not specified.""" 

616 if axis is None: 

617 axis = random.randint(0, 2) 

618 return np.flip(volume, axis=axis) 

619 

620 def _rotate(self, volume, axes=None, k=None): 

621 """Rotate the volume 90, 180, or 270 degrees around specified or allowed axes. 

622 

623 Args: 

624 volume: The 3D volume to rotate 

625 axes: Optional tuple of (axis1, axis2) to rotate around 

626 k: Optional number of 90-degree rotations (1, 2, or 3) 

627 

628 Returns: 

629 Rotated volume 

630 """ 

631 # If axes not specified, select from available axes 

632 if axes is None: 

633 # Filter available axes based on rotate_axes setting 

634 available_axes = [i for i, allowed in enumerate(self.rotate_axes) if allowed] 

635 

636 if len(available_axes) < 2: 

637 # Need at least 2 axes for rotation, if not enough are enabled, 

638 # default to standard x-y rotation 

639 axes = (0, 1) 

640 else: 

641 # Select two random axes from the available ones 

642 axes = tuple(random.sample(available_axes, 2)) 

643 

644 # If k not specified, choose random rotation 

645 if k is None: 

646 k = random.randint(1, 3) # 90, 180, or 270 degrees 

647 

648 return np.rot90(volume, k=k, axes=axes) 

649 

650 def _rotate_z(self, volume, angle=None): 

651 """Apply rotation specifically around z-axis. 

652 

653 Args: 

654 volume: The 3D volume to rotate 

655 angle: Optional rotation angle in degrees (0-360) 

656 

657 Returns: 

658 Rotated volume 

659 """ 

660 # For z-rotation, we'll rotate around the first dimension (z-axis) 

661 # using alternative methods that allow arbitrary angles 

662 if angle is None: 

663 angle = np.random.uniform(0, 360) # Random angle in degrees 

664 

665 # Get center coordinates 

666 center_z, center_y, center_x = np.array(volume.shape) // 2 

667 

668 # Create rotation matrix for z-axis rotation 

669 theta = np.radians(angle) 

670 c, s = np.cos(theta), np.sin(theta) 

671 rotation_matrix = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) 

672 

673 # Create coordinates grid 

674 z, y, x = np.meshgrid( 

675 np.arange(volume.shape[0]), 

676 np.arange(volume.shape[1]), 

677 np.arange(volume.shape[2]), 

678 indexing="ij", 

679 ) 

680 

681 # Adjust coordinates to be relative to center 

682 z -= center_z 

683 y -= center_y 

684 x -= center_x 

685 

686 # Stack coordinates and reshape 

687 coords = np.stack([z.flatten(), y.flatten(), x.flatten()]) 

688 

689 # Apply rotation 

690 rotated_coords = np.dot(rotation_matrix, coords) 

691 

692 # Reshape back and adjust to original coordinate system 

693 z_rot = rotated_coords[0].reshape(volume.shape) + center_z 

694 y_rot = rotated_coords[1].reshape(volume.shape) + center_y 

695 x_rot = rotated_coords[2].reshape(volume.shape) + center_x 

696 

697 # Interpolate using scipy map_coordinates 

698 from scipy.ndimage import map_coordinates 

699 

700 rotated_volume = map_coordinates(volume, [z_rot, y_rot, x_rot], order=1, mode="constant") 

701 

702 return rotated_volume 

703 

704 def _apply_mixup(self, subvolume, idx): 

705 """Apply mixup augmentation by blending with another random sample. 

706 

707 Mixup is a data augmentation technique that creates virtual training examples 

708 by mixing pairs of inputs and their labels with random proportions. 

709 

710 Args: 

711 subvolume: The current subvolume being processed 

712 idx: Index of the current subvolume to avoid mixing with itself 

713 

714 Returns: 

715 Tuple of (mixed subvolume, other_idx, lambda) for complete mixup 

716 """ 

717 if len(self._subvolumes) <= 1: 

718 return subvolume, None, 1.0 

719 

720 # Select a different index at random 

721 other_idx = idx 

722 while other_idx == idx: 

723 other_idx = random.randint(0, len(self._subvolumes) - 1) 

724 

725 # Get the other subvolume 

726 other_subvolume = self._subvolumes[other_idx].copy() 

727 

728 # Sample lambda from beta distribution 

729 if self.mixup_alpha > 0: 

730 lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) 

731 else: 

732 lam = 0.5 # Equal mix if alpha not provided 

733 

734 # Mix the subvolumes 

735 mixed_subvolume = lam * subvolume + (1 - lam) * other_subvolume 

736 

737 # Return the mixed subvolume, the other sample's index, and the lambda value 

738 # This allows the __getitem__ method to properly handle the labels 

739 return mixed_subvolume, other_idx, lam 

740 

741 def __len__(self): 

742 """Get the total number of items in the dataset.""" 

743 return len(self._subvolumes) 

744 

745 def __getitem__(self, idx): 

746 """Get an item from the dataset with proper mixup handling and augmentation tracking. 

747 

748 Returns: 

749 tuple: (subvolume, label_dict) 

750 

751 Where label_dict contains: 

752 - 'class_idx': Original class index (or primary class index if mixed) 

753 - 'is_mixed': Boolean indicating if mixup was applied 

754 - 'mix_lambda': Lambda value for mixup (1.0 if no mixup) 

755 - 'mix_class_idx': Secondary class index for mixup (None if no mixup) 

756 - 'applied_augmentations': List of applied augmentations (if debug_mode=True) 

757 """ 

758 subvolume = self._subvolumes[idx].copy() 

759 molecule_idx = self._molecule_ids[idx] 

760 

761 # Initialize label dictionary with default values 

762 label_dict = {"class_idx": molecule_idx, "is_mixed": False, "mix_lambda": 1.0, "mix_class_idx": None} 

763 

764 # Track augmentations if debug mode is enabled 

765 if self.debug_mode: 

766 label_dict["applied_augmentations"] = [] 

767 

768 if self.augment: 

769 if self.debug_mode: 

770 # Apply augmentations with tracking in debug mode 

771 subvolume, applied_augmentations, mixup_info = self._augment_subvolume(subvolume, idx) 

772 

773 # Store augmentation information in label dictionary 

774 label_dict["applied_augmentations"] = applied_augmentations 

775 

776 # Update mixup information if applicable 

777 if mixup_info is not None: 

778 mixup_other_idx = mixup_info["other_idx"] 

779 mixup_lambda = mixup_info["lambda"] 

780 other_molecule_idx = self._molecule_ids[mixup_other_idx] 

781 label_dict.update( 

782 {"is_mixed": True, "mix_lambda": mixup_lambda, "mix_class_idx": other_molecule_idx}, 

783 ) 

784 else: 

785 # Standard approach without tracking 

786 for aug in self.augmentations: 

787 if random.random() < self.augmentation_prob: 

788 if aug == "brightness": 

789 subvolume = self._brightness(subvolume) 

790 elif aug == "blur": 

791 subvolume = self._gaussian_blur(subvolume) 

792 elif aug == "intensity": 

793 subvolume = self._intensity_scaling(subvolume) 

794 elif aug == "flip": 

795 subvolume = self._flip(subvolume) 

796 elif aug == "rotate": 

797 subvolume = self._rotate(subvolume) 

798 elif aug == "rotate_z" and "rotate_z" in self.augmentations: 

799 subvolume = self._rotate_z(subvolume) 

800 

801 # Apply mixup separately to capture its metadata 

802 if "mixup" in self.augmentations and self.mixup_alpha is not None: 

803 if random.random() < self.augmentation_prob: 

804 subvolume, mixup_other_idx, mixup_lambda = self._apply_mixup(subvolume, idx) 

805 

806 # Update label dictionary with mixup information if applicable 

807 if mixup_other_idx is not None: 

808 other_molecule_idx = self._molecule_ids[mixup_other_idx] 

809 label_dict.update( 

810 {"is_mixed": True, "mix_lambda": mixup_lambda, "mix_class_idx": other_molecule_idx}, 

811 ) 

812 

813 # Normalize 

814 subvolume = (subvolume - np.mean(subvolume)) / (np.std(subvolume) + 1e-6) 

815 

816 # Add channel dimension and convert to tensor 

817 subvolume = torch.as_tensor(subvolume[None, ...], dtype=torch.float32) 

818 

819 return subvolume, label_dict 

820 

821 def get_sample_weights(self): 

822 """Return sample weights for use in a WeightedRandomSampler.""" 

823 return self.sample_weights 

824 

825 def keys(self): 

826 """Get pickable object keys.""" 

827 return self._keys 

828 

829 def examples(self): 

830 """Get example volumes for each class.""" 

831 # Check if dataset is empty 

832 if len(self._subvolumes) == 0 or len(self._molecule_ids) == 0: 

833 return None, [] 

834 

835 class_examples = {} 

836 example_tensors = [] 

837 example_labels = [] 

838 

839 # Get examples for regular classes 

840 for cls in range(len(self._keys)): 

841 # Find first index for this class 

842 for i, mol_id in enumerate(self._molecule_ids): 

843 if ( 

844 mol_id == cls 

845 and cls not in class_examples 

846 and (not self._is_background or not self._is_background[i]) 

847 ): 

848 try: 

849 volume, _ = self[i] 

850 example_tensors.append(volume) 

851 example_labels.append(cls) 

852 class_examples[cls] = i 

853 break 

854 except Exception as e: 

855 print(f"Error extracting example for class {cls}: {str(e)}") 

856 continue 

857 

858 # Add background example if present 

859 if self.include_background and self._is_background and any(self._is_background): 

860 # Find first background sample 

861 for i, is_bg in enumerate(self._is_background): 

862 if is_bg: 

863 try: 

864 volume, _ = self[i] 

865 example_tensors.append(volume) 

866 example_labels.append(-1) # Use -1 for background 

867 break 

868 except Exception as e: 

869 print(f"Error extracting background example: {str(e)}") 

870 continue 

871 

872 if example_tensors: 

873 return torch.stack(example_tensors), [ 

874 "background" if label == -1 else self._keys[label] for label in example_labels 

875 ] 

876 return None, [] 

877 

878 def get_class_distribution(self): 

879 """Get distribution of classes in the dataset.""" 

880 class_counts = Counter(self._molecule_ids) 

881 

882 # Create a readable distribution 

883 distribution = {} 

884 

885 # Count background samples if any 

886 if -1 in class_counts: 

887 distribution["background"] = class_counts[-1] 

888 del class_counts[-1] 

889 

890 # Count regular classes 

891 for cls_idx, count in class_counts.items(): 

892 if 0 <= cls_idx < len(self._keys): 

893 distribution[self._keys[cls_idx]] = count 

894 

895 return distribution 

896 

897 def stratified_split(self, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=None): 

898 """Split the dataset into train, validation, and test sets while preserving class distributions. 

899 

900 Args: 

901 train_ratio: Proportion of data to use for training 

902 val_ratio: Proportion of data to use for validation 

903 test_ratio: Proportion of data to use for testing 

904 seed: Random seed for reproducibility 

905 

906 Returns: 

907 Tuple of (train_dataset, val_dataset, test_dataset) as Subset objects 

908 """ 

909 # Validate ratios 

910 if not np.isclose(train_ratio + val_ratio + test_ratio, 1.0): 

911 raise ValueError("Ratios must sum to 1.0") 

912 

913 # Set random seed if provided 

914 if seed is not None: 

915 np.random.seed(seed) 

916 

917 # Get indices for each class, including background if present 

918 class_indices = {} 

919 for i, mol_id in enumerate(self._molecule_ids): 

920 if mol_id not in class_indices: 

921 class_indices[mol_id] = [] 

922 class_indices[mol_id].append(i) 

923 

924 # Shuffle indices for each class 

925 for mol_id in class_indices: 

926 np.random.shuffle(class_indices[mol_id]) 

927 

928 # Split indices for each class according to ratios 

929 train_indices = [] 

930 val_indices = [] 

931 test_indices = [] 

932 

933 for mol_id, indices in class_indices.items(): 

934 n_samples = len(indices) 

935 n_train = int(n_samples * train_ratio) 

936 n_val = int(n_samples * val_ratio) 

937 

938 # Assign indices to splits 

939 train_indices.extend(indices[:n_train]) 

940 val_indices.extend(indices[n_train : n_train + n_val]) 

941 test_indices.extend(indices[n_train + n_val :]) 

942 

943 # Shuffle the final indices 

944 np.random.shuffle(train_indices) 

945 np.random.shuffle(val_indices) 

946 np.random.shuffle(test_indices) 

947 

948 # Create subset datasets 

949 train_dataset = Subset(self, train_indices) 

950 val_dataset = Subset(self, val_indices) 

951 test_dataset = Subset(self, test_indices) 

952 

953 # Print split information 

954 print( 

955 f"Dataset split: {len(train_dataset)} train, {len(val_dataset)} validation, {len(test_dataset)} test samples", 

956 ) 

957 

958 return train_dataset, val_dataset, test_dataset 

959 

960 def balance_classes(self, method="oversample", target_ratio=1.0, exclude_background=False): 

961 """Balance class distribution in the dataset. 

962 

963 Args: 

964 method: Balancing method to use ('oversample' or 'undersample') 

965 target_ratio: For partial balancing (1.0 = perfect balance) 

966 exclude_background: Whether to exclude background class from balancing 

967 

968 Returns: 

969 A new CopickDataset instance with balanced classes 

970 """ 

971 # Validate parameters 

972 if method not in ["oversample", "undersample"]: 

973 raise ValueError("method must be either 'oversample' or 'undersample'") 

974 

975 if target_ratio <= 0 or target_ratio > 1.0: 

976 raise ValueError("target_ratio must be between 0 and 1.0") 

977 

978 # Get class distribution 

979 class_indices = {} 

980 for i, mol_id in enumerate(self._molecule_ids): 

981 # Skip background class if requested 

982 if exclude_background and mol_id == -1: 

983 continue 

984 

985 if mol_id not in class_indices: 

986 class_indices[mol_id] = [] 

987 class_indices[mol_id].append(i) 

988 

989 class_counts = {mol_id: len(indices) for mol_id, indices in class_indices.items()} 

990 print("Original class distribution:") 

991 for mol_id, count in class_counts.items(): 

992 class_name = "background" if mol_id == -1 else self._keys[mol_id] 

993 print(f" {class_name}: {count} samples") 

994 

995 # Determine target counts 

996 if method == "oversample": 

997 # Oversample minority classes to match majority class 

998 max_count = max(class_counts.values()) 

999 target_counts = {} 

1000 for mol_id, count in class_counts.items(): 

1001 # Calculate the target count for this class 

1002 # At target_ratio=1.0, all classes will have max_count samples 

1003 # At lower ratios, there will be partial balancing 

1004 deficit = max_count - count 

1005 target_counts[mol_id] = count + int(deficit * target_ratio) 

1006 

1007 else: # undersample 

1008 # Undersample majority classes to match minority class 

1009 min_count = min(class_counts.values()) 

1010 target_counts = {} 

1011 for mol_id, count in class_counts.items(): 

1012 # Calculate the target count for this class 

1013 # At target_ratio=1.0, all classes will have min_count samples 

1014 # At lower ratios, there will be partial balancing towards min_count 

1015 excess = count - min_count 

1016 target_counts[mol_id] = count - int(excess * target_ratio) 

1017 

1018 # Create new balanced dataset 

1019 new_subvolumes = [] 

1020 new_molecule_ids = [] 

1021 new_is_background = [] 

1022 

1023 # Process each class 

1024 for mol_id, indices in class_indices.items(): 

1025 current_count = len(indices) 

1026 target_count = target_counts[mol_id] 

1027 

1028 if target_count <= current_count: 

1029 # Undersample: randomly select subset of samples 

1030 selected_indices = np.random.choice(indices, target_count, replace=False) 

1031 for idx in selected_indices: 

1032 new_subvolumes.append(self._subvolumes[idx].copy()) 

1033 new_molecule_ids.append(self._molecule_ids[idx]) 

1034 new_is_background.append(self._is_background[idx]) 

1035 else: 

1036 # Oversample: use all original samples and add duplicates with augmentation 

1037 # First, add all original samples 

1038 for idx in indices: 

1039 new_subvolumes.append(self._subvolumes[idx].copy()) 

1040 new_molecule_ids.append(self._molecule_ids[idx]) 

1041 new_is_background.append(self._is_background[idx]) 

1042 

1043 # Then add duplicates with augmentation to reach target count 

1044 n_duplicates = target_count - current_count 

1045 duplicate_indices = np.random.choice(indices, n_duplicates, replace=True) 

1046 

1047 for idx in duplicate_indices: 

1048 # Apply some basic augmentation to avoid exact duplicates 

1049 augmented = self._subvolumes[idx].copy() 

1050 

1051 if self.debug_mode: 

1052 augmented, applied_augs, _ = self._augment_subvolume(augmented) 

1053 else: 

1054 # Just apply some basic augmentations 

1055 augmented = self._flip(augmented) 

1056 if random.random() < 0.5: 

1057 augmented = self._brightness(augmented) 

1058 if random.random() < 0.5: 

1059 augmented = self._intensity_scaling(augmented) 

1060 

1061 new_subvolumes.append(augmented) 

1062 new_molecule_ids.append(self._molecule_ids[idx]) 

1063 new_is_background.append(self._is_background[idx]) 

1064 

1065 # Convert to numpy arrays 

1066 new_subvolumes = np.array(new_subvolumes) 

1067 new_molecule_ids = np.array(new_molecule_ids) 

1068 new_is_background = np.array(new_is_background) 

1069 

1070 # Create a new dataset with balanced classes 

1071 balanced_dataset = CopickDataset( 

1072 config_path=self.config_path, 

1073 boxsize=self.boxsize, 

1074 augment=self.augment, 

1075 cache_dir=None, # Don't use caching for the balanced dataset 

1076 seed=self.seed, 

1077 voxel_spacing=self.voxel_spacing, 

1078 include_background=self.include_background, 

1079 patch_strategy=self.patch_strategy, 

1080 debug_mode=self.debug_mode, 

1081 ) 

1082 

1083 # Replace data with balanced data 

1084 balanced_dataset._subvolumes = new_subvolumes 

1085 balanced_dataset._molecule_ids = new_molecule_ids 

1086 balanced_dataset._is_background = new_is_background 

1087 balanced_dataset._keys = self._keys.copy() 

1088 

1089 # Compute new sample weights 

1090 balanced_dataset._compute_sample_weights() 

1091 

1092 # Print final distribution 

1093 balanced_dist = balanced_dataset.get_class_distribution() 

1094 print("Balanced class distribution:") 

1095 for class_name, count in balanced_dist.items(): 

1096 print(f" {class_name}: {count} samples") 

1097 

1098 return balanced_dataset 

1099 

1100 def extract_grid_patches(self, patch_size, overlap=0.25, normalize=True, run_index=0, tomo_type="raw"): 

1101 """Extract a grid of patches from a tomogram. 

1102 

1103 Args: 

1104 patch_size: Int or tuple (z, y, x) for patch dimensions 

1105 overlap: Overlap ratio between adjacent patches (0-1) 

1106 normalize: Whether to normalize patches 

1107 run_index: Index of the run to extract from 

1108 tomo_type: Type of tomogram to extract from ('raw' or 'filtered') 

1109 

1110 Returns: 

1111 List of extracted patches and their coordinates (z, y, x) 

1112 """ 

1113 # Validate parameters 

1114 if isinstance(patch_size, int): 

1115 patch_size = (patch_size, patch_size, patch_size) 

1116 elif len(patch_size) != 3: 

1117 raise ValueError("patch_size must be an integer or tuple of 3 integers") 

1118 

1119 if overlap < 0 or overlap >= 1: 

1120 raise ValueError("overlap must be between 0 and 1") 

1121 

1122 # Get tomogram data 

1123 try: 

1124 root = copick.from_file(self.config_path) 

1125 if not root.runs: 

1126 raise ValueError("No runs found in the copick project") 

1127 

1128 # Use the specified run 

1129 if run_index >= len(root.runs): 

1130 raise ValueError(f"Run index {run_index} out of range. Only {len(root.runs)} runs available.") 

1131 

1132 run = root.runs[run_index] 

1133 

1134 # Get the tomogram based on voxel spacing 

1135 tomogram = run.get_voxel_spacing(self.voxel_spacing).tomograms[0] 

1136 

1137 # Get the appropriate tomogram type 

1138 if tomo_type == "raw": 

1139 tomogram_array = tomogram.numpy() 

1140 elif tomo_type == "filtered": 

1141 # Check if filtered data is available 

1142 if hasattr(tomogram, "filtered") and tomogram.filtered is not None: 

1143 tomogram_array = tomogram.filtered.numpy() 

1144 else: 

1145 print("Warning: Filtered tomogram not available, using raw tomogram instead") 

1146 tomogram_array = tomogram.numpy() 

1147 else: 

1148 raise ValueError(f"Invalid tomogram type: {tomo_type}. Must be 'raw' or 'filtered'") 

1149 

1150 # Calculate stride (step size between patches) 

1151 stride_z = int(patch_size[0] * (1 - overlap)) 

1152 stride_y = int(patch_size[1] * (1 - overlap)) 

1153 stride_x = int(patch_size[2] * (1 - overlap)) 

1154 

1155 # Ensure stride is at least 1 

1156 stride_z = max(1, stride_z) 

1157 stride_y = max(1, stride_y) 

1158 stride_x = max(1, stride_x) 

1159 

1160 # Calculate number of patches in each dimension 

1161 n_patches_z = 1 + (tomogram_array.shape[0] - patch_size[0]) // stride_z 

1162 n_patches_y = 1 + (tomogram_array.shape[1] - patch_size[1]) // stride_y 

1163 n_patches_x = 1 + (tomogram_array.shape[2] - patch_size[2]) // stride_x 

1164 

1165 # Initialize results 

1166 patches = [] 

1167 coordinates = [] 

1168 

1169 # Extract patches 

1170 for iz in range(n_patches_z): 

1171 z_start = iz * stride_z 

1172 z_end = z_start + patch_size[0] 

1173 if z_end > tomogram_array.shape[0]: 

1174 continue 

1175 

1176 for iy in range(n_patches_y): 

1177 y_start = iy * stride_y 

1178 y_end = y_start + patch_size[1] 

1179 if y_end > tomogram_array.shape[1]: 

1180 continue 

1181 

1182 for ix in range(n_patches_x): 

1183 x_start = ix * stride_x 

1184 x_end = x_start + patch_size[2] 

1185 if x_end > tomogram_array.shape[2]: 

1186 continue 

1187 

1188 # Extract the patch 

1189 patch = tomogram_array[z_start:z_end, y_start:y_end, x_start:x_end].copy() 

1190 

1191 # Normalize if requested 

1192 if normalize: 

1193 # Center and scale to unit variance 

1194 patch = (patch - np.mean(patch)) / (np.std(patch) + 1e-6) 

1195 

1196 # Record patch and its center coordinates 

1197 patches.append(patch) 

1198 coordinates.append( 

1199 (z_start + patch_size[0] // 2, y_start + patch_size[1] // 2, x_start + patch_size[2] // 2), 

1200 ) 

1201 

1202 print(f"Extracted {len(patches)} patches of size {patch_size} with {overlap:.2f} overlap") 

1203 return patches, coordinates 

1204 

1205 except Exception as e: 

1206 print(f"Error extracting grid patches: {str(e)}") 

1207 raise 

1208 

1209 def extract_from_region(self, x_range, y_range, z_range, tomo_type="raw"): 

1210 """Extract a specific region from a tomogram. 

1211 

1212 Args: 

1213 x_range: Tuple of (min_x, max_x) in voxel space 

1214 y_range: Tuple of (min_y, max_y) in voxel space 

1215 z_range: Tuple of (min_z, max_z) in voxel space 

1216 tomo_type: Type of tomogram to extract from ('raw' or 'filtered') 

1217 

1218 Returns: 

1219 A numpy array containing the extracted region 

1220 """ 

1221 # Validate ranges 

1222 if not all(isinstance(r, tuple) and len(r) == 2 for r in [x_range, y_range, z_range]): 

1223 raise ValueError("Range parameters must be tuples of (min, max)") 

1224 

1225 # Get tomogram data 

1226 try: 

1227 root = copick.from_file(self.config_path) 

1228 if not root.runs: 

1229 raise ValueError("No runs found in the copick project") 

1230 

1231 # Use the first run by default 

1232 run = root.runs[0] 

1233 

1234 # Get the tomogram based on voxel spacing 

1235 tomogram = run.get_voxel_spacing(self.voxel_spacing).tomograms[0] 

1236 

1237 # Get the appropriate tomogram type 

1238 if tomo_type == "raw": 

1239 tomogram_array = tomogram.numpy() 

1240 elif tomo_type == "filtered": 

1241 # Check if filtered data is available 

1242 if hasattr(tomogram, "filtered") and tomogram.filtered is not None: 

1243 tomogram_array = tomogram.filtered.numpy() 

1244 else: 

1245 print("Warning: Filtered tomogram not available, using raw tomogram instead") 

1246 tomogram_array = tomogram.numpy() 

1247 else: 

1248 raise ValueError(f"Invalid tomogram type: {tomo_type}. Must be 'raw' or 'filtered'") 

1249 

1250 # Extract the requested region 

1251 min_x, max_x = x_range 

1252 min_y, max_y = y_range 

1253 min_z, max_z = z_range 

1254 

1255 # Convert to integer indices and ensure they're within bounds 

1256 min_z = max(0, int(min_z)) 

1257 max_z = min(tomogram_array.shape[0], int(max_z)) 

1258 min_y = max(0, int(min_y)) 

1259 max_y = min(tomogram_array.shape[1], int(max_y)) 

1260 min_x = max(0, int(min_x)) 

1261 max_x = min(tomogram_array.shape[2], int(max_x)) 

1262 

1263 # Extract the region 

1264 region = tomogram_array[min_z:max_z, min_y:max_y, min_x:max_x] 

1265 

1266 if region.size == 0: 

1267 raise ValueError("Extracted region is empty. Check range parameters.") 

1268 

1269 return region 

1270 

1271 except Exception as e: 

1272 print(f"Error extracting region from tomogram: {str(e)}") 

1273 raise