Coverage for copick_torch/minimal_dataset.py: 6%

382 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1""" 

2A minimal CopickDataset implementation without caching or augmentation. 

3""" 

4 

5import json 

6import logging 

7import os 

8from collections import Counter 

9from types import SimpleNamespace 

10 

11import copick 

12import numpy as np 

13import torch 

14import zarr 

15from torch.utils.data import Dataset 

16from tqdm import tqdm 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21class MinimalCopickDataset(Dataset): 

22 """ 

23 A minimal PyTorch dataset for working with copick data that returns (image, label) pairs. 

24 

25 Unlike the SimpleCopickDataset, this implementation: 

26 1. Does not use caching (loads data on-the-fly) 

27 2. Does not include augmentation 

28 3. Has minimal dependencies 

29 4. Focuses on correct label mapping 

30 

31 This dataset can be saved to disk and loaded later for reproducibility. 

32 """ 

33 

34 def __init__( 

35 self, 

36 proj=None, 

37 dataset_id=None, 

38 overlay_root=None, 

39 boxsize=(48, 48, 48), 

40 voxel_spacing=10.012, 

41 include_background=False, 

42 background_ratio=0.2, 

43 min_background_distance=None, 

44 preload=True, 

45 ): 

46 """ 

47 Initialize a MinimalCopickDataset. 

48 

49 Args: 

50 proj: A copick project object. If provided, dataset_id and overlay_root are ignored. 

51 dataset_id: Dataset ID from the CZ cryoET Data Portal. Only used if proj is None. 

52 overlay_root: Root directory for the overlay storage. Only used if proj is None. 

53 boxsize: Size of the subvolumes to extract (z, y, x) 

54 voxel_spacing: Voxel spacing to use for extraction 

55 include_background: Whether to include background samples 

56 background_ratio: Ratio of background to particle samples 

57 min_background_distance: Minimum distance from particles for background samples 

58 preload: Whether to preload all subvolumes into memory (faster but more memory intensive) 

59 """ 

60 self.dataset_id = dataset_id 

61 self.overlay_root = overlay_root 

62 self.boxsize = boxsize 

63 self.voxel_spacing = voxel_spacing 

64 self.include_background = include_background 

65 self.background_ratio = background_ratio 

66 self.min_background_distance = min_background_distance or max(boxsize) 

67 self.preload = preload 

68 

69 # Initialize data structures 

70 self._points = [] # List of (x, y, z) coordinates 

71 self._labels = [] # List of class indices 

72 self._is_background = [] # List of booleans indicating if a sample is background 

73 self._tomogram_data = [] # List of tomogram zarr arrays 

74 self._name_to_label = {} # Mapping from object names to labels 

75 

76 # Storage for preloaded data 

77 self._subvolumes = None 

78 

79 # Set copick project 

80 self.copick_root = proj 

81 

82 # Load the data 

83 if self.copick_root is not None: 

84 self._load_data() 

85 elif dataset_id is not None and overlay_root is not None: 

86 # Create project from dataset_id and overlay_root 

87 self.copick_root = copick.from_czcdp_datasets([dataset_id], overlay_root=overlay_root) 

88 self._load_data() 

89 

90 def _extract_name_to_label(self): 

91 """Extract name to label mapping from pickable objects.""" 

92 # Create mapping from object names to labels 

93 self._name_to_label = {} 

94 for obj in self.copick_root.pickable_objects: 

95 self._name_to_label[obj.name] = obj.label 

96 

97 # Ensure we have a consistent list of object names 

98 self._object_names = list(self._name_to_label.keys()) 

99 

100 logger.info(f"Name to label mapping: {self._name_to_label}") 

101 

102 def _load_data(self): 

103 """Load data from the copick project.""" 

104 try: 

105 # Extract name to label mapping 

106 self._extract_name_to_label() 

107 

108 # Process each run 

109 all_points = [] 

110 all_labels = [] 

111 all_is_background = [] 

112 all_tomogram_indices = [] 

113 

114 for run_idx, run in enumerate(self.copick_root.runs): 

115 logger.info(f"Processing run: {run.name}") 

116 

117 # Get tomogram 

118 try: 

119 voxel_spacing_obj = run.get_voxel_spacing(self.voxel_spacing) 

120 if not voxel_spacing_obj or not voxel_spacing_obj.tomograms: 

