Coverage for copick_torch/minimal_dataset.py: 6%
382 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1"""
2A minimal CopickDataset implementation without caching or augmentation.
3"""
5import json
6import logging
7import os
8from collections import Counter
9from types import SimpleNamespace
11import copick
12import numpy as np
13import torch
14import zarr
15from torch.utils.data import Dataset
16from tqdm import tqdm
18logger = logging.getLogger(__name__)
21class MinimalCopickDataset(Dataset):
22 """
23 A minimal PyTorch dataset for working with copick data that returns (image, label) pairs.
25 Unlike the SimpleCopickDataset, this implementation:
26 1. Does not use caching (loads data on-the-fly)
27 2. Does not include augmentation
28 3. Has minimal dependencies
29 4. Focuses on correct label mapping
31 This dataset can be saved to disk and loaded later for reproducibility.
32 """
34 def __init__(
35 self,
36 proj=None,
37 dataset_id=None,
38 overlay_root=None,
39 boxsize=(48, 48, 48),
40 voxel_spacing=10.012,
41 include_background=False,
42 background_ratio=0.2,
43 min_background_distance=None,
44 preload=True,
45 ):
46 """
47 Initialize a MinimalCopickDataset.
49 Args:
50 proj: A copick project object. If provided, dataset_id and overlay_root are ignored.
51 dataset_id: Dataset ID from the CZ cryoET Data Portal. Only used if proj is None.
52 overlay_root: Root directory for the overlay storage. Only used if proj is None.
53 boxsize: Size of the subvolumes to extract (z, y, x)
54 voxel_spacing: Voxel spacing to use for extraction
55 include_background: Whether to include background samples
56 background_ratio: Ratio of background to particle samples
57 min_background_distance: Minimum distance from particles for background samples
58 preload: Whether to preload all subvolumes into memory (faster but more memory intensive)
59 """
60 self.dataset_id = dataset_id
61 self.overlay_root = overlay_root
62 self.boxsize = boxsize
63 self.voxel_spacing = voxel_spacing
64 self.include_background = include_background
65 self.background_ratio = background_ratio
66 self.min_background_distance = min_background_distance or max(boxsize)
67 self.preload = preload
69 # Initialize data structures
70 self._points = [] # List of (x, y, z) coordinates
71 self._labels = [] # List of class indices
72 self._is_background = [] # List of booleans indicating if a sample is background
73 self._tomogram_data = [] # List of tomogram zarr arrays
74 self._name_to_label = {} # Mapping from object names to labels
76 # Storage for preloaded data
77 self._subvolumes = None
79 # Set copick project
80 self.copick_root = proj
82 # Load the data
83 if self.copick_root is not None:
84 self._load_data()
85 elif dataset_id is not None and overlay_root is not None:
86 # Create project from dataset_id and overlay_root
87 self.copick_root = copick.from_czcdp_datasets([dataset_id], overlay_root=overlay_root)
88 self._load_data()
90 def _extract_name_to_label(self):
91 """Extract name to label mapping from pickable objects."""
92 # Create mapping from object names to labels
93 self._name_to_label = {}
94 for obj in self.copick_root.pickable_objects:
95 self._name_to_label[obj.name] = obj.label
97 # Ensure we have a consistent list of object names
98 self._object_names = list(self._name_to_label.keys())
100 logger.info(f"Name to label mapping: {self._name_to_label}")
102 def _load_data(self):
103 """Load data from the copick project."""
104 try:
105 # Extract name to label mapping
106 self._extract_name_to_label()
108 # Process each run
109 all_points = []
110 all_labels = []
111 all_is_background = []
112 all_tomogram_indices = []
114 for run_idx, run in enumerate(self.copick_root.runs):
115 logger.info(f"Processing run: {run.name}")
117 # Get tomogram
118 try:
119 voxel_spacing_obj = run.get_voxel_spacing(self.voxel_spacing)
120 if not voxel_spacing_obj or not voxel_spacing_obj.tomograms:
121 logger.warning(f"No tomograms found for run {run.name} at voxel spacing {self.voxel_spacing}")
122 continue
124 # Find a denoised tomogram if available, otherwise use the first one
125 tomogram = [t for t in voxel_spacing_obj.tomograms if "wbp-denoised" in t.tomo_type]
126 if not tomogram:
127 tomogram = voxel_spacing_obj.tomograms[0]
128 else:
129 tomogram = tomogram[0]
131 # Open zarr array and load it fully into memory
132 tomogram_zarr = zarr.open(tomogram.zarr())["0"]
133 tomogram_data = np.array(tomogram_zarr[:])
134 self._tomogram_data.append(tomogram_data)
135 logger.info(f"Loaded tomogram with shape {tomogram_data.shape} into memory")
137 # Store all particle coordinates for background sampling
138 all_particle_coords = []
140 # Initialize storage for preloaded data if preloading is enabled
141 if self.preload and not hasattr(self, "_subvolumes"):
142 self._subvolumes = []
144 # Process picks for each object type
145 for picks in run.get_picks():
146 if not picks.from_tool:
147 continue
149 object_name = picks.pickable_object_name
151 # Skip objects not in our mapping
152 if object_name not in self._name_to_label:
153 logger.warning(f"Object {object_name} not in pickable objects, skipping")
154 continue
156 class_idx = self._name_to_label[object_name]
158 try:
159 points, _ = picks.numpy()
160 if len(points) == 0:
161 logger.warning(f"No points found for {object_name}")
162 continue
164 logger.info(f"Found {len(points)} points for {object_name}")
166 # Store the points and labels
167 for point in points:
168 all_points.append(point)
169 all_labels.append(class_idx)
170 all_is_background.append(False)
171 all_tomogram_indices.append(run_idx)
172 all_particle_coords.append(point)
174 # If preloading is enabled, extract and store the subvolume immediately
175 if self.preload:
176 # Convert coordinates to indices
177 x_idx = int(point[0] / self.voxel_spacing)
178 y_idx = int(point[1] / self.voxel_spacing)
179 z_idx = int(point[2] / self.voxel_spacing)
181 # Calculate half box sizes
182 half_x = self.boxsize[2] // 2
183 half_y = self.boxsize[1] // 2
184 half_z = self.boxsize[0] // 2
186 # Calculate bounds with boundary checking
187 x_start = max(0, x_idx - half_x)
188 x_end = min(tomogram_data.shape[2], x_idx + half_x)
189 y_start = max(0, y_idx - half_y)
190 y_end = min(tomogram_data.shape[1], y_idx + half_y)
191 z_start = max(0, z_idx - half_z)
192 z_end = min(tomogram_data.shape[0], z_idx + half_z)
194 # Extract subvolume
195 subvolume = tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end].copy()
197 # Pad if necessary
198 if subvolume.shape != self.boxsize:
199 padded = np.zeros(self.boxsize, dtype=subvolume.dtype)
201 # Calculate padding dimensions
202 pad_z = min(z_end - z_start, self.boxsize[0])
203 pad_y = min(y_end - y_start, self.boxsize[1])
204 pad_x = min(x_end - x_start, self.boxsize[2])
206 # Calculate padding offsets (center the subvolume in the padded volume)
207 z_offset = (self.boxsize[0] - pad_z) // 2
208 y_offset = (self.boxsize[1] - pad_y) // 2
209 x_offset = (self.boxsize[2] - pad_x) // 2
211 # Insert subvolume into padded volume
212 padded[
213 z_offset : z_offset + pad_z,
214 y_offset : y_offset + pad_y,
215 x_offset : x_offset + pad_x,
216 ] = subvolume
218 subvolume = padded
220 # Normalize
221 if np.std(subvolume) > 0:
222 subvolume = (subvolume - np.mean(subvolume)) / np.std(subvolume)
224 # Add channel dimension and convert to tensor
225 subvolume_tensor = torch.as_tensor(subvolume[None, ...], dtype=torch.float32)
227 # Store the preloaded tensor with its label
228 self._subvolumes.append((subvolume_tensor, class_idx))
229 except Exception as e:
230 logger.error(f"Error processing picks for {object_name}: {e}")
232 # Sample background points if requested
233 if self.include_background and all_particle_coords:
234 num_particles = len(all_particle_coords)
235 num_background = int(num_particles * self.background_ratio)
237 logger.info(f"Sampling {num_background} background points")
239 bg_points = self._sample_background_points(
240 tomogram_data.shape,
241 all_particle_coords,
242 num_background,
243 self.min_background_distance,
244 )
246 for point in bg_points:
247 all_points.append(point)
248 all_labels.append(-1) # -1 indicates background
249 all_is_background.append(True)
250 all_tomogram_indices.append(run_idx)
252 # If preloading is enabled, extract and store the background subvolume immediately
253 if self.preload:
254 # Convert coordinates to indices
255 x_idx = int(point[0] / self.voxel_spacing)
256 y_idx = int(point[1] / self.voxel_spacing)
257 z_idx = int(point[2] / self.voxel_spacing)
259 # Calculate half box sizes
260 half_x = self.boxsize[2] // 2
261 half_y = self.boxsize[1] // 2
262 half_z = self.boxsize[0] // 2
264 # Calculate bounds with boundary checking
265 x_start = max(0, x_idx - half_x)
266 x_end = min(tomogram_data.shape[2], x_idx + half_x)
267 y_start = max(0, y_idx - half_y)
268 y_end = min(tomogram_data.shape[1], y_idx + half_y)
269 z_start = max(0, z_idx - half_z)
270 z_end = min(tomogram_data.shape[0], z_idx + half_z)
272 # Extract subvolume
273 subvolume = tomogram_data[z_start:z_end, y_start:y_end, x_start:x_end].copy()
275 # Pad if necessary
276 if subvolume.shape != self.boxsize:
277 padded = np.zeros(self.boxsize, dtype=subvolume.dtype)
279 # Calculate padding dimensions
280 pad_z = min(z_end - z_start, self.boxsize[0])
281 pad_y = min(y_end - y_start, self.boxsize[1])
282 pad_x = min(x_end - x_start, self.boxsize[2])
284 # Calculate padding offsets (center the subvolume in the padded volume)
285 z_offset = (self.boxsize[0] - pad_z) // 2
286 y_offset = (self.boxsize[1] - pad_y) // 2
287 x_offset = (self.boxsize[2] - pad_x) // 2
289 # Insert subvolume into padded volume
290 padded[
291 z_offset : z_offset + pad_z,
292 y_offset : y_offset + pad_y,
293 x_offset : x_offset + pad_x,
294 ] = subvolume
296 subvolume = padded
298 # Normalize
299 if np.std(subvolume) > 0:
300 subvolume = (subvolume - np.mean(subvolume)) / np.std(subvolume)
302 # Add channel dimension and convert to tensor
303 subvolume_tensor = torch.as_tensor(subvolume[None, ...], dtype=torch.float32)
305 # Store the preloaded tensor with its label (-1 for background)
306 self._subvolumes.append((subvolume_tensor, -1))
308 except Exception as e:
309 logger.error(f"Error processing tomogram for run {run.name}: {e}")
310 continue
312 # Store the processed data
313 self._points = all_points
314 self._labels = all_labels
315 self._is_background = all_is_background
316 self._tomogram_indices = all_tomogram_indices
318 logger.info(f"Dataset loaded with {len(self._points)} samples")
320 # Print class distribution
321 self._print_class_distribution()
323 # If preloading is enabled, the subvolumes are already preloaded during point extraction
324 if self.preload and len(self._points) > 0 and not hasattr(self, "_subvolumes"):
325 logger.info("Preloading was requested but no preloaded data exists")
327 except Exception as e:
328 logger.error(f"Error loading data: {e}")
329 raise
331 def _preload_data(self):
332 """Preload all subvolumes into memory."""
333 logger.info(f"Preloading {len(self._points)} subvolumes into memory...")
335 # This method is preserved for backward compatibility but should not be called
336 # during normal operation since preloading now happens during _load_data
338 # Check if preloading already happened
339 if hasattr(self, "_subvolumes") and self._subvolumes:
340 logger.info(f"Subvolumes are already preloaded ({len(self._subvolumes)} subvolumes)")
341 return
343 # Initialize storage for preloaded data
344 self._subvolumes = []
346 # Extract and store all subvolumes
347 for idx in tqdm(range(len(self._points))):
348 point = self._points[idx]
349 label = self._labels[idx]
350 tomogram_idx = self._tomogram_indices[idx] if hasattr(self, "_tomogram_indices") else 0
352 # Extract the subvolume
353 subvolume = self.extract_subvolume(point, tomogram_idx)
355 # Normalize
356 if np.std(subvolume) > 0:
357 subvolume = (subvolume - np.mean(subvolume)) / np.std(subvolume)
359 # Add channel dimension and convert to tensor
360 subvolume_tensor = torch.as_tensor(subvolume[None, ...], dtype=torch.float32)
362 # Store the tensor with its label
363 self._subvolumes.append((subvolume_tensor, label))
365 logger.info(f"Preloaded {len(self._subvolumes)} subvolumes")
367 def _print_class_distribution(self):
368 """Print the distribution of classes in the dataset."""
369 class_counts = Counter(self._labels)
371 # Create a readable distribution
372 distribution = {}
374 # Count background samples if any
375 if -1 in class_counts:
376 distribution["background"] = class_counts[-1]
377 del class_counts[-1]
379 # Count regular classes
380 for cls_idx, count in class_counts.items():
381 # Find the class name for this label
382 for name, label in self._name_to_label.items():
383 if label == cls_idx:
384 distribution[name] = count
385 break
387 logger.info("Class distribution:")
388 for class_name, count in distribution.items():
389 logger.info(f" {class_name}: {count} samples")
391 return distribution
393 def _sample_background_points(self, tomogram_shape, particle_coords, num_points, min_distance):
394 """
395 Sample random background points away from particles.
397 Args:
398 tomogram_shape: Shape of the tomogram (z, y, x)
399 particle_coords: List of particle coordinates
400 num_points: Number of background points to sample
401 min_distance: Minimum distance from particles
403 Returns:
404 List of background points
405 """
406 # Convert to numpy array for vectorized calculations
407 if particle_coords:
408 particle_array = np.array(particle_coords)
409 else:
410 particle_array = np.array([[0, 0, 0]]) # Dummy point if no particles
412 # Get dimensions
413 z_dim, y_dim, x_dim = tomogram_shape
414 half_box = np.array(self.boxsize) // 2
416 # Define valid ranges
417 x_range = (half_box[2], x_dim - half_box[2])
418 y_range = (half_box[1], y_dim - half_box[1])
419 z_range = (half_box[0], z_dim - half_box[0])
421 # Sample points
422 bg_points = []
423 max_attempts = num_points * 10
424 attempts = 0
426 while len(bg_points) < num_points and attempts < max_attempts:
427 # Generate random point
428 x = np.random.uniform(x_range[0], x_range[1])
429 y = np.random.uniform(y_range[0], y_range[1])
430 z = np.random.uniform(z_range[0], z_range[1])
431 point = np.array([x, y, z])
433 # Check distance to all particles
434 if particle_array is not None:
435 distances = np.linalg.norm(particle_array - point, axis=1)
436 min_dist = np.min(distances)
438 if min_dist >= min_distance:
439 bg_points.append(point)
440 else:
441 # No particles to avoid
442 bg_points.append(point)
444 attempts += 1
446 logger.info(f"Sampled {len(bg_points)} background points after {attempts} attempts")
447 return bg_points
449 def extract_subvolume(self, point, tomogram_idx=0):
450 """
451 Extract a cubic subvolume centered around a point.
453 Args:
454 point: (x, y, z) coordinates
455 tomogram_idx: Index of the tomogram to use
457 Returns:
458 Extracted subvolume as a numpy array
459 """
460 # Check if tomogram exists
461 if tomogram_idx >= len(self._tomogram_data) or self._tomogram_data[tomogram_idx] is None:
462 raise ValueError(f"No tomogram found at index {tomogram_idx}")
464 tomogram_zarr = self._tomogram_data[tomogram_idx]
466 # Get dimensions of the tomogram
467 z_dim, y_dim, x_dim = tomogram_zarr.shape
469 # Convert coordinates to indices
470 x_idx = int(point[0] / self.voxel_spacing)
471 y_idx = int(point[1] / self.voxel_spacing)
472 z_idx = int(point[2] / self.voxel_spacing)
474 # Calculate half box sizes
475 half_x = self.boxsize[2] // 2
476 half_y = self.boxsize[1] // 2
477 half_z = self.boxsize[0] // 2
479 # Calculate bounds with boundary checking
480 x_start = max(0, x_idx - half_x)
481 x_end = min(x_dim, x_idx + half_x)
482 y_start = max(0, y_idx - half_y)
483 y_end = min(y_dim, y_idx + half_y)
484 z_start = max(0, z_idx - half_z)
485 z_end = min(z_dim, z_idx + half_z)
487 # Extract subvolume
488 subvolume = tomogram_zarr[z_start:z_end, y_start:y_end, x_start:x_end].copy()
490 # Pad if necessary
491 if subvolume.shape != self.boxsize:
492 padded = np.zeros(self.boxsize, dtype=subvolume.dtype)
494 # Calculate padding dimensions
495 pad_z = min(z_end - z_start, self.boxsize[0])
496 pad_y = min(y_end - y_start, self.boxsize[1])
497 pad_x = min(x_end - x_start, self.boxsize[2])
499 # Calculate padding offsets (center the subvolume in the padded volume)
500 z_offset = (self.boxsize[0] - pad_z) // 2
501 y_offset = (self.boxsize[1] - pad_y) // 2
502 x_offset = (self.boxsize[2] - pad_x) // 2
504 # Insert subvolume into padded volume
505 padded[z_offset : z_offset + pad_z, y_offset : y_offset + pad_y, x_offset : x_offset + pad_x] = subvolume
507 return padded
509 return subvolume
511 def __len__(self):
512 """Get the length of the dataset."""
513 return len(self._points)
515 def __getitem__(self, idx):
516 """
517 Get an item from the dataset.
519 Args:
520 idx: Index
522 Returns:
523 Tuple of (subvolume, label)
524 """
525 # If data is preloaded, return from preloaded data
526 if self.preload and hasattr(self, "_subvolumes") and self._subvolumes:
527 return self._subvolumes[idx]
529 # Otherwise, extract on-the-fly
530 # Get the point, label, and tomogram index
531 point = self._points[idx]
532 label = self._labels[idx]
533 tomogram_idx = self._tomogram_indices[idx] if hasattr(self, "_tomogram_indices") else 0
535 # Extract the subvolume
536 subvolume = self.extract_subvolume(point, tomogram_idx)
538 # Normalize
539 if np.std(subvolume) > 0:
540 subvolume = (subvolume - np.mean(subvolume)) / np.std(subvolume)
542 # Add channel dimension and convert to tensor
543 subvolume_tensor = torch.as_tensor(subvolume[None, ...], dtype=torch.float32)
545 return subvolume_tensor, label
547 def keys(self):
548 """Get the list of class names."""
549 # Add background class if included
550 class_names = list(self._name_to_label.keys())
551 if self.include_background:
552 return class_names + ["background"]
553 return class_names
555 def get_class_distribution(self):
556 """Get the distribution of classes in the dataset."""
557 distribution = Counter()
559 for label in self._labels:
560 if label == -1:
561 distribution["background"] += 1
562 else:
563 # Find the class name for this label
564 for name, idx in self._name_to_label.items():
565 if idx == label:
566 distribution[name] += 1
567 break
569 return dict(distribution)
571 def get_sample_weights(self):
572 """
573 Compute sample weights for balanced sampling.
575 Returns:
576 List of weights for each sample
577 """
578 # Count instances of each class
579 class_counts = Counter(self._labels)
580 total_samples = len(self._labels)
582 # Compute inverse frequency weights
583 weights = []
584 for label in self._labels:
585 weight = total_samples / class_counts[label]
586 weights.append(weight)
588 return weights
590 def save(self, save_dir):
591 """
592 Save the dataset to disk for later reloading.
594 Args:
595 save_dir: Directory to save the dataset
596 """
597 os.makedirs(save_dir, exist_ok=True)
599 # Save metadata
600 metadata = {
601 "dataset_id": self.dataset_id,
602 "boxsize": self.boxsize,
603 "voxel_spacing": self.voxel_spacing,
604 "include_background": self.include_background,
605 "background_ratio": self.background_ratio,
606 "min_background_distance": self.min_background_distance,
607 "name_to_label": self._name_to_label,
608 "preload": self.preload,
609 }
611 with open(os.path.join(save_dir, "metadata.json"), "w") as f:
612 json.dump(metadata, f)
614 # If preloaded, save the actual tensors
615 if self.preload and hasattr(self, "_subvolumes") and self._subvolumes:
616 logger.info("Saving preloaded tensors...")
618 # Extract tensors and labels
619 subvolumes = []
620 labels = []
622 for volume, label in self._subvolumes:
623 subvolumes.append(volume)
624 labels.append(label)
626 # Stack tensors into a single tensor
627 subvolumes_tensor = torch.stack(subvolumes)
628 labels_tensor = torch.tensor(labels)
630 # Save tensors to disk
631 torch.save(subvolumes_tensor, os.path.join(save_dir, "subvolumes.pt"))
632 torch.save(labels_tensor, os.path.join(save_dir, "labels.pt"))
634 logger.info(f"Saved {len(subvolumes)} preloaded tensors")
635 else:
636 # Save sample information for on-the-fly loading
637 logger.info("Saving sample information for on-the-fly loading...")
639 sample_data = []
640 for i in range(len(self._points)):
641 point = self._points[i]
642 label = self._labels[i]
643 is_background = self._is_background[i]
644 tomogram_idx = self._tomogram_indices[i] if hasattr(self, "_tomogram_indices") else 0
646 sample_data.append(
647 {
648 "point": point.tolist() if isinstance(point, np.ndarray) else point,
649 "label": int(label),
650 "is_background": bool(is_background),
651 "tomogram_idx": int(tomogram_idx),
652 },
653 )
655 with open(os.path.join(save_dir, "samples.json"), "w") as f:
656 json.dump(sample_data, f)
658 # Save tomogram information (needed for on-the-fly loading)
659 tomogram_info = []
660 for idx, tomogram in enumerate(self._tomogram_data):
661 tomo_data = {"index": idx, "shape": list(tomogram.shape), "path": getattr(tomogram, "path", str(tomogram))}
662 tomogram_info.append(tomo_data)
664 with open(os.path.join(save_dir, "tomogram_info.json"), "w") as f:
665 json.dump(tomogram_info, f)
667 logger.info(f"Dataset saved to {save_dir}")
669 @classmethod
670 def load(cls, save_dir, proj=None):
671 """
672 Load a previously saved dataset.
674 Args:
675 save_dir: Directory where the dataset was saved
676 proj: Optional copick project object. If provided, tomograms will be loaded from it.
678 Returns:
679 Loaded MinimalCopickDataset instance
680 """
681 # Load metadata
682 with open(os.path.join(save_dir, "metadata.json"), "r") as f:
683 metadata = json.load(f)
685 # Create a new dataset instance without loading data
686 dataset = cls.__new__(cls)
687 dataset.dataset_id = metadata.get("dataset_id")
688 dataset.boxsize = metadata.get("boxsize", (48, 48, 48))
689 dataset.voxel_spacing = metadata.get("voxel_spacing", 10.012)
690 dataset.include_background = metadata.get("include_background", False)
691 dataset.background_ratio = metadata.get("background_ratio", 0.2)
692 dataset.min_background_distance = metadata.get("min_background_distance")
693 dataset._name_to_label = metadata.get("name_to_label", {})
694 dataset.preload = metadata.get("preload", True)
695 dataset.copick_root = proj
697 # Check if we have preloaded tensors
698 subvolumes_path = os.path.join(save_dir, "subvolumes.pt")
699 labels_path = os.path.join(save_dir, "labels.pt")
701 if os.path.exists(subvolumes_path) and os.path.exists(labels_path):
702 logger.info("Loading preloaded tensors...")
704 # Load the tensors
705 subvolumes = torch.load(subvolumes_path)
706 labels = torch.load(labels_path)
708 # Store in the dataset
709 dataset._subvolumes = [(subvolumes[i], labels[i].item()) for i in range(len(labels))]
711 # Create minimal point/label data for compatibility
712 dataset._points = [np.zeros(3) for _ in range(len(labels))]
713 dataset._labels = [label.item() for label in labels]
714 dataset._is_background = [label.item() == -1 for label in labels]
715 dataset._tomogram_indices = [0 for _ in range(len(labels))]
716 dataset._tomogram_data = []
718 logger.info(f"Loaded dataset with {len(dataset._subvolumes)} preloaded subvolumes")
719 else:
720 # Initialize empty data structures
721 dataset._tomogram_data = []
723 # Load sample information
724 with open(os.path.join(save_dir, "samples.json"), "r") as f:
725 sample_data = json.load(f)
727 # Extract sample information
728 dataset._points = [np.array(s["point"]) for s in sample_data]
729 dataset._labels = [s["label"] for s in sample_data]
730 dataset._is_background = [s["is_background"] for s in sample_data]
731 dataset._tomogram_indices = [s["tomogram_idx"] for s in sample_data]
733 # Load tomogram information
734 with open(os.path.join(save_dir, "tomogram_info.json"), "r") as f:
735 tomogram_info = json.load(f)
737 # If a project is provided, attempt to load tomograms from it
738 if proj is not None:
739 logger.info("Loading tomograms from provided project")
741 # Initialize tomogram list with placeholders
742 dataset._tomogram_data = [None] * len(tomogram_info)
744 # Attempt to load tomograms
745 for run in proj.runs:
746 voxel_spacing_obj = run.get_voxel_spacing(dataset.voxel_spacing)
747 if voxel_spacing_obj and voxel_spacing_obj.tomograms:
748 # Find denoised tomogram if available
749 tomogram = [t for t in voxel_spacing_obj.tomograms if "wbp-denoised" in t.tomo_type]
750 if not tomogram:
751 tomogram = voxel_spacing_obj.tomograms[0]
752 else:
753 tomogram = tomogram[0]
755 # Find matching tomogram in info
756 for tomo_info in tomogram_info:
757 # Simple check for matching shape as a heuristic
758 tomo_zarr = zarr.open(tomogram.zarr())["0"]
759 if list(tomo_zarr.shape) == tomo_info["shape"]:
760 idx = tomo_info["index"]
761 dataset._tomogram_data[idx] = tomo_zarr
762 logger.info(f"Loaded tomogram at index {idx} with shape {tomo_zarr.shape}")
763 else:
764 logger.warning("No project provided. Tomograms must be loaded separately.")
766 logger.info(f"Loaded dataset with {len(dataset._points)} samples")
768 # If preload is True and we have tomogram data, preload the subvolumes
769 if dataset.preload and dataset._tomogram_data and all(t is not None for t in dataset._tomogram_data):
770 logger.info("Preloading subvolumes...")
771 dataset._preload_data()
773 # Print class distribution
774 dataset._print_class_distribution()
776 return dataset