Coverage for copick_torch/dataset.py: 7%

583 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import logging 

2import os 

3import pickle 

4import random 

5from collections import Counter 

6from datetime import datetime 

7from pathlib import Path 

8from typing import Any, Dict, List, Optional, Tuple, Union 

9 

10import copick 

11import numpy as np 

12import pandas as pd 

13import torch 

14import zarr 

15 

16# Import these at module level to avoid pickling issues 

17from scipy.ndimage import binary_dilation, gaussian_filter 

18from skimage import measure 

19from skimage.transform import resize 

20from torch.utils.data import ConcatDataset, Dataset, Subset 

21 

22from .augmentations import FourierAugment3D 

23 

24 

25class SimpleDatasetMixin: 

26 """ 

27 A mixin class that modifies datasets to return simple (image, label) pairs. 

28 

29 This modifies the __getitem__ method to return a tuple of (subvolume, label_index) 

30 rather than the more complex dictionary format. 

31 """ 

32 

33 def __getitem__(self, idx): 

34 """ 

35 Get an item from the dataset, returning a simple (subvolume, label) pair. 

36 

37 This simplifies the original __getitem__ method to return just an image tensor 

38 and a class label integer. 

39 

40 Returns: 

41 tuple: (subvolume, label) 

42 """ 

43 # Get the subvolume using the original method 

44 # The original method may apply augmentations if self.augment is True 

45 subvolume = self._subvolumes[idx].copy() 

46 molecule_idx = self._molecule_ids[idx] 

47 

48 if self.augment: 

49 # Apply augmentations if enabled 

50 subvolume = self._augment_subvolume(subvolume, idx) 

51 

52 # Normalize subvolume 

53 subvolume = (subvolume - np.mean(subvolume)) / (np.std(subvolume) + 1e-6) 

54 

55 # Add channel dimension and convert to tensor 

56 subvolume = torch.as_tensor(subvolume[None, ...], dtype=torch.float32) 

57 

58 # Return the subvolume and class index as a simple tuple 

59 return subvolume, molecule_idx 

60 

61 

62class SimpleCopickDataset(SimpleDatasetMixin, Dataset): 

63 """ 

64 A simplified PyTorch dataset for working with copick data that returns (image, label) pairs. 

65 

66 This implementation is a wrapper around the original CopickDataset that modifies the 

67 __getitem__ method to return a simpler format suitable for standard training pipelines. 

68 """ 

69 

70 def __init__( 

71 self, 

72 config_path: Union[str, Any] = None, 

73 copick_root: Optional[Any] = None, 

74 boxsize: Tuple[int, int, int] = (32, 32, 32), 

75 augment: bool = False, 

76 cache_dir: Optional[str] = None, 

77 cache_format: str = "parquet", 

78 seed: Optional[int] = 1717, 

79 max_samples: Optional[int] = None, 

80 voxel_spacing: float = 10.0, 

81 include_background: bool = False, 

82 background_ratio: float = 0.2, 

83 min_background_distance: Optional[float] = None, 

84 patch_strategy: str = "centered", 

85 debug_mode: bool = False, 

86 dataset_id: Optional[int] = None, 

87 overlay_root: str = "/tmp/test/", 

88 ): 

89 """ 

90 Initialize a SimpleCopickDataset. 

91 

92 Args: 

93 config_path: Path to the copick config file or CopickConfig object 

94 copick_root: Copick root object (alternative to config_path) 

95 boxsize: Size of the subvolumes to extract (z, y, x) 

96 augment: Whether to apply data augmentation 

97 cache_dir: Directory to cache extracted subvolumes 

98 cache_format: Format for caching ('pickle' or 'parquet') 

99 seed: Random seed for reproducibility 

100 max_samples: Maximum number of samples to use 

101 voxel_spacing: Voxel spacing to use for extraction 

102 include_background: Whether to include background samples 

103 background_ratio: Ratio of background to particle samples 

104 min_background_distance: Minimum distance from particles for background samples 

105 patch_strategy: Strategy for extracting patches ('centered', 'random', or 'jittered') 

106 debug_mode: Whether to enable debug mode 

107 """ 

108 # Validate input: either config_path, copick_root, or dataset_id must be provided 

109 if config_path is None and copick_root is None and dataset_id is None: 

110 raise ValueError("Either config_path, copick_root, or dataset_id must be provided") 

111 

112 self.config_path = config_path 

113 self.copick_root = copick_root 

114 self.dataset_id = dataset_id 

115 self.overlay_root = overlay_root 

116 

117 # If dataset_id is provided but not copick_root, create it here 

118 if self.dataset_id is not None and self.copick_root is None: 

119 try: 

120 import copick 

121 

122 self.copick_root = copick.from_czcdp_datasets([self.dataset_id], overlay_root=self.overlay_root) 

123 print(f"Created copick root from dataset ID: {self.dataset_id}") 

124 except Exception as e: 

125 print(f"Error creating copick root from dataset ID: {e}") 

126 raise 

127 self.boxsize = boxsize 

128 self.augment = augment 

129 self.cache_dir = cache_dir 

130 self.cache_format = cache_format.lower() 

131 self.seed = seed 

132 self.max_samples = max_samples 

133 self.voxel_spacing = voxel_spacing 

134 self.include_background = include_background 

135 self.background_ratio = background_ratio 

136 self.min_background_distance = min_background_distance or max(boxsize) 

137 self.patch_strategy = patch_strategy 

138 self.debug_mode = debug_mode 

139 

140 # Initialize dataset 

141 self._set_random_seed() 

142 self._subvolumes = [] 

143 self._molecule_ids = [] 

144 self._keys = [] 

145 self._is_background = [] 

146 self._load_or_process_data() 

147 self._compute_sample_weights() 

148 

149 def _set_random_seed(self): 

150 """Set random seeds for reproducibility.""" 

151 if self.seed is not None: 

152 random.seed(self.seed) 

153 np.random.seed(self.seed) 

154 torch.manual_seed(self.seed) 

155 if torch.cuda.is_available(): 

156 torch.cuda.manual_seed(self.seed) 

157 

158 def _compute_sample_weights(self): 