121 logger.warning(f"No tomograms found for run {run.name} at voxel spacing {self.voxel_spacing}") 

122 continue 

123 

124 # Find a denoised tomogram if available, otherwise use the first one 

125 tomogram = [t for t in voxel_spacing_obj.tomograms if "wbp-denoised" in t.tomo_type] 

126 if not tomogram: 

127 tomogram = voxel_spacing_obj.tomograms[0] 

128 else: 

129 tomogram = tomogram[0] 

130 

131 # Open zarr array and load it fully into memory 

132 tomogram_zarr = zarr.open(tomogram.zarr())["0"] 

133 tomogram_data = np.array(tomogram_zarr[:]) 

134 self._tomogram_data.append(tomogram_data) 

135 logger.info(f"Loaded tomogram with shape {tomogram_data.shape} into memory") 

136 

137 # Store all particle coordinates for background sampling 

138 all_particle_coords = [] 

139 

140 # Initialize storage for preloaded data if preloading is enabled 

141 if self.preload and not hasattr(self, "_subvolumes"): 

142 self._subvolumes = [] 

143 

144 # Process picks for each object type 

145 for picks in run.get_picks(): 

146 if not picks.from_tool: 

147 continue 

148 

149 object_name = picks.pickable_object_name 

150 

151 # Skip objects not in our mapping 

152 if object_name not in self._name_to_label: 

153 logger.warning(f"Object {object_name} not in pickable objects, skipping") 

154 continue 

155 

156 class_idx = self._name_to_label[object_name] 

157 

158 try: 

159 points, _ = picks.numpy() 

160 if len(points) == 0: 

161 logger.warning(f"No points found for {object_name}") 

162 continue 

163 

164 logger.info(f"Found {len(points)} points for {object_name}") 

165 

166 # Store the points and labels 

167 for point in points: 

168 all_points.append(point) 

169 all_labels.append(class_idx) 

170 all_is_background.append(False) 

171 all_tomogram_indices.append(run_idx) 

172 all_particle_coords.append(point) 

173 

174 # If preloading is enabled, extract and store the subvolume immediately 

175 if self.preload: 

176 # Convert coordinates to indices 

177 x_idx = int(point[0] / self.voxel_spacing) 

178 y_idx = int(point[1] / self.voxel_spacing) 

179 z_idx = int(point[2] / self.voxel_spacing) 

180 

181 # Calculate half box sizes 

182 half_x = self.boxsize[2] // 2 

183 half_y = self.boxsize[1] // 2 

184 half_z = self.boxsize[0] // 2 

185 

186 # Calculate bounds with boundary checking 

187 x_start = max(0, x_idx - half_x) 

188 x_end = min(tomogram_data.shape[2], x_idx + half_x) 

189 y_start = max(0, y_idx - half_y) 

190 y_end = min(tomogram_data.shape[1], y_idx + half_y) 

191 z_start = max(0, z_idx - half_z) 

192 z_end = min(tomogram_data.shape[0], z_idx + half_z) 

193 

194 # Extract subvolume 

195 subvolume = tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end].copy() 

196 

197 # Pad if necessary 

198 if subvolume.shape != self.boxsize: 

199 padded = np.zeros(self.boxsize, dtype=subvolume.dtype) 

200 

201 # Calculate padding dimensions 

202 pad_z = min(z_end - z_start, self.boxsize[0]) 

203 pad_y = min(y_end - y_start, self.boxsize[1]) 

204 pad_x = min(x_end - x_start, self.boxsize[2]) 

205 

206 # Calculate padding offsets (center the subvolume in the padded volume) 

207 z_offset = (self.boxsize[0] - pad_z) // 2 

208 y_offset = (self.boxsize[1] - pad_y) // 2 

209 x_offset = (self.boxsize[2] - pad_x) // 2 

210 

211 # Insert subvolume into padded volume 

212 padded[ 

213 z_offset : z_offset + pad_z, 

214 y_offset : y_offset + pad_y, 

215 x_offset : x_offset + pad_x, 

216 ] = subvolume 

217 

218 subvolume = padded 

219 

220 # Normalize 

221 if np.std(subvolume) > 0: 

222 subvolume = (subvolume - np.mean(subvolume)) / np.std(subvolume) 

223 

224 # Add channel dimension and convert to tensor 

225 subvolume_tensor = torch.as_tensor(subvolume[None, ...], dtype=torch.float32) 

226 

227 # Store the preloaded tensor with its label 

228 self._subvolumes.append((subvolume_tensor, class_idx)) 

229 except Exception as e: 

230 logger.error(f"Error processing picks for {object_name}: {e}") 

231 

232 # Sample background points if requested 

233 if self.include_background and all_particle_coords: 

234 num_particles = len(all_particle_coords) 

235 num_background = int(num_particles * self.background_ratio) 

236 

237 logger.info(f"Sampling {num_background} background points") 

238 

239 bg_points = self._sample_background_points( 

240 tomogram_data.shape, 

241 all_particle_coords, 

242 num_background, 

243 self.min_background_distance, 

244 ) 

245 

246 for point in bg_points: 

247 all_points.append(point) 

248 all_labels.append(-1) # -1 indicates background 

249 all_is_background.append(True) 

250 all_tomogram_indices.append(run_idx) 

251 

252 # If preloading is enabled, extract and store the background subvolume immediately 

253 if self.preload: 

254 # Convert coordinates to indices 

255 x_idx = int(point[0] / self.voxel_spacing) 

256 y_idx = int(point[1] / self.voxel_spacing) 

257 z_idx = int(point[2] / self.voxel_spacing) 

258 

259 # Calculate half box sizes 

260 half_x = self.boxsize[2] // 2 

261 half_y = self.boxsize[1] // 2 

262 half_z = self.boxsize[0] // 2 

263 

264 # Calculate bounds with boundary checking 

265 x_start = max(0, x_idx - half_x) 

266 x_end = min(tomogram_data.shape[2], x_idx + half_x) 

267 y_start = max(0, y_idx - half_y) 

268 y_end = min(tomogram_data.shape[1], y_idx + half_y) 

269 z_start = max(0, z_idx - half_z) 

270 z_end = min(tomogram_data.shape[0], z_idx + half_z) 

271 

272 # Extract subvolume 

273 subvolume = tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end].copy() 

274 

275 # Pad if necessary 

276 if subvolume.shape != self.boxsize: 

277 padded = np.zeros(self.boxsize, dtype=subvolume.dtype) 

278 

279 # Calculate padding dimensions 

280 pad_z = min(z_end - z_start, self.boxsize[0]) 

281 pad_y = min(y_end - y_start, self.boxsize[1]) 

282 pad_x = min(x_end - x_start, self.boxsize[2]) 

283 

284 # Calculate padding offsets (center the subvolume in the padded volume) 

285 z_offset = (self.boxsize[0] - pad_z) // 2 

286 y_offset = (self.boxsize[1] - pad_y) // 2 

287 x_offset = (self.boxsize[2] - pad_x) // 2 

288 

289 # Insert subvolume into padded volume 

290 padded[ 

291 z_offset : z_offset + pad_z, 

292 y_offset : y_offset + pad_y, 

293 x_offset : x_offset + pad_x, 

294 ] = subvolume 

295 

296 subvolume = padded 

297 

298 # Normalize 

299 if np.std(subvolume) > 0: 

300 subvolume = (subvolume - np.mean(subvolume)) / np.std(subvolume) 

301 

302 # Add channel dimension and convert to tensor 

303 subvolume_tensor = torch.as_tensor(subvolume[None, ...], dtype=torch.float32) 

304 

305 # Store the preloaded tensor with its label (-1 for background) 

306 self._subvolumes.append((subvolume_tensor, -1)) 

307 

308 except Exception as e: 

309 logger.error(f"Error processing tomogram for run {run.name}: {e}") 

310 continue 

311 

312 # Store the processed data 

313 self._points = all_points 

314 self._labels = all_labels 

315 self._is_background = all_is_background 

316 self._tomogram_indices = all_tomogram_indices 

317 

318 logger.info(f"Dataset loaded with {len(self._points)} samples") 

319 

320 # Print class distribution 

321 self._print_class_distribution() 

322 

323 # If preloading is enabled, the subvolumes are already preloaded during point extraction 

324 if self.preload and len(self._points) > 0 and not hasattr(self, "_subvolumes"): 

325 logger.info("Preloading was requested but no preloaded data exists") 

326 

327 except Exception as e: 

328 logger.error(f"Error loading data: {e}") 

329 raise 

330 

331 def _preload_data(self): 

332 """Preload all subvolumes into memory.""" 

333 logger.info(f"Preloading {len(self._points)} subvolumes into memory...") 