159 """Compute sample weights based on class frequency for balancing.""" 

160 # Include special handling for background class if it exists 

161 class_counts = Counter(self._molecule_ids) 

162 total_samples = len(self._molecule_ids) 

163 

164 # Assign weights inversely proportional to class frequency 

165 class_weights = {cls: total_samples / count for cls, count in class_counts.items()} 

166 

167 # Compute weights for each sample 

168 self.sample_weights = [class_weights[mol_id] for mol_id in self._molecule_ids] 

169 

170 def _get_cache_path(self): 

171 """Get the appropriate cache file path based on format.""" 

172 # If we have a copick_root but no config_path, use dataset IDs 

173 cache_key = self.config_path 

174 if cache_key is None and self.copick_root is not None: 

175 # Try to get dataset IDs from the datasets attribute 

176 try: 

177 dataset_ids = [] 

178 for dataset in self.copick_root.datasets: 

179 if hasattr(dataset, "id"): 

180 dataset_ids.append(str(dataset.id)) 

181 

182 if dataset_ids: 

183 # Use the dataset IDs in order as the cache key 

184 dataset_ids_str = "_".join(dataset_ids) 

185 cache_key = f"datasets_{dataset_ids_str}" 

186 else: 

187 # Fallback if no dataset IDs found 

188 cache_key = "copick_root_unknown" 

189 except (AttributeError, IndexError): 

190 # Fallback if datasets attribute doesn't exist 

191 if hasattr(self.copick_root, "dataset_ids"): 

192 dataset_ids = [str(did) for did in self.copick_root.dataset_ids] 

193 cache_key = f"datasets_{'_'.join(dataset_ids)}" 

194 else: 

195 # Last resort fallback 

196 cache_key = f"copick_root_{hash(str(self.copick_root))}" 

197 

198 if self.cache_format == "pickle": 

199 return os.path.join( 

200 self.cache_dir, 

201 f"{cache_key}_{self.boxsize[0]}x{self.boxsize[1]}x{self.boxsize[2]}" 

202 f"_{self.voxel_spacing}" 

203 f"{'_with_bg' if self.include_background else ''}.pkl", 

204 ) 

205 else: # parquet 

206 return os.path.join( 

207 self.cache_dir, 

208 f"{cache_key}_{self.boxsize[0]}x{self.boxsize[1]}x{self.boxsize[2]}" 

209 f"_{self.voxel_spacing}" 

210 f"{'_with_bg' if self.include_background else ''}.parquet", 

211 ) 

212 

213 def _load_or_process_data(self): 

214 """Load data from cache or process it directly.""" 

215 # If cache_dir is None, process data directly without caching 

216 if self.cache_dir is None: 

217 print("Cache directory not specified. Processing data without caching...") 

218 self._load_data() 

219 return 

220 

221 # If cache_dir is specified, use caching logic 

222 os.makedirs(self.cache_dir, exist_ok=True) 

223 cache_file = self._get_cache_path() 

224 

225 if os.path.exists(cache_file): 

226 print(f"Loading cached data from {cache_file}") 

227 

228 if self.cache_format == "pickle": 

229 self._load_from_pickle(cache_file) 

230 else: # parquet 

231 self._load_from_parquet(cache_file) 

232 

233 # Apply max_samples limit if specified 

234 if self.max_samples is not None and len(self._subvolumes) > self.max_samples: 

235 indices = np.random.choice(len(self._subvolumes), self.max_samples, replace=False) 

236 self._subvolumes = np.array(self._subvolumes)[indices] 

237 self._molecule_ids = np.array(self._molecule_ids)[indices] 

238 if self._is_background: 

239 self._is_background = np.array(self._is_background)[indices] 

240 else: 

241 print("Processing data and creating cache...") 

242 self._load_data() 

243 

244 # Only save to cache if we actually loaded some data 

245 if len(self._subvolumes) > 0: 

246 if self.cache_format == "pickle": 

247 self._save_to_pickle(cache_file) 

248 else: # parquet 

249 self._save_to_parquet(cache_file) 

250 print(f"Cached data saved to {cache_file}") 

251 else: 

252 print("No data loaded, skipping cache creation") 

253 

254 def _load_from_pickle(self, cache_file): 

255 """Load dataset from pickle cache.""" 

256 with open(cache_file, "rb") as f: 

257 cached_data = pickle.load(f) 

258 self._subvolumes = cached_data.get("subvolumes", []) 

259 self._molecule_ids = cached_data.get("molecule_ids", []) 

260 self._keys = cached_data.get("keys", []) 

261 self._is_background = cached_data.get("is_background", []) 

262 

263 # Handle case where background flag wasn't saved 

264 if not self._is_background and self.include_background: 

265 # Initialize all as non-background 

266 self._is_background = [False] * len(self._subvolumes) 

267 

268 def _save_to_pickle(self, cache_file): 

269 """Save dataset to pickle cache.""" 

270 with open(cache_file, "wb") as f: 

271 pickle.dump( 

272 { 

273 "subvolumes": self._subvolumes, 

274 "molecule_ids": self._molecule_ids, 

275 "keys": self._keys, 

276 "is_background": self._is_background, 

277 }, 

278 f, 

279 ) 

280 

281 def _load_from_parquet(self, cache_file): 

282 """Load dataset from parquet cache.""" 

283 try: 

284 df = pd.read_parquet(cache_file) 

285 

286 # Process subvolumes from bytes back to numpy arrays 

287 self._subvolumes = [] 

288 for idx, row in df.iterrows(): 

289 if isinstance(row["subvolume"], bytes): 

290 # Reconstruct numpy array from bytes 

291 subvol = np.frombuffer(row["subvolume"], dtype=np.float32) 

292 shape = row["shape"] 

293 if isinstance(shape, list): 

294 shape = tuple(shape) 

295 subvol = subvol.reshape(shape) 

296 self._subvolumes.append(subvol) 

297 else: 

298 raise ValueError(f"Invalid subvolume format: {type(row['subvolume'])}") 

299 

300 # Convert to numpy array 

301 self._subvolumes = np.array(self._subvolumes) 

302 

303 # Extract other fields 