334 

335 # This method is preserved for backward compatibility but should not be called 

336 # during normal operation since preloading now happens during _load_data 

337 

338 # Check if preloading already happened 

339 if hasattr(self, "_subvolumes") and self._subvolumes: 

340 logger.info(f"Subvolumes are already preloaded ({len(self._subvolumes)} subvolumes)") 

341 return 

342 

343 # Initialize storage for preloaded data 

344 self._subvolumes = [] 

345 

346 # Extract and store all subvolumes 

347 for idx in tqdm(range(len(self._points))): 

348 point = self._points[idx] 

349 label = self._labels[idx] 

350 tomogram_idx = self._tomogram_indices[idx] if hasattr(self, "_tomogram_indices") else 0 

351 

352 # Extract the subvolume 

353 subvolume = self.extract_subvolume(point, tomogram_idx) 

354 

355 # Normalize 

356 if np.std(subvolume) > 0: 

357 subvolume = (subvolume - np.mean(subvolume)) / np.std(subvolume) 

358 

359 # Add channel dimension and convert to tensor 

360 subvolume_tensor = torch.as_tensor(subvolume[None, ...], dtype=torch.float32) 

361 

362 # Store the tensor with its label 

363 self._subvolumes.append((subvolume_tensor, label)) 

364 

365 logger.info(f"Preloaded {len(self._subvolumes)} subvolumes") 

366 

367 def _print_class_distribution(self): 

368 """Print the distribution of classes in the dataset.""" 

369 class_counts = Counter(self._labels) 

370 

371 # Create a readable distribution 

372 distribution = {} 

373 

374 # Count background samples if any 

375 if -1 in class_counts: 

376 distribution["background"] = class_counts[-1] 

377 del class_counts[-1] 

378 

379 # Count regular classes 

380 for cls_idx, count in class_counts.items(): 

381 # Find the class name for this label 

382 for name, label in self._name_to_label.items(): 

383 if label == cls_idx: 

384 distribution[name] = count 

385 break 

386 

387 logger.info("Class distribution:") 

388 for class_name, count in distribution.items(): 

389 logger.info(f" {class_name}: {count} samples") 

390 

391 return distribution 

392 

393 def _sample_background_points(self, tomogram_shape, particle_coords, num_points, min_distance): 

394 """ 

395 Sample random background points away from particles. 

396 

397 Args: 

398 tomogram_shape: Shape of the tomogram (z, y, x) 

399 particle_coords: List of particle coordinates 

400 num_points: Number of background points to sample 

401 min_distance: Minimum distance from particles 

402 

403 Returns: 

404 List of background points 

405 """ 

406 # Convert to numpy array for vectorized calculations 

407 if particle_coords: 

408 particle_array = np.array(particle_coords) 

409 else: 

410 particle_array = np.array([[0, 0, 0]]) # Dummy point if no particles 

411 

412 # Get dimensions 

413 z_dim, y_dim, x_dim = tomogram_shape 

414 half_box = np.array(self.boxsize) // 2 

415 

416 # Define valid ranges 

417 x_range = (half_box[2], x_dim - half_box[2]) 

418 y_range = (half_box[1], y_dim - half_box[1]) 

419 z_range = (half_box[0], z_dim - half_box[0]) 

420 

421 # Sample points 

422 bg_points = [] 

423 max_attempts = num_points * 10 

424 attempts = 0 

425 

426 while len(bg_points) < num_points and attempts < max_attempts: 

427 # Generate random point 

428 x = np.random.uniform(x_range[0], x_range[1]) 

429 y = np.random.uniform(y_range[0], y_range[1]) 

430 z = np.random.uniform(z_range[0], z_range[1]) 

431 point = np.array([x, y, z]) 

432 

433 # Check distance to all particles 

434 if particle_array is not None: 

435 distances = np.linalg.norm(particle_array - point, axis=1) 

436 min_dist = np.min(distances) 

437 

438 if min_dist >= min_distance: 

439 bg_points.append(point) 

440 else: 

441 # No particles to avoid 

442 bg_points.append(point) 

443 

444 attempts += 1 

445 

446 logger.info(f"Sampled {len(bg_points)} background points after {attempts} attempts") 

447 return bg_points 

448 

449 def extract_subvolume(self, point, tomogram_idx=0): 

450 """ 

451 Extract a cubic subvolume centered around a point. 

452 

453 Args: 

454 point: (x, y, z) coordinates 

455 tomogram_idx: Index of the tomogram to use 

456 

457 Returns: 

458 Extracted subvolume as a numpy array 

459 """ 

460 # Check if tomogram exists 

461 if tomogram_idx >= len(self._tomogram_data) or self._tomogram_data[tomogram_idx] is None: 

462 raise ValueError(f"No tomogram found at index {tomogram_idx}") 

463 

464 tomogram_zarr = self._tomogram_data[tomogram_idx] 

465 

466 # Get dimensions of the tomogram 

467 z_dim, y_dim, x_dim = tomogram_zarr.shape 

468 

469 # Convert coordinates to indices 

470 x_idx = int(point[0] / self.voxel_spacing) 

471 y_idx = int(point[1] / self.voxel_spacing) 

472 z_idx = int(point[2] / self.voxel_spacing) 

473 

474 # Calculate half box sizes 

475 half_x = self.boxsize[2] // 2 

476 half_y = self.boxsize[1] // 2 

477 half_z = self.boxsize[0] // 2 

478 

479 # Calculate bounds with boundary checking 

480 x_start = max(0, x_idx - half_x) 

481 x_end = min(x_dim, x_idx + half_x) 

482 y_start = max(0, y_idx - half_y) 

483 y_end = min(y_dim, y_idx + half_y) 

484 z_start = max(0, z_idx - half_z) 

485 z_end = min(z_dim, z_idx + half_z) 

486 

487 # Extract subvolume 

488 subvolume = tomogram_zarr[z_start:z_end, y_start:y_end, x_start:x_end].copy() 

489 

490 # Pad if necessary 

491 if subvolume.shape != self.boxsize: 

492 padded = np.zeros(self.boxsize, dtype=subvolume.dtype) 

493 

494 # Calculate padding dimensions 

495 pad_z = min(z_end - z_start, self.boxsize[0]) 

496 pad_y = min(y_end - y_start, self.boxsize[1]) 

497 pad_x = min(x_end - x_start, self.boxsize[2]) 

498 

499 # Calculate padding offsets (center the subvolume in the padded volume) 

500 z_offset = (self.boxsize[0] - pad_z) // 2 

501 y_offset = (self.boxsize[1] - pad_y) // 2 

502 x_offset = (self.boxsize[2] - pad_x) // 2 

503 

504 # Insert subvolume into padded volume 

505 padded[z_offset : z_offset + pad_z, y_offset : y_offset + pad_y, x_offset : x_offset + pad_x] = subvolume 

506 

507 return padded 

508 

509 return subvolume 

510 

511 def __len__(self): 

512 """Get the length of the dataset.""" 

513 return len(self._points) 

514 

515 def __getitem__(self, idx): 

516 """ 

517 Get an item from the dataset. 

518 

519 Args: 

520 idx: Index 

521 

522 Returns: 

523 Tuple of (subvolume, label) 

524 """ 

525 # If data is preloaded, return from preloaded data 

526 if self.preload and hasattr(self, "_subvolumes") and self._subvolumes: 

527 return self._subvolumes[idx] 

528 

529 # Otherwise, extract on-the-fly 

530 # Get the point, label, and tomogram index 

531 point = self._points[idx] 

532 label = self._labels[idx] 

533 tomogram_idx = self._tomogram_indices[idx] if hasattr(self, "_tomogram_indices") else 0 

534 

535 # Extract the subvolume 

536 subvolume = self.extract_subvolume(point, tomogram_idx) 

537 

538 # Normalize 

539 if np.std(subvolume) > 0: 

540 subvolume = (subvolume - np.mean(subvolume)) / np.std(subvolume) 

541 

542 # Add channel dimension and convert to tensor 

543 subvolume_tensor = torch.as_tensor(subvolume[None, ...], dtype=torch.float32) 

544 

545 return subvolume_tensor, label 

546 

547 def keys(self): 

548 """Get the list of class names.""" 

549 # Add background class if included 

550 class_names = list(self._name_to_label.keys()) 

551 if self.include_background: 

552 return class_names + ["background"] 

553 return class_names 

554 

555 def get_class_distribution(self): 

556 """Get the distribution of classes in the dataset.""" 

557 distribution = Counter() 

558 

559 for label in self._labels: 

560 if label == -1: 

561 distribution["background"] += 1 

562 else: 

563 # Find the class name for this label 

564 for name, idx in self._name_to_label.items(): 

565 if idx == label: 

566 distribution[name] += 1 

567 break 

568 

569 return dict(distribution) 

570 

571 def get_sample_weights(self): 

572 """ 

573 Compute sample weights for balanced sampling. 

574 

575 Returns: 

576 List of weights for each sample 

577 """ 

578 # Count instances of each class 

579 class_counts = Counter(self._labels) 

580 total_samples = len(self._labels) 

581 

582 # Compute inverse frequency weights 

583 weights = [] 

584 for label in self._labels: 

585 weight = total_samples / class_counts[label] 

586 weights.append(weight) 

587 

588 return weights 

589 

590 def save(self, save_dir): 

591 """ 

592 Save the dataset to disk for later reloading. 

593 

594 Args: 

595 save_dir: Directory to save the dataset 

596 """ 

597 os.makedirs(save_dir, exist_ok=True) 

598 

599 # Save metadata 

600 metadata = { 

601 "dataset_id": self.dataset_id, 

602 "boxsize": self.boxsize, 

603 "voxel_spacing": self.voxel_spacing, 

604 "include_background": self.include_background, 

605 "background_ratio": self.background_ratio, 

606 "min_background_distance": self.min_background_distance, 

607 "name_to_label": self._name_to_label, 

608 "preload": self.preload, 

609 } 

610 

611 with open(os.path.join(save_dir, "metadata.json"), "w") as f: 

612 json.dump(metadata, f) 

613 

614 # If preloaded, save the actual tensors 

615 if self.preload and hasattr(self, "_subvolumes") and self._subvolumes: 

616 logger.info("Saving preloaded tensors...") 

617 

618 # Extract tensors and labels 

619 subvolumes = [] 

620 labels = [] 

621 

622 for volume, label in self._subvolumes: 

623 subvolumes.append(volume) 

624 labels.append(label) 

625 

626 # Stack tensors into a single tensor 

627 subvolumes_tensor = torch.stack(subvolumes) 

628 labels_tensor = torch.tensor(labels) 

629 

630 # Save tensors to disk 

631 torch.save(subvolumes_tensor, os.path.join(save_dir, "subvolumes.pt")) 

632 torch.save(labels_tensor, os.path.join(save_dir, "labels.pt")) 

633 

634 logger.info(f"Saved {len(subvolumes)} preloaded tensors") 

635 else: 

636 # Save sample information for on-the-fly loading 

637 logger.info("Saving sample information for on-the-fly loading...") 

638 

639 sample_data = [] 

640 for i in range(len(self._points)): 

641 point = self._points[i] 

642 label = self._labels[i] 

643 is_background = self._is_background[i] 

644 tomogram_idx = self._tomogram_indices[i] if hasattr(self, "_tomogram_indices") else 0 

645 

646 sample_data.append( 

647 { 

648 "point": point.tolist() if isinstance(point, np.ndarray) else point, 

649 "label": int(label), 

650 "is_background": bool(is_background), 

651 "tomogram_idx": int(tomogram_idx), 

652 }, 

653 ) 

654 

655 with open(os.path.join(save_dir, "samples.json"), "w") as f: 

656 json.dump(sample_data, f) 

657 

658 # Save tomogram information (needed for on-the-fly loading) 

659 tomogram_info = [] 

660 for idx, tomogram in enumerate(self._tomogram_data): 

661 tomo_data = {"index": idx, "shape": list(tomogram.shape), "path": getattr(tomogram, "path", str(tomogram))} 

662 tomogram_info.append(tomo_data) 

663 

664 with open(os.path.join(save_dir, "tomogram_info.json"), "w") as f: 

665 json.dump(tomogram_info, f) 

666 

667 logger.info(f"Dataset saved to {save_dir}") 

668 

669 @classmethod 

670 def load(cls, save_dir, proj=None): 

671 """ 

672 Load a previously saved dataset. 

673 

674 Args: 

675 save_dir: Directory where the dataset was saved 

676 proj: Optional copick project object. If provided, tomograms will be loaded from it. 

677 

678 Returns: 

679 Loaded MinimalCopickDataset instance 

680 """ 

681 # Load metadata 

682 with open(os.path.join(save_dir, "metadata.json"), "r") as f: 

683 metadata = json.load(f) 

684 

685 # Create a new dataset instance without loading data 

686 dataset = cls.__new__(cls) 

687 dataset.dataset_id = metadata.get("dataset_id") 

688 dataset.boxsize = metadata.get("boxsize", (48, 48, 48)) 

689 dataset.voxel_spacing = metadata.get("voxel_spacing", 10.012) 

690 dataset.include_background = metadata.get("include_background", False) 

691 dataset.background_ratio = metadata.get("background_ratio", 0.2) 

692 dataset.min_background_distance = metadata.get("min_background_distance") 

693 dataset._name_to_label = metadata.get("name_to_label", {}) 

694 dataset.preload = metadata.get("preload", True) 

695 dataset.copick_root = proj 

696 

697 # Check if we have preloaded tensors 

698 subvolumes_path = os.path.join(save_dir, "subvolumes.pt") 

699 labels_path = os.path.join(save_dir, "labels.pt") 

700 

701 if os.path.exists(subvolumes_path) and os.path.exists(labels_path): 

702 logger.info("Loading preloaded tensors...") 

703 

704 # Load the tensors 

705 subvolumes = torch.load(subvolumes_path) 

706 labels = torch.load(labels_path) 

707 

708 # Store in the dataset 

709 dataset._subvolumes = [(subvolumes[i], labels[i].item()) for i in range(len(labels))] 

710 

711 # Create minimal point/label data for compatibility 

712 dataset._points = [np.zeros(3) for _ in range(len(labels))] 

713 dataset._labels = [label.item() for label in labels] 

714 dataset._is_background = [label.item() == -1 for label in labels] 

715 dataset._tomogram_indices = [0 for _ in range(len(labels))] 

716 dataset._tomogram_data = [] 

717 

718 logger.info(f"Loaded dataset with {len(dataset._subvolumes)} preloaded subvolumes") 

719 else: 

720 # Initialize empty data structures 

721 dataset._tomogram_data = [] 

722 

723 # Load sample information 

724 with open(os.path.join(save_dir, "samples.json"), "r") as f: 

725 sample_data = json.load(f) 

726 

727 # Extract sample information 

728 dataset._points = [np.array(s["point"]) for s in sample_data] 

729 dataset._labels = [s["label"] for s in sample_data] 

730 dataset._is_background = [s["is_background"] for s in sample_data] 

731 dataset._tomogram_indices = [s["tomogram_idx"] for s in sample_data] 

732 

733 # Load tomogram information 

734 with open(os.path.join(save_dir, "tomogram_info.json"), "r") as f: 

735 tomogram_info = json.load(f) 

736 

737 # If a project is provided, attempt to load tomograms from it 

738 if proj is not None: 

739 logger.info("Loading tomograms from provided project") 

740 

741 # Initialize tomogram list with placeholders 

742 dataset._tomogram_data = [None] * len(tomogram_info) 

743 

744 # Attempt to load tomograms 

745 for run in proj.runs: 

746 voxel_spacing_obj = run.get_voxel_spacing(dataset.voxel_spacing) 

747 if voxel_spacing_obj and voxel_spacing_obj.tomograms: 

748 # Find denoised tomogram if available 

749 tomogram = [t for t in voxel_spacing_obj.tomograms if "wbp-denoised" in t.tomo_type] 

750 if not tomogram: 

751 tomogram = voxel_spacing_obj.tomograms[0] 

752 else: 

753 tomogram = tomogram[0] 

754 

755 # Find matching tomogram in info 

756 for tomo_info in tomogram_info: 

757 # Simple check for matching shape as a heuristic 

758 tomo_zarr = zarr.open(tomogram.zarr())["0"] 

759 if list(tomo_zarr.shape) == tomo_info["shape"]: 

760 idx = tomo_info["index"] 

761 dataset._tomogram_data[idx] = tomo_zarr 

762 logger.info(f"Loaded tomogram at index {idx} with shape {tomo_zarr.shape}") 

763 else: 

764 logger.warning("No project provided. Tomograms must be loaded separately.") 

765 

766 logger.info(f"Loaded dataset with {len(dataset._points)} samples") 

767 

768 # If preload is True and we have tomogram data, preload the subvolumes 

769 if dataset.preload and dataset._tomogram_data and all(t is not None for t in dataset._tomogram_data): 

770 logger.info("Preloading subvolumes...") 

771 dataset._preload_data() 

772 

773 # Print class distribution 

774 dataset._print_class_distribution() 

775 

776 return dataset