304 self._molecule_ids = df["molecule_id"].tolist() 

305 self._keys = df["key"].tolist() if "key" in df.columns else [] 

306 

307 # Load or initialize background flags 

308 if "is_background" in df.columns: 

309 self._is_background = df["is_background"].tolist() 

310 else: 

311 self._is_background = [False] * len(self._subvolumes) 

312 

313 # Reconstruct unique keys if not available 

314 if not self._keys: 

315 unique_ids = set() 

316 for mol_id in self._molecule_ids: 

317 if mol_id != -1 and not df.loc[df["molecule_id"] == mol_id, "is_background"].any(): 

318 unique_ids.add(mol_id) 

319 self._keys = sorted(unique_ids) 

320 

321 except Exception as e: 

322 print(f"Error loading from parquet: {str(e)}") 

323 raise 

324 

325 def _save_to_parquet(self, cache_file): 

326 """Save dataset to parquet cache.""" 

327 try: 

328 # Check if we have any data to save 

329 if len(self._subvolumes) == 0: 

330 print("No data to save to parquet") 

331 return 

332 

333 # Prepare records 

334 records = [] 

335 for subvol, mol_id, is_bg in zip(self._subvolumes, self._molecule_ids, self._is_background): 

336 record = { 

337 "subvolume": subvol.tobytes(), 

338 "shape": list(subvol.shape), 

339 "molecule_id": mol_id, 

340 "is_background": is_bg, 

341 } 

342 records.append(record) 

343 

344 # Add keys information 

345 key_mapping = [] 

346 for i, key in enumerate(self._keys): 

347 key_mapping.append({"key_index": i, "key": key}) 

348 

349 # Create and save main dataframe 

350 df = pd.DataFrame(records) 

351 

352 # Add keys as a column for each row 

353 df["key"] = df["molecule_id"].apply( 

354 lambda x: self._keys[x] if x != -1 and x < len(self._keys) else "background", 

355 ) 

356 

357 df.to_parquet(cache_file, index=False) 

358 

359 # Save additional metadata 

360 metadata_file = cache_file.replace(".parquet", "_metadata.parquet") 

361 metadata = { 

362 "creation_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 

363 "total_samples": len(records), 

364 "unique_molecules": len(self._keys), 

365 "boxsize": self.boxsize, 

366 "include_background": self.include_background, 

367 "background_samples": sum(self._is_background), 

368 } 

369 pd.DataFrame([metadata]).to_parquet(metadata_file, index=False) 

370 

371 except Exception as e: 

372 print(f"Error saving to parquet: {str(e)}") 

373 raise 

374 

375 def _extract_subvolume_with_validation(self, tomogram_array, x, y, z): 

376 """Extract a subvolume with validation checks, applying the selected patch strategy.""" 

377 half_box = np.array(self.boxsize) // 2 

378 

379 # Apply patch strategy 

380 if self.patch_strategy == "centered": 

381 # Standard centered extraction 

382 offset_x, offset_y, offset_z = 0, 0, 0 

383 elif self.patch_strategy == "random": 

384 # Random offsets within half_box/4 

385 max_offset = [size // 4 for size in half_box] 

386 offset_x = np.random.randint(-max_offset[0], max_offset[0] + 1) 

387 offset_y = np.random.randint(-max_offset[1], max_offset[1] + 1) 

388 offset_z = np.random.randint(-max_offset[2], max_offset[2] + 1) 

389 elif self.patch_strategy == "jittered": 

390 # Small random jitter for data augmentation 

391 offset_x = np.random.randint(-2, 3) 

392 offset_y = np.random.randint(-2, 3) 

393 offset_z = np.random.randint(-2, 3) 

394 

395 # Apply offsets to coordinates 

396 x_adj = x + offset_x 

397 y_adj = y + offset_y 

398 z_adj = z + offset_z 

399 

400 # Calculate slice indices 

401 x_start = max(0, int(x_adj - half_box[0])) 

402 x_end = min(tomogram_array.shape[2], int(x_adj + half_box[0])) 

403 y_start = max(0, int(y_adj - half_box[1])) 

404 y_end = min(tomogram_array.shape[1], int(y_adj + half_box[1])) 

405 z_start = max(0, int(z_adj - half_box[2])) 

406 z_end = min(tomogram_array.shape[0], int(z_adj + half_box[2])) 

407 

408 # Validate slice ranges 

409 if x_end <= x_start or y_end <= y_start or z_end <= z_start: 

410 return None, False, "Invalid slice range" 

411 

412 # Extract subvolume 

413 subvolume = tomogram_array[z_start:z_end, y_start:y_end, x_start:x_end] 

414 

415 # Check if extracted shape matches requested size 

416 if subvolume.shape != self.boxsize: 

417 # Need to pad the subvolume 

418 padded = np.zeros(self.boxsize, dtype=subvolume.dtype) 

419 

420 # Calculate padding amounts 

421 z_pad_start = max(0, half_box[2] - int(z)) 

422 y_pad_start = max(0, half_box[1] - int(y)) 

423 x_pad_start = max(0, half_box[0] - int(x)) 

424 

425 # Calculate end indices 

426 z_pad_end = min(z_pad_start + (z_end - z_start), self.boxsize[0]) 

427 y_pad_end = min(y_pad_start + (y_end - y_start), self.boxsize[1]) 

428 x_pad_end = min(x_pad_start + (x_end - x_start), self.boxsize[2]) 

429 

430 # Copy data 

431 padded[z_pad_start:z_pad_end, y_pad_start:y_pad_end, x_pad_start:x_pad_end] = subvolume 

432 return padded, True, "padded" 

433 

434 return subvolume, True, "valid" 

435 

436 def _load_data(self): 

437 """Load particle picks data from copick project.""" 

438 # Determine which root to use 

439 if self.copick_root is not None: 

440 root = self.copick_root 

441 print("Using provided copick root object") 

442 else: 

443 try: 

444 root = copick.from_file(self.config_path) 

445 print(f"Loading data from {self.config_path}") 

446 except Exception as e: 

447 print(f"Failed to load copick root: {str(e)}") 

448 return 

449 

450 # Store all particle coordinates for background sampling 

451 all_particle_coords = [] 

452 

453 for run in root.runs: 

454 print(f"Processing run: {run.name}") 

455 

456 # Try to load tomogram 

457 try: 

458 voxel_spacing_obj = run.get_voxel_spacing(self.voxel_spacing) 

459 if ( 

460 voxel_spacing_obj is None 

461 or not hasattr(voxel_spacing_obj, "tomograms") 

462 or not voxel_spacing_obj.tomograms 

463 ): 

464 print(f"No tomograms found for run {run.name} at voxel spacing {self.voxel_spacing}") 

465 continue 

466 

467 tomogram = voxel_spacing_obj.tomograms[0] 

468 tomogram_array = tomogram.numpy() 

469 except Exception as e: 

470 print(f"Error loading tomogram for run {run.name}: {str(e)}") 

471 continue 

472 

473 # Process picks 

474 run_particle_coords = [] # Store coordinates for this run 

475 

476 for picks in run.get_picks(): 

477 if not picks.from_tool: 

478 continue 

479 

480 object_name = picks.pickable_object_name 

481 

482 try: 

483 points, _ = picks.numpy() 

484 points = points / self.voxel_spacing 

485 

486 for point in points: 

487 try: 

488 x, y, z = point 

489 

490 # Save for background sampling 

491 run_particle_coords.append((x, y, z)) 

492 

493 # Extract subvolume 

494 subvolume, is_valid, _ = self._extract_subvolume_with_validation(tomogram_array, x, y, z) 

495 

496 if is_valid: 

497 self._subvolumes.append(subvolume) 

498 

499 if object_name not in self._keys: 

500 self._keys.append(object_name) 

501 

502 self._molecule_ids.append(self._keys.index(object_name)) 

503 self._is_background.append(False) 

504 except Exception as e: 

505 print(f"Error extracting subvolume: {str(e)}") 

506 except Exception as e: 

507 print(f"Error processing picks for {object_name}: {str(e)}") 

508 

509 # Sample background points for this run if needed 

510 if self.include_background and run_particle_coords: 

511 all_particle_coords.extend(run_particle_coords) 

512 self._sample_background_points(tomogram_array, run_particle_coords) 

513 

514 self._subvolumes = np.array(self._subvolumes) 

515 self._molecule_ids = np.array(self._molecule_ids) 

516 self._is_background = np.array(self._is_background) 

517 

518 # Apply max_samples limit if specified 

519 if self.max_samples is not None and len(self._subvolumes) > self.max_samples: 

520 indices = np.random.choice(len(self._subvolumes), self.max_samples, replace=False) 

521 self._subvolumes = self._subvolumes[indices] 

522 self._molecule_ids = self._molecule_ids[indices] 

523 self._is_background = self._is_background[indices] 

524 

525 print(f"Loaded {len(self._subvolumes)} subvolumes with {len(self._keys)} classes") 

526 print(f"Background samples: {sum(self._is_background)}") 

527 

528 def _sample_background_points(self, tomogram_array, particle_coords): 

529 """Sample background points away from particles.""" 

530 if not particle_coords: 

531 return 

532 

533 # Convert to numpy array for distance calculations 

534 particle_coords = np.array(particle_coords) 

535 

536 # Calculate number of background samples based on ratio 

537 num_particles = len(particle_coords) 

538 num_background = int(num_particles * self.background_ratio) 

539 

540 # Limit attempts to avoid infinite loop 

541 max_attempts = num_background * 10 

542 attempts = 0 

543 bg_points_found = 0 

544 

545 half_box = np.array(self.boxsize) // 2 

546 

547 while bg_points_found < num_background and attempts < max_attempts: 

548 # Generate random point within tomogram bounds with margin for box extraction 

549 random_point = np.array( 

550 [ 

551 np.random.randint(half_box[0], tomogram_array.shape[2] - half_box[0]), 

552 np.random.randint(half_box[1], tomogram_array.shape[1] - half_box[1]), 

553 np.random.randint(half_box[2], tomogram_array.shape[0] - half_box[2]), 

554 ], 

555 ) 

556 

557 # Calculate distances to all particles 

558 distances = np.linalg.norm(particle_coords - random_point, axis=1) 

559 

560 # Check if point is far enough from all particles 

561 if np.min(distances) >= self.min_background_distance: 

562 # Extract subvolume 

563 x, y, z = random_point 

564 subvolume, is_valid, _ = self._extract_subvolume_with_validation(tomogram_array, x, y, z) 

565 

566 if is_valid: 

567 self._subvolumes.append(subvolume) 

568 self._molecule_ids.append(-1) # Use -1 to indicate background 

569 self._is_background.append(True) 

570 bg_points_found += 1 

571 

572 attempts += 1 

573 

574 print(f"Added {bg_points_found} background points after {attempts} attempts") 

575 

576 def _augment_subvolume(self, subvolume, idx=None): 

577 """Apply data augmentation to a subvolume. 

578 

579 This simplified version applies basic augmentations only (no mixup). 

580 

581 Args: 

582 subvolume: The 3D volume to augment 

583 idx: Optional index for mixup augmentation (not used in this version) 

584 

585 Returns: 

586 Augmented subvolume 

587 """ 

588 # Apply random brightness adjustment 

589 if random.random() < 0.3: 

590 delta = np.random.uniform(-0.5, 0.5) 

591 subvolume = subvolume + delta 

592 

593 # Apply random Gaussian blur 

594 if random.random() < 0.2: 

595 sigma = np.random.uniform(0.5, 1.5) 

596 subvolume = gaussian_filter(subvolume, sigma=sigma) 

597 

598 # Apply random intensity scaling 

599 if random.random() < 0.2: 

600 factor = np.random.uniform(0.5, 1.5) 

601 subvolume = subvolume * factor 

602 

603 # Apply random flip 

604 if random.random() < 0.2: 

605 axis = random.randint(0, 2) 

606 subvolume = np.flip(subvolume, axis=axis) 

607 

608 # Apply random rotation 

609 if random.random() < 0.2: 

610 k = random.randint(1, 3) # 90, 180, or 270 degrees 

611 axes = tuple(random.sample([0, 1, 2], 2)) # Select 2 random axes 

612 subvolume = np.rot90(subvolume, k=k, axes=axes) 

613 

614 # Apply Fourier domain augmentation 

615 if random.random() < 0.3: # 30% chance to apply Fourier augmentation 

616 fourier_aug = FourierAugment3D(freq_mask_prob=0.3, phase_noise_std=0.1, intensity_scaling_range=(0.8, 1.2)) 

617 subvolume = fourier_aug(subvolume) 

618 

619 return subvolume 

620 

621 def __len__(self): 

622 """Get the total number of items in the dataset.""" 

623 return len(self._subvolumes) 

624 

625 def get_sample_weights(self): 

626 """Return sample weights for use in a WeightedRandomSampler.""" 

627 return self.sample_weights 

628 

629 def keys(self): 

630 """Get pickable object keys.""" 

631 return self._keys 

632 

633 def get_class_distribution(self): 

634 """Get distribution of classes in the dataset.""" 

635 class_counts = Counter(self._molecule_ids) 

636 

637 # Create a readable distribution 

638 distribution = {} 

639 

640 # Count background samples if any 

641 if -1 in class_counts: 

642 distribution["background"] = class_counts[-1] 

643 del class_counts[-1] 

644 

645 # Count regular classes 

646 for cls_idx, count in class_counts.items(): 

647 if 0 <= cls_idx < len(self._keys): 

648 distribution[self._keys[cls_idx]] = count 

649 

650 return distribution 

651 

652 

653class SplicedMixupDataset(SimpleCopickDataset): 

654 """ 

655 A dataset that loads zarr arrays into memory and performs balanced sampling with mixup splicing. 

656 

657 This dataset extends SimpleCopickDataset to add experimental-synthetic data splicing capabilities, 

658 keeping zarr arrays in memory for faster loading and using balanced sampling by default. 

659 """ 

660 

661 def __init__( 

662 self, 

663 exp_dataset_id: int, 

664 synth_dataset_id: int, 

665 synth_run_id: str = "16487", 

666 overlay_root: str = "/tmp/test/", 

667 boxsize: Tuple[int, int, int] = (48, 48, 48), 

668 augment: bool = True, 

669 cache_dir: Optional[str] = None, 

670 cache_format: str = "parquet", 

671 seed: Optional[int] = 1717, 

672 max_samples: Optional[int] = None, 

673 voxel_spacing: float = 10.0, 

674 include_background: bool = False, 

675 background_ratio: float = 0.2, 

676 min_background_distance: Optional[float] = None, 

677 blend_sigma: float = 2.0, # Controls the standard deviation of Gaussian blending at boundaries 

678 mixup_alpha: float = 0.2, 

679 debug_mode: bool = False, 

680 ): 

681 """ 

682 Initialize the SplicedMixupDataset. 

683 

684 Args: 

685 exp_dataset_id: Dataset ID for the experimental dataset 

686 synth_dataset_id: Dataset ID for the synthetic dataset 

687 synth_run_id: Run ID for the synthetic dataset (default: "16487") 

688 overlay_root: Root directory for the overlay storage (default: "/tmp/test/") 

689 boxsize: Size of the subvolumes to extract (z, y, x) 

690 augment: Whether to apply data augmentation 

691 cache_dir: Directory to cache extracted subvolumes 

692 cache_format: Format for caching ('pickle' or 'parquet') 

693 seed: Random seed for reproducibility 

694 max_samples: Maximum number of samples to use 

695 voxel_spacing: Voxel spacing to use for extraction 

696 include_background: Whether to include background samples 

697 background_ratio: Ratio of background to particle samples 

698 min_background_distance: Minimum distance from particles for background samples 

699 blend_sigma: Controls the standard deviation of Gaussian blending at boundaries 

700 mixup_alpha: Alpha parameter for mixup augmentation 

701 debug_mode: Whether to enable debug mode 

702 """ 

703 # Save specific parameters 

704 self.exp_dataset_id = exp_dataset_id 

705 self.synth_dataset_id = synth_dataset_id 

706 self.synth_run_id = synth_run_id 

707 self.overlay_root = overlay_root 

708 self.blend_sigma = blend_sigma 

709 self.mixup_alpha = mixup_alpha 

710 

711 # Initialize load flags and storage for zarr arrays 

712 self._zarr_loaded = False 

713 self._exp_zarr_data = None 

714 self._synth_zarr_data = None 

715 self._synth_mask_data = {} 

716 

717 # Load copick roots 

718 self._load_copick_roots() 

719 

720 # Initialize the parent class (SimpleCopickDataset) 

721 # We'll override certain methods to use our in-memory zarr arrays 

722 super().__init__( 

723 copick_root=self.exp_root if hasattr(self, "exp_root") else None, # Use experimental data as the base 

724 boxsize=boxsize, 

725 augment=augment, 

726 cache_dir=cache_dir, 

727 cache_format=cache_format, 

728 seed=seed, 

729 max_samples=max_samples, 

730 voxel_spacing=voxel_spacing, 

731 include_background=include_background, 

732 background_ratio=background_ratio, 

733 min_background_distance=min_background_distance, 

734 patch_strategy="centered", # Always use centered for splicing 

735 debug_mode=debug_mode, 

736 ) 

737 

738 # Load zarr arrays into memory if not already loaded 

739 self._ensure_zarr_loaded() 

740 

741 # Initialize dataset with a small number of samples to make sure the parent initialization works 

742 # We will create our own samples from zarr arrays after parent initialization 

743 

744 # Generate synthetic samples directly from zarr arrays 

745 self._generate_synthetic_samples() 

746 

747 def _generate_synthetic_samples(self): 

748 """Generate synthetic samples directly from zarr arrays.""" 

749 print("Generating synthetic samples from zarr arrays...") 

750 # Clear any existing samples 

751 self._subvolumes = [] 

752 self._molecule_ids = [] 

753 self._keys = [] 

754 self._is_background = [] 

755 

756 num_samples = 100 # Default number of samples 

757 if self.max_samples is not None: 

758 num_samples = self.max_samples 

759 

760 # Generate samples (half from experimental data, half from synthetic+experimental splice) 

761 num_exp_samples = num_samples // 2 

762 num_synth_samples = num_samples - num_exp_samples 

763 

764 # Generate experimental samples 

765 for _ in range(num_exp_samples): 

766 # Extract a random crop from experimental data 

767 exp_crop = self._extract_random_crop(self._exp_zarr_data, self.boxsize) 

768 self._subvolumes.append(exp_crop) 

769 self._molecule_ids.append(-1) # Background class 

770 self._is_background.append(True) 

771 

772 # Generate synthetic+experimental spliced samples 

773 for _ in range(num_synth_samples): 

774 # Get a random mask name 

775 mask_names = list(self._synth_mask_data.keys()) 

776 mask_name = random.choice(mask_names) 

777 

778 # Extract a bounding box 

779 bbox_info = self._extract_bounding_box(self._synth_mask_data[mask_name], mask_name) 

780 

781 if bbox_info is not None: 

782 # Extract a random crop from experimental data 

783 exp_crop = self._extract_random_crop(self._exp_zarr_data, self.boxsize) 

784 

785 # Splice the volumes 

786 spliced_volume = self._splice_volumes(bbox_info["synth_region"], bbox_info["region_mask"], exp_crop) 

787 

788 # Add to dataset 

789 self._subvolumes.append(spliced_volume) 

790 

791 # Get or create molecule index 

792 if bbox_info["object_name"] not in self._keys: 

793 self._keys.append(bbox_info["object_name"]) 

794 molecule_idx = self._keys.index(bbox_info["object_name"]) 

795 

796 self._molecule_ids.append(molecule_idx) 

797 self._is_background.append(False) 

798 

799 # Convert to numpy arrays 

800 self._subvolumes = np.array(self._subvolumes) 

801 self._molecule_ids = np.array(self._molecule_ids) 

802 self._is_background = np.array(self._is_background) 

803 

804 # Compute sample weights 

805 self._compute_sample_weights() 

806 

807 print(f"Generated {len(self._subvolumes)} samples with {len(self._keys)} classes") 

808 print(f"Background samples: {sum(self._is_background)}") 

809 print(f"Class distribution: {self.get_class_distribution()}") 

810 

811 def _load_copick_roots(self): 

812 """Load the experimental and synthetic copick roots.""" 

813 try: 

814 print(f"Loading experimental dataset {self.exp_dataset_id} and synthetic dataset {self.synth_dataset_id}") 

815 self.exp_root = copick.from_czcdp_datasets([self.exp_dataset_id], overlay_root=self.overlay_root) 

816 self.synth_root = copick.from_czcdp_datasets([self.synth_dataset_id], overlay_root=self.overlay_root) 

817 

818 print(f"Experimental dataset: {len(self.exp_root.runs)} runs") 

819 print(f"Synthetic dataset: {len(self.synth_root.runs)} runs") 

820 

821 # Filter synthetic dataset to only include the specified run 

822 if self.synth_run_id: 

823 print(f"Filtering synthetic dataset to only use run {self.synth_run_id}") 

824 filtered_runs = [run for run in self.synth_root.runs if run.meta.name == self.synth_run_id] 

825 if filtered_runs: 

826 print(f"Found run {self.synth_run_id}. Using only this run.") 

827 # Store the filtered run for use in loading zarr data 

828 self.synth_root._filtered_run = filtered_runs[0] 

829 else: 

830 print(f"Run {self.synth_run_id} not found in synthetic dataset. Using all available runs.") 

831 self.synth_root._filtered_run = None 

832 except Exception as e: 

833 print(f"Error loading CoPick roots: {str(e)}") 

834 raise 

835 

836 def _ensure_zarr_loaded(self): 

837 """Load zarr arrays into memory if not already loaded.""" 

838 if not self._zarr_loaded: 

839 self._load_experimental_zarr() 

840 self._load_synthetic_zarr() 

841 self._load_segmentation_masks() 

842 self._zarr_loaded = True 

843 print("Zarr arrays loaded into memory.") 

844 

845 def _load_experimental_zarr(self): 

846 """Load experimental tomogram into memory.""" 

847 try: 

848 # Get available tomograms from experimental dataset 

849 exp_tomograms = self._get_available_tomograms(self.exp_root, self.voxel_spacing) 

850 

851 if not exp_tomograms: 

852 raise ValueError(f"No experimental tomograms found with voxel spacing {self.voxel_spacing}") 

853 

854 # Select the first tomogram 

855 exp_tomogram_obj = exp_tomograms[0] 

856 exp_zarr = zarr.open(exp_tomogram_obj.zarr(), "r") 

857 self._exp_zarr_data = exp_zarr["0"][:] 

858 

859 # Normalize tomogram data 

860 self._exp_zarr_data = (self._exp_zarr_data - np.mean(self._exp_zarr_data)) / np.std(self._exp_zarr_data) 

861 print(f"Loaded experimental zarr with shape {self._exp_zarr_data.shape}") 

862 except Exception as e: 

863 print(f"Error loading experimental zarr: {str(e)}") 

864 raise 

865 

866 def _load_synthetic_zarr(self): 

867 """Load synthetic tomogram into memory.""" 

868 try: 

869 # Get available tomograms from synthetic dataset 

870 synth_tomograms = self._get_available_tomograms(self.synth_root, self.voxel_spacing) 

871 

872 if not synth_tomograms: 

873 raise ValueError(f"No synthetic tomograms found with voxel spacing {self.voxel_spacing}") 

874 

875 # Select the first tomogram 

876 synth_tomogram_obj = synth_tomograms[0] 

877 synth_zarr = zarr.open(synth_tomogram_obj.zarr(), "r") 

878 self._synth_zarr_data = synth_zarr["0"][:] 

879 

880 # Normalize tomogram data 

881 self._synth_zarr_data = (self._synth_zarr_data - np.mean(self._synth_zarr_data)) / np.std( 

882 self._synth_zarr_data, 

883 ) 

884 print(f"Loaded synthetic zarr with shape {self._synth_zarr_data.shape}") 

885 except Exception as e: 

886 print(f"Error loading synthetic zarr: {str(e)}") 

887 raise 

888 

889 def _get_available_tomograms(self, root, voxel_spacing, tomo_type="wbp"): 

890 """Get available tomograms from a CoPick dataset for a specific voxel spacing.""" 

891 available_tomograms = [] 

892 

893 # If a filtered run is specified, only use that run 

894 if hasattr(root, "_filtered_run") and root._filtered_run is not None: 

895 runs = [root._filtered_run] 

896 print(f"Using only filtered run: {root._filtered_run.meta.name}") 

897 else: 

898 runs = root.runs 

899 

900 for run in runs: 

901 # Get the closest voxel spacing to the target 

902 closest_vs = None 

903 min_diff = float("inf") 

904 

905 for vs in run.voxel_spacings: 

906 diff = abs(vs.meta.voxel_size - voxel_spacing) 

907 if diff < min_diff: 

908 min_diff = diff 

909 closest_vs = vs 

910 

911 if closest_vs: 

912 tomograms = closest_vs.get_tomograms(tomo_type) 

913 if tomograms: 

914 available_tomograms.extend(tomograms) 

915 print( 

916 f"Found {len(tomograms)} tomograms in run {run.meta.name} with voxel spacing {closest_vs.meta.voxel_size}", 

917 ) 

918 

919 return available_tomograms 

920 

921 def _load_segmentation_masks(self): 

922 """Load segmentation masks from synthetic dataset into memory.""" 

923 try: 

924 # Get segmentation masks from synthetic dataset 

925 segmentation_masks = self._get_segmentation_masks(self.synth_root, self.voxel_spacing) 

926 

927 if not segmentation_masks: 

928 raise ValueError(f"No segmentation masks found with voxel spacing {self.voxel_spacing}") 

929 

930 # Load each mask into memory 

931 for mask_name, mask_obj in segmentation_masks.items(): 

932 # Skip membrane segmentation masks 

933 if mask_name.lower() == "membrane": 

934 print(f"Skipping membrane segmentation mask: {mask_name}") 

935 continue 

936 

937 # Access the mask data 

938 mask_zarr = zarr.open(mask_obj.zarr(), "r") 

939 mask_data = mask_zarr["data" if "data" in mask_zarr else "0"][:] 

940 

941 # Store the mask data 

942 self._synth_mask_data[mask_name] = mask_data 

943 print(f"Loaded mask '{mask_name}' with shape {mask_data.shape}") 

944 

945 print(f"Loaded {len(self._synth_mask_data)} segmentation masks") 

946 except Exception as e: 

947 print(f"Error loading segmentation masks: {str(e)}") 

948 raise 

949 

950 def _get_segmentation_masks(self, root, voxel_spacing, pickable_objects=None): 

951 """Get segmentation masks from a CoPick dataset for a specific voxel spacing.""" 

952 segmentation_masks = {} 

953 

954 # If a filtered run is specified, only use that run 

955 if hasattr(root, "_filtered_run") and root._filtered_run is not None: 

956 runs = [root._filtered_run] 

957 print(f"Using only filtered run: {root._filtered_run.meta.name} for segmentation masks") 

958 else: 

959 runs = root.runs 

960 

961 for run in runs: 

962 # Get the closest voxel spacing to the target 

963 closest_vs = None 

964 min_diff = float("inf") 

965 

966 for vs in run.voxel_spacings: 

967 diff = abs(vs.meta.voxel_size - voxel_spacing) 

968 if diff < min_diff: 

969 min_diff = diff 

970 closest_vs = vs 

971 

972 if closest_vs: 

973 segmentations = run.get_segmentations(voxel_size=closest_vs.meta.voxel_size) 

974 

975 for seg in segmentations: 

976 # Only include segmentations matching requested pickable objects 

977 if pickable_objects is None or seg.meta.name in pickable_objects: 

978 segmentation_masks[seg.meta.name] = seg 

979 print(f"Found segmentation mask for '{seg.meta.name}' in run {run.meta.name}") 

980 

981 return segmentation_masks 

982 

983 def _extract_random_crop(self, tomogram_data, crop_size): 

984 """Extract a random crop from a tomogram.""" 

985 depth, height, width = tomogram_data.shape 

986 

987 # Ensure crop sizes don't exceed tomogram dimensions 

988 crop_depth = min(crop_size[0], depth) 

989 crop_height = min(crop_size[1], height) 

990 crop_width = min(crop_size[2], width) 

991 

992 # Calculate valid ranges for the random crop 

993 max_z = depth - crop_depth 

994 max_y = height - crop_height 

995 max_x = width - crop_width 

996 

997 if max_z <= 0 or max_y <= 0 or max_x <= 0: 

998 # Tomogram is smaller than crop size in at least one dimension 

999 return resize(tomogram_data, crop_size, mode="reflect", anti_aliasing=True) 

1000 

1001 # Get random start coordinates 

1002 z_start = random.randint(0, max_z) 

1003 y_start = random.randint(0, max_y) 

1004 x_start = random.randint(0, max_x) 

1005 

1006 # Extract the crop 

1007 crop = tomogram_data[ 

1008 z_start : z_start + crop_depth, 

1009 y_start : y_start + crop_height, 

1010 x_start : x_start + crop_width, 

1011 ] 

1012 

1013 return crop 

1014 

1015 def _extract_bounding_box(self, mask_data, mask_name): 

1016 """Extract a bounding box for a connected component in a segmentation mask.""" 

1017 # Label connected components 

1018 labels = measure.label(mask_data > 0) 

1019 regions = measure.regionprops(labels) 

1020 

1021 if not regions: 

1022 return None 

1023 

1024 # Select a random region to extract 

1025 region = random.choice(regions) 

1026 

1027 # Get the centroid of the region 

1028 z_center, y_center, x_center = region.centroid 

1029 

1030 # Calculate box boundaries centered on the particle 

1031 box_size = self.boxsize[0] # Assume cubic box 

1032 half_size = box_size // 2 

1033 

1034 z_min = max(0, int(z_center - half_size)) 

1035 y_min = max(0, int(y_center - half_size)) 

1036 x_min = max(0, int(x_center - half_size)) 

1037 

1038 # Adjust if box would go beyond bounds 

1039 if z_min + box_size > mask_data.shape[0]: 

1040 z_min = max(0, mask_data.shape[0] - box_size) 

1041 if y_min + box_size > mask_data.shape[1]: 

1042 y_min = max(0, mask_data.shape[1] - box_size) 

1043 if x_min + box_size > mask_data.shape[2]: 

1044 x_min = max(0, mask_data.shape[2] - box_size) 

1045 

1046 # Calculate max coordinates 

1047 z_max = min(mask_data.shape[0], z_min + box_size) 

1048 y_max = min(mask_data.shape[1], y_min + box_size) 

1049 x_max = min(mask_data.shape[2], x_min + box_size) 

1050 

1051 # Check if we can extract a full box 

1052 if (z_max - z_min) != box_size or (y_max - y_min) != box_size or (x_max - x_min) != box_size: 

1053 print(f"Cannot extract a full {box_size}^3 box at the edge of the volume.") 

1054 return None 

1055 

1056 # Create a mask for this specific region 

1057 region_mask = np.zeros(mask_data.shape, dtype=bool) 

1058 region_mask[labels == region.label] = True 

1059 

1060 # Dilate the mask slightly for smoother boundaries 

1061 dilated_mask = binary_dilation(region_mask, iterations=2) 

1062 

1063 # Extract the fixed-size box from the mask 

1064 box_mask = dilated_mask[z_min:z_max, y_min:y_max, x_min:x_max].copy() 

1065 

1066 # Verify box mask has expected dimensions 

1067 if box_mask.shape != (box_size, box_size, box_size): 

1068 print(f"Box mask has unexpected shape: {box_mask.shape}") 

1069 return None 

1070 

1071 # Extract corresponding region from synthetic tomogram 

1072 try: 

1073 synth_region = self._synth_zarr_data[z_min:z_max, y_min:y_max, x_min:x_max].copy() 

1074 if synth_region.shape != box_mask.shape: 

1075 synth_region = resize(synth_region, box_mask.shape, mode="reflect", anti_aliasing=True) 

1076 except Exception as e: 

1077 print(f"Error extracting synthetic region: {e}") 

1078 return None 

1079 

1080 # Return the bounding box info 

1081 return { 

1082 "bbox": (z_min, y_min, x_min, z_max, y_max, x_max), 

1083 "region_mask": box_mask, 

1084 "synth_region": synth_region, 

1085 "object_name": mask_name, 

1086 "center": region.centroid, 

1087 } 

1088 

1089 def _splice_volumes(self, synthetic_region, region_mask, exp_crop): 

1090 """Splice a synthetic structure into an experimental tomogram using Gaussian blending at the edges.""" 

1091 # Create a spliced volume by starting with the experimental crop 

1092 spliced_volume = exp_crop.copy() 

1093 

1094 # For Gaussian blending, create a weight map that transitions smoothly from 1 to 0 

1095 if self.blend_sigma > 0: 

1096 try: 

1097 # Start with the region mask 

1098 mask_float = region_mask.astype(np.float32) 

1099 

1100 # Apply Gaussian blur to the mask to create a smooth transition at the boundaries 

1101 # This creates a weight map that goes from 1 (inside) to 0 (outside) with smooth transitions 

1102 weight_map = gaussian_filter(mask_float, sigma=self.blend_sigma) 

1103 

1104 # Normalize weight map to ensure it's between 0 and 1 

1105 weight_map = np.clip(weight_map, 0, 1) 

1106 

1107 # Apply weighted blending: synthetic * weight + experimental * (1-weight) 

1108 spliced_volume = synthetic_region * weight_map + exp_crop * (1 - weight_map) 

1109 

1110 except Exception as e: 

1111 print(f"Error during Gaussian blending at boundaries: {e}") 

1112 # Fall back to simple mask-based splicing 

1113 spliced_volume[region_mask] = synthetic_region[region_mask] 

1114 else: 

1115 # If blend_sigma is 0, just do simple mask-based splicing without blending 

1116 spliced_volume[region_mask] = synthetic_region[region_mask] 

1117 

1118 return spliced_volume 

1119 

1120 def __getitem__(self, idx): 

1121 """Get an item with spliced mixup augmentation.""" 

1122 # Ensure zarr data is loaded 

1123 self._ensure_zarr_loaded() 

1124 

1125 # Get the base subvolume using the parent method (without normalizing and tensor conversion) 

1126 subvolume = self._subvolumes[idx].copy() 

1127 molecule_idx = self._molecule_ids[idx] 

1128 

1129 # Get a random mask and extract a bounding box 

1130 mask_name = random.choice(list(self._synth_mask_data.keys())) 

1131 bbox_info = self._extract_bounding_box(self._synth_mask_data[mask_name], mask_name) 

1132 

1133 if bbox_info is not None: 

1134 # Extract a random crop from the experimental tomogram 

1135 exp_crop = self._extract_random_crop(self._exp_zarr_data, self.boxsize) 

1136 

1137 # Splice the volumes 

1138 spliced_volume = self._splice_volumes(bbox_info["synth_region"], bbox_info["region_mask"], exp_crop) 

1139 

1140 # Decide whether to use the spliced volume based on a random chance 

1141 if random.random() < 0.5: 

1142 subvolume = spliced_volume 

1143 # Get the molecule_idx for the synthetic object 

1144 if bbox_info["object_name"] not in self._keys: 

1145 self._keys.append(bbox_info["object_name"]) 

1146 molecule_idx = self._keys.index(bbox_info["object_name"]) 

1147 

1148 # Apply augmentations if enabled 

1149 if self.augment: 

1150 subvolume = self._augment_subvolume(subvolume) 

1151 

1152 # Normalize subvolume 

1153 subvolume = (subvolume - np.mean(subvolume)) / (np.std(subvolume) + 1e-6) 

1154 

1155 # Add channel dimension and convert to tensor 

1156 subvolume = torch.as_tensor(subvolume[None, ...], dtype=torch.float32) 

1157 

1158 # Return the subvolume and class index as a simple tuple 

1159 return subvolume, molecule_idx