Coverage for copick_torch/dataset.py: 7%
583 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import logging
2import os
3import pickle
4import random
5from collections import Counter
6from datetime import datetime
7from pathlib import Path
8from typing import Any, Dict, List, Optional, Tuple, Union
10import copick
11import numpy as np
12import pandas as pd
13import torch
14import zarr
16# Import these at module level to avoid pickling issues
17from scipy.ndimage import binary_dilation, gaussian_filter
18from skimage import measure
19from skimage.transform import resize
20from torch.utils.data import ConcatDataset, Dataset, Subset
22from .augmentations import FourierAugment3D
25class SimpleDatasetMixin:
26 """
27 A mixin class that modifies datasets to return simple (image, label) pairs.
29 This modifies the __getitem__ method to return a tuple of (subvolume, label_index)
30 rather than the more complex dictionary format.
31 """
33 def __getitem__(self, idx):
34 """
35 Get an item from the dataset, returning a simple (subvolume, label) pair.
37 This simplifies the original __getitem__ method to return just an image tensor
38 and a class label integer.
40 Returns:
41 tuple: (subvolume, label)
42 """
43 # Get the subvolume using the original method
44 # The original method may apply augmentations if self.augment is True
45 subvolume = self._subvolumes[idx].copy()
46 molecule_idx = self._molecule_ids[idx]
48 if self.augment:
49 # Apply augmentations if enabled
50 subvolume = self._augment_subvolume(subvolume, idx)
52 # Normalize subvolume
53 subvolume = (subvolume - np.mean(subvolume)) / (np.std(subvolume) + 1e-6)
55 # Add channel dimension and convert to tensor
56 subvolume = torch.as_tensor(subvolume[None, ...], dtype=torch.float32)
58 # Return the subvolume and class index as a simple tuple
59 return subvolume, molecule_idx
62class SimpleCopickDataset(SimpleDatasetMixin, Dataset):
63 """
64 A simplified PyTorch dataset for working with copick data that returns (image, label) pairs.
66 This implementation is a wrapper around the original CopickDataset that modifies the
67 __getitem__ method to return a simpler format suitable for standard training pipelines.
68 """
70 def __init__(
71 self,
72 config_path: Union[str, Any] = None,
73 copick_root: Optional[Any] = None,
74 boxsize: Tuple[int, int, int] = (32, 32, 32),
75 augment: bool = False,
76 cache_dir: Optional[str] = None,
77 cache_format: str = "parquet",
78 seed: Optional[int] = 1717,
79 max_samples: Optional[int] = None,
80 voxel_spacing: float = 10.0,
81 include_background: bool = False,
82 background_ratio: float = 0.2,
83 min_background_distance: Optional[float] = None,
84 patch_strategy: str = "centered",
85 debug_mode: bool = False,
86 dataset_id: Optional[int] = None,
87 overlay_root: str = "/tmp/test/",
88 ):
89 """
90 Initialize a SimpleCopickDataset.
92 Args:
93 config_path: Path to the copick config file or CopickConfig object
94 copick_root: Copick root object (alternative to config_path)
95 boxsize: Size of the subvolumes to extract (z, y, x)
96 augment: Whether to apply data augmentation
97 cache_dir: Directory to cache extracted subvolumes
98 cache_format: Format for caching ('pickle' or 'parquet')
99 seed: Random seed for reproducibility
100 max_samples: Maximum number of samples to use
101 voxel_spacing: Voxel spacing to use for extraction
102 include_background: Whether to include background samples
103 background_ratio: Ratio of background to particle samples
104 min_background_distance: Minimum distance from particles for background samples
105 patch_strategy: Strategy for extracting patches ('centered', 'random', or 'jittered')
106 debug_mode: Whether to enable debug mode
107 """
108 # Validate input: either config_path, copick_root, or dataset_id must be provided
109 if config_path is None and copick_root is None and dataset_id is None:
110 raise ValueError("Either config_path, copick_root, or dataset_id must be provided")
112 self.config_path = config_path
113 self.copick_root = copick_root
114 self.dataset_id = dataset_id
115 self.overlay_root = overlay_root
117 # If dataset_id is provided but not copick_root, create it here
118 if self.dataset_id is not None and self.copick_root is None:
119 try:
120 import copick
122 self.copick_root = copick.from_czcdp_datasets([self.dataset_id], overlay_root=self.overlay_root)
123 print(f"Created copick root from dataset ID: {self.dataset_id}")
124 except Exception as e:
125 print(f"Error creating copick root from dataset ID: {e}")
126 raise
127 self.boxsize = boxsize
128 self.augment = augment
129 self.cache_dir = cache_dir
130 self.cache_format = cache_format.lower()
131 self.seed = seed
132 self.max_samples = max_samples
133 self.voxel_spacing = voxel_spacing
134 self.include_background = include_background
135 self.background_ratio = background_ratio
136 self.min_background_distance = min_background_distance or max(boxsize)
137 self.patch_strategy = patch_strategy
138 self.debug_mode = debug_mode
140 # Initialize dataset
141 self._set_random_seed()
142 self._subvolumes = []
143 self._molecule_ids = []
144 self._keys = []
145 self._is_background = []
146 self._load_or_process_data()
147 self._compute_sample_weights()
149 def _set_random_seed(self):
150 """Set random seeds for reproducibility."""
151 if self.seed is not None:
152 random.seed(self.seed)
153 np.random.seed(self.seed)
154 torch.manual_seed(self.seed)
155 if torch.cuda.is_available():
156 torch.cuda.manual_seed(self.seed)
158 def _compute_sample_weights(self):
159 """Compute sample weights based on class frequency for balancing."""
160 # Include special handling for background class if it exists
161 class_counts = Counter(self._molecule_ids)
162 total_samples = len(self._molecule_ids)
164 # Assign weights inversely proportional to class frequency
165 class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
167 # Compute weights for each sample
168 self.sample_weights = [class_weights[mol_id] for mol_id in self._molecule_ids]
170 def _get_cache_path(self):
171 """Get the appropriate cache file path based on format."""
172 # If we have a copick_root but no config_path, use dataset IDs
173 cache_key = self.config_path
174 if cache_key is None and self.copick_root is not None:
175 # Try to get dataset IDs from the datasets attribute
176 try:
177 dataset_ids = []
178 for dataset in self.copick_root.datasets:
179 if hasattr(dataset, "id"):
180 dataset_ids.append(str(dataset.id))
182 if dataset_ids:
183 # Use the dataset IDs in order as the cache key
184 dataset_ids_str = "_".join(dataset_ids)
185 cache_key = f"datasets_{dataset_ids_str}"
186 else:
187 # Fallback if no dataset IDs found
188 cache_key = "copick_root_unknown"
189 except (AttributeError, IndexError):
190 # Fallback if datasets attribute doesn't exist
191 if hasattr(self.copick_root, "dataset_ids"):
192 dataset_ids = [str(did) for did in self.copick_root.dataset_ids]
193 cache_key = f"datasets_{'_'.join(dataset_ids)}"
194 else:
195 # Last resort fallback
196 cache_key = f"copick_root_{hash(str(self.copick_root))}"
198 if self.cache_format == "pickle":
199 return os.path.join(
200 self.cache_dir,
201 f"{cache_key}_{self.boxsize[0]}x{self.boxsize[1]}x{self.boxsize[2]}"
202 f"_{self.voxel_spacing}"
203 f"{'_with_bg' if self.include_background else ''}.pkl",
204 )
205 else: # parquet
206 return os.path.join(
207 self.cache_dir,
208 f"{cache_key}_{self.boxsize[0]}x{self.boxsize[1]}x{self.boxsize[2]}"
209 f"_{self.voxel_spacing}"
210 f"{'_with_bg' if self.include_background else ''}.parquet",
211 )
213 def _load_or_process_data(self):
214 """Load data from cache or process it directly."""
215 # If cache_dir is None, process data directly without caching
216 if self.cache_dir is None:
217 print("Cache directory not specified. Processing data without caching...")
218 self._load_data()
219 return
221 # If cache_dir is specified, use caching logic
222 os.makedirs(self.cache_dir, exist_ok=True)
223 cache_file = self._get_cache_path()
225 if os.path.exists(cache_file):
226 print(f"Loading cached data from {cache_file}")
228 if self.cache_format == "pickle":
229 self._load_from_pickle(cache_file)
230 else: # parquet
231 self._load_from_parquet(cache_file)
233 # Apply max_samples limit if specified
234 if self.max_samples is not None and len(self._subvolumes) > self.max_samples:
235 indices = np.random.choice(len(self._subvolumes), self.max_samples, replace=False)
236 self._subvolumes = np.array(self._subvolumes)[indices]
237 self._molecule_ids = np.array(self._molecule_ids)[indices]
238 if self._is_background:
239 self._is_background = np.array(self._is_background)[indices]
240 else:
241 print("Processing data and creating cache...")
242 self._load_data()
244 # Only save to cache if we actually loaded some data
245 if len(self._subvolumes) > 0:
246 if self.cache_format == "pickle":
247 self._save_to_pickle(cache_file)
248 else: # parquet
249 self._save_to_parquet(cache_file)
250 print(f"Cached data saved to {cache_file}")
251 else:
252 print("No data loaded, skipping cache creation")
254 def _load_from_pickle(self, cache_file):
255 """Load dataset from pickle cache."""
256 with open(cache_file, "rb") as f:
257 cached_data = pickle.load(f)
258 self._subvolumes = cached_data.get("subvolumes", [])
259 self._molecule_ids = cached_data.get("molecule_ids", [])
260 self._keys = cached_data.get("keys", [])
261 self._is_background = cached_data.get("is_background", [])
263 # Handle case where background flag wasn't saved
264 if not self._is_background and self.include_background:
265 # Initialize all as non-background
266 self._is_background = [False] * len(self._subvolumes)
268 def _save_to_pickle(self, cache_file):
269 """Save dataset to pickle cache."""
270 with open(cache_file, "wb") as f:
271 pickle.dump(
272 {
273 "subvolumes": self._subvolumes,
274 "molecule_ids": self._molecule_ids,
275 "keys": self._keys,
276 "is_background": self._is_background,
277 },
278 f,
279 )
281 def _load_from_parquet(self, cache_file):
282 """Load dataset from parquet cache."""
283 try:
284 df = pd.read_parquet(cache_file)
286 # Process subvolumes from bytes back to numpy arrays
287 self._subvolumes = []
288 for idx, row in df.iterrows():
289 if isinstance(row["subvolume"], bytes):
290 # Reconstruct numpy array from bytes
291 subvol = np.frombuffer(row["subvolume"], dtype=np.float32)
292 shape = row["shape"]
293 if isinstance(shape, list):
294 shape = tuple(shape)
295 subvol = subvol.reshape(shape)
296 self._subvolumes.append(subvol)
297 else:
298 raise ValueError(f"Invalid subvolume format: {type(row['subvolume'])}")
300 # Convert to numpy array
301 self._subvolumes = np.array(self._subvolumes)
303 # Extract other fields
304 self._molecule_ids = df["molecule_id"].tolist()
305 self._keys = df["key"].tolist() if "key" in df.columns else []
307 # Load or initialize background flags
308 if "is_background" in df.columns:
309 self._is_background = df["is_background"].tolist()
310 else:
311 self._is_background = [False] * len(self._subvolumes)
313 # Reconstruct unique keys if not available
314 if not self._keys:
315 unique_ids = set()
316 for mol_id in self._molecule_ids:
317 if mol_id != -1 and not df.loc[df["molecule_id"] == mol_id, "is_background"].any():
318 unique_ids.add(mol_id)
319 self._keys = sorted(unique_ids)
321 except Exception as e:
322 print(f"Error loading from parquet: {str(e)}")
323 raise
325 def _save_to_parquet(self, cache_file):
326 """Save dataset to parquet cache."""
327 try:
328 # Check if we have any data to save
329 if len(self._subvolumes) == 0:
330 print("No data to save to parquet")
331 return
333 # Prepare records
334 records = []
335 for subvol, mol_id, is_bg in zip(self._subvolumes, self._molecule_ids, self._is_background):
336 record = {
337 "subvolume": subvol.tobytes(),
338 "shape": list(subvol.shape),
339 "molecule_id": mol_id,
340 "is_background": is_bg,
341 }
342 records.append(record)
344 # Add keys information
345 key_mapping = []
346 for i, key in enumerate(self._keys):
347 key_mapping.append({"key_index": i, "key": key})
349 # Create and save main dataframe
350 df = pd.DataFrame(records)
352 # Add keys as a column for each row
353 df["key"] = df["molecule_id"].apply(
354 lambda x: self._keys[x] if x != -1 and x < len(self._keys) else "background",
355 )
357 df.to_parquet(cache_file, index=False)
359 # Save additional metadata
360 metadata_file = cache_file.replace(".parquet", "_metadata.parquet")
361 metadata = {
362 "creation_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
363 "total_samples": len(records),
364 "unique_molecules": len(self._keys),
365 "boxsize": self.boxsize,
366 "include_background": self.include_background,
367 "background_samples": sum(self._is_background),
368 }
369 pd.DataFrame([metadata]).to_parquet(metadata_file, index=False)
371 except Exception as e:
372 print(f"Error saving to parquet: {str(e)}")
373 raise
375 def _extract_subvolume_with_validation(self, tomogram_array, x, y, z):
376 """Extract a subvolume with validation checks, applying the selected patch strategy."""
377 half_box = np.array(self.boxsize) // 2
379 # Apply patch strategy
380 if self.patch_strategy == "centered":
381 # Standard centered extraction
382 offset_x, offset_y, offset_z = 0, 0, 0
383 elif self.patch_strategy == "random":
384 # Random offsets within half_box/4
385 max_offset = [size // 4 for size in half_box]
386 offset_x = np.random.randint(-max_offset[0], max_offset[0] + 1)
387 offset_y = np.random.randint(-max_offset[1], max_offset[1] + 1)
388 offset_z = np.random.randint(-max_offset[2], max_offset[2] + 1)
389 elif self.patch_strategy == "jittered":
390 # Small random jitter for data augmentation
391 offset_x = np.random.randint(-2, 3)
392 offset_y = np.random.randint(-2, 3)
393 offset_z = np.random.randint(-2, 3)
395 # Apply offsets to coordinates
396 x_adj = x + offset_x
397 y_adj = y + offset_y
398 z_adj = z + offset_z
400 # Calculate slice indices
401 x_start = max(0, int(x_adj - half_box[0]))
402 x_end = min(tomogram_array.shape[2], int(x_adj + half_box[0]))
403 y_start = max(0, int(y_adj - half_box[1]))
404 y_end = min(tomogram_array.shape[1], int(y_adj + half_box[1]))
405 z_start = max(0, int(z_adj - half_box[2]))
406 z_end = min(tomogram_array.shape[0], int(z_adj + half_box[2]))
408 # Validate slice ranges
409 if x_end <= x_start or y_end <= y_start or z_end <= z_start:
410 return None, False, "Invalid slice range"
412 # Extract subvolume
413 subvolume = tomogram_array[z_start:z_end, y_start:y_end, x_start:x_end]
415 # Check if extracted shape matches requested size
416 if subvolume.shape != self.boxsize:
417 # Need to pad the subvolume
418 padded = np.zeros(self.boxsize, dtype=subvolume.dtype)
420 # Calculate padding amounts
421 z_pad_start = max(0, half_box[2] - int(z))
422 y_pad_start = max(0, half_box[1] - int(y))
423 x_pad_start = max(0, half_box[0] - int(x))
425 # Calculate end indices
426 z_pad_end = min(z_pad_start + (z_end - z_start), self.boxsize[0])
427 y_pad_end = min(y_pad_start + (y_end - y_start), self.boxsize[1])
428 x_pad_end = min(x_pad_start + (x_end - x_start), self.boxsize[2])
430 # Copy data
431 padded[z_pad_start:z_pad_end, y_pad_start:y_pad_end, x_pad_start:x_pad_end] = subvolume
432 return padded, True, "padded"
434 return subvolume, True, "valid"
436 def _load_data(self):
437 """Load particle picks data from copick project."""
438 # Determine which root to use
439 if self.copick_root is not None:
440 root = self.copick_root
441 print("Using provided copick root object")
442 else:
443 try:
444 root = copick.from_file(self.config_path)
445 print(f"Loading data from {self.config_path}")
446 except Exception as e:
447 print(f"Failed to load copick root: {str(e)}")
448 return
450 # Store all particle coordinates for background sampling
451 all_particle_coords = []
453 for run in root.runs:
454 print(f"Processing run: {run.name}")
456 # Try to load tomogram
457 try:
458 voxel_spacing_obj = run.get_voxel_spacing(self.voxel_spacing)
459 if (
460 voxel_spacing_obj is None
461 or not hasattr(voxel_spacing_obj, "tomograms")
462 or not voxel_spacing_obj.tomograms
463 ):
464 print(f"No tomograms found for run {run.name} at voxel spacing {self.voxel_spacing}")
465 continue
467 tomogram = voxel_spacing_obj.tomograms[0]
468 tomogram_array = tomogram.numpy()
469 except Exception as e:
470 print(f"Error loading tomogram for run {run.name}: {str(e)}")
471 continue
473 # Process picks
474 run_particle_coords = [] # Store coordinates for this run
476 for picks in run.get_picks():
477 if not picks.from_tool:
478 continue
480 object_name = picks.pickable_object_name
482 try:
483 points, _ = picks.numpy()
484 points = points / self.voxel_spacing
486 for point in points:
487 try:
488 x, y, z = point
490 # Save for background sampling
491 run_particle_coords.append((x, y, z))
493 # Extract subvolume
494 subvolume, is_valid, _ = self._extract_subvolume_with_validation(tomogram_array, x, y, z)
496 if is_valid:
497 self._subvolumes.append(subvolume)
499 if object_name not in self._keys:
500 self._keys.append(object_name)
502 self._molecule_ids.append(self._keys.index(object_name))
503 self._is_background.append(False)
504 except Exception as e:
505 print(f"Error extracting subvolume: {str(e)}")
506 except Exception as e:
507 print(f"Error processing picks for {object_name}: {str(e)}")
509 # Sample background points for this run if needed
510 if self.include_background and run_particle_coords:
511 all_particle_coords.extend(run_particle_coords)
512 self._sample_background_points(tomogram_array, run_particle_coords)
514 self._subvolumes = np.array(self._subvolumes)
515 self._molecule_ids = np.array(self._molecule_ids)
516 self._is_background = np.array(self._is_background)
518 # Apply max_samples limit if specified
519 if self.max_samples is not None and len(self._subvolumes) > self.max_samples:
520 indices = np.random.choice(len(self._subvolumes), self.max_samples, replace=False)
521 self._subvolumes = self._subvolumes[indices]
522 self._molecule_ids = self._molecule_ids[indices]
523 self._is_background = self._is_background[indices]
525 print(f"Loaded {len(self._subvolumes)} subvolumes with {len(self._keys)} classes")
526 print(f"Background samples: {sum(self._is_background)}")
528 def _sample_background_points(self, tomogram_array, particle_coords):
529 """Sample background points away from particles."""
530 if not particle_coords:
531 return
533 # Convert to numpy array for distance calculations
534 particle_coords = np.array(particle_coords)
536 # Calculate number of background samples based on ratio
537 num_particles = len(particle_coords)
538 num_background = int(num_particles * self.background_ratio)
540 # Limit attempts to avoid infinite loop
541 max_attempts = num_background * 10
542 attempts = 0
543 bg_points_found = 0
545 half_box = np.array(self.boxsize) // 2
547 while bg_points_found < num_background and attempts < max_attempts:
548 # Generate random point within tomogram bounds with margin for box extraction
549 random_point = np.array(
550 [
551 np.random.randint(half_box[0], tomogram_array.shape[2] - half_box[0]),
552 np.random.randint(half_box[1], tomogram_array.shape[1] - half_box[1]),
553 np.random.randint(half_box[2], tomogram_array.shape[0] - half_box[2]),
554 ],
555 )
557 # Calculate distances to all particles
558 distances = np.linalg.norm(particle_coords - random_point, axis=1)
560 # Check if point is far enough from all particles
561 if np.min(distances) >= self.min_background_distance:
562 # Extract subvolume
563 x, y, z = random_point
564 subvolume, is_valid, _ = self._extract_subvolume_with_validation(tomogram_array, x, y, z)
566 if is_valid:
567 self._subvolumes.append(subvolume)
568 self._molecule_ids.append(-1) # Use -1 to indicate background
569 self._is_background.append(True)
570 bg_points_found += 1
572 attempts += 1
574 print(f"Added {bg_points_found} background points after {attempts} attempts")
576 def _augment_subvolume(self, subvolume, idx=None):
577 """Apply data augmentation to a subvolume.
579 This simplified version applies basic augmentations only (no mixup).
581 Args:
582 subvolume: The 3D volume to augment
583 idx: Optional index for mixup augmentation (not used in this version)
585 Returns:
586 Augmented subvolume
587 """
588 # Apply random brightness adjustment
589 if random.random() < 0.3:
590 delta = np.random.uniform(-0.5, 0.5)
591 subvolume = subvolume + delta
593 # Apply random Gaussian blur
594 if random.random() < 0.2:
595 sigma = np.random.uniform(0.5, 1.5)
596 subvolume = gaussian_filter(subvolume, sigma=sigma)
598 # Apply random intensity scaling
599 if random.random() < 0.2:
600 factor = np.random.uniform(0.5, 1.5)
601 subvolume = subvolume * factor
603 # Apply random flip
604 if random.random() < 0.2:
605 axis = random.randint(0, 2)
606 subvolume = np.flip(subvolume, axis=axis)
608 # Apply random rotation
609 if random.random() < 0.2:
610 k = random.randint(1, 3) # 90, 180, or 270 degrees
611 axes = tuple(random.sample([0, 1, 2], 2)) # Select 2 random axes
612 subvolume = np.rot90(subvolume, k=k, axes=axes)
614 # Apply Fourier domain augmentation
615 if random.random() < 0.3: # 30% chance to apply Fourier augmentation
616 fourier_aug = FourierAugment3D(freq_mask_prob=0.3, phase_noise_std=0.1, intensity_scaling_range=(0.8, 1.2))
617 subvolume = fourier_aug(subvolume)
619 return subvolume
621 def __len__(self):
622 """Get the total number of items in the dataset."""
623 return len(self._subvolumes)
625 def get_sample_weights(self):
626 """Return sample weights for use in a WeightedRandomSampler."""
627 return self.sample_weights
629 def keys(self):
630 """Get pickable object keys."""
631 return self._keys
633 def get_class_distribution(self):
634 """Get distribution of classes in the dataset."""
635 class_counts = Counter(self._molecule_ids)
637 # Create a readable distribution
638 distribution = {}
640 # Count background samples if any
641 if -1 in class_counts:
642 distribution["background"] = class_counts[-1]
643 del class_counts[-1]
645 # Count regular classes
646 for cls_idx, count in class_counts.items():
647 if 0 <= cls_idx < len(self._keys):
648 distribution[self._keys[cls_idx]] = count
650 return distribution
653class SplicedMixupDataset(SimpleCopickDataset):
654 """
655 A dataset that loads zarr arrays into memory and performs balanced sampling with mixup splicing.
657 This dataset extends SimpleCopickDataset to add experimental-synthetic data splicing capabilities,
658 keeping zarr arrays in memory for faster loading and using balanced sampling by default.
659 """
661 def __init__(
662 self,
663 exp_dataset_id: int,
664 synth_dataset_id: int,
665 synth_run_id: str = "16487",
666 overlay_root: str = "/tmp/test/",
667 boxsize: Tuple[int, int, int] = (48, 48, 48),
668 augment: bool = True,
669 cache_dir: Optional[str] = None,
670 cache_format: str = "parquet",
671 seed: Optional[int] = 1717,
672 max_samples: Optional[int] = None,
673 voxel_spacing: float = 10.0,
674 include_background: bool = False,
675 background_ratio: float = 0.2,
676 min_background_distance: Optional[float] = None,
677 blend_sigma: float = 2.0, # Controls the standard deviation of Gaussian blending at boundaries
678 mixup_alpha: float = 0.2,
679 debug_mode: bool = False,
680 ):
681 """
682 Initialize the SplicedMixupDataset.
684 Args:
685 exp_dataset_id: Dataset ID for the experimental dataset
686 synth_dataset_id: Dataset ID for the synthetic dataset
687 synth_run_id: Run ID for the synthetic dataset (default: "16487")
688 overlay_root: Root directory for the overlay storage (default: "/tmp/test/")
689 boxsize: Size of the subvolumes to extract (z, y, x)
690 augment: Whether to apply data augmentation
691 cache_dir: Directory to cache extracted subvolumes
692 cache_format: Format for caching ('pickle' or 'parquet')
693 seed: Random seed for reproducibility
694 max_samples: Maximum number of samples to use
695 voxel_spacing: Voxel spacing to use for extraction
696 include_background: Whether to include background samples
697 background_ratio: Ratio of background to particle samples
698 min_background_distance: Minimum distance from particles for background samples
699 blend_sigma: Controls the standard deviation of Gaussian blending at boundaries
700 mixup_alpha: Alpha parameter for mixup augmentation
701 debug_mode: Whether to enable debug mode
702 """
703 # Save specific parameters
704 self.exp_dataset_id = exp_dataset_id
705 self.synth_dataset_id = synth_dataset_id
706 self.synth_run_id = synth_run_id
707 self.overlay_root = overlay_root
708 self.blend_sigma = blend_sigma
709 self.mixup_alpha = mixup_alpha
711 # Initialize load flags and storage for zarr arrays
712 self._zarr_loaded = False
713 self._exp_zarr_data = None
714 self._synth_zarr_data = None
715 self._synth_mask_data = {}
717 # Load copick roots
718 self._load_copick_roots()
720 # Initialize the parent class (SimpleCopickDataset)
721 # We'll override certain methods to use our in-memory zarr arrays
722 super().__init__(
723 copick_root=self.exp_root if hasattr(self, "exp_root") else None, # Use experimental data as the base
724 boxsize=boxsize,
725 augment=augment,
726 cache_dir=cache_dir,
727 cache_format=cache_format,
728 seed=seed,
729 max_samples=max_samples,
730 voxel_spacing=voxel_spacing,
731 include_background=include_background,
732 background_ratio=background_ratio,
733 min_background_distance=min_background_distance,
734 patch_strategy="centered", # Always use centered for splicing
735 debug_mode=debug_mode,
736 )
738 # Load zarr arrays into memory if not already loaded
739 self._ensure_zarr_loaded()
741 # Initialize dataset with a small number of samples to make sure the parent initialization works
742 # We will create our own samples from zarr arrays after parent initialization
744 # Generate synthetic samples directly from zarr arrays
745 self._generate_synthetic_samples()
747 def _generate_synthetic_samples(self):
748 """Generate synthetic samples directly from zarr arrays."""
749 print("Generating synthetic samples from zarr arrays...")
750 # Clear any existing samples
751 self._subvolumes = []
752 self._molecule_ids = []
753 self._keys = []
754 self._is_background = []
756 num_samples = 100 # Default number of samples
757 if self.max_samples is not None:
758 num_samples = self.max_samples
760 # Generate samples (half from experimental data, half from synthetic+experimental splice)
761 num_exp_samples = num_samples // 2
762 num_synth_samples = num_samples - num_exp_samples
764 # Generate experimental samples
765 for _ in range(num_exp_samples):
766 # Extract a random crop from experimental data
767 exp_crop = self._extract_random_crop(self._exp_zarr_data, self.boxsize)
768 self._subvolumes.append(exp_crop)
769 self._molecule_ids.append(-1) # Background class
770 self._is_background.append(True)
772 # Generate synthetic+experimental spliced samples
773 for _ in range(num_synth_samples):
774 # Get a random mask name
775 mask_names = list(self._synth_mask_data.keys())
776 mask_name = random.choice(mask_names)
778 # Extract a bounding box
779 bbox_info = self._extract_bounding_box(self._synth_mask_data[mask_name], mask_name)
781 if bbox_info is not None:
782 # Extract a random crop from experimental data
783 exp_crop = self._extract_random_crop(self._exp_zarr_data, self.boxsize)
785 # Splice the volumes
786 spliced_volume = self._splice_volumes(bbox_info["synth_region"], bbox_info["region_mask"], exp_crop)
788 # Add to dataset
789 self._subvolumes.append(spliced_volume)
791 # Get or create molecule index
792 if bbox_info["object_name"] not in self._keys:
793 self._keys.append(bbox_info["object_name"])
794 molecule_idx = self._keys.index(bbox_info["object_name"])
796 self._molecule_ids.append(molecule_idx)
797 self._is_background.append(False)
799 # Convert to numpy arrays
800 self._subvolumes = np.array(self._subvolumes)
801 self._molecule_ids = np.array(self._molecule_ids)
802 self._is_background = np.array(self._is_background)
804 # Compute sample weights
805 self._compute_sample_weights()
807 print(f"Generated {len(self._subvolumes)} samples with {len(self._keys)} classes")
808 print(f"Background samples: {sum(self._is_background)}")
809 print(f"Class distribution: {self.get_class_distribution()}")
811 def _load_copick_roots(self):
812 """Load the experimental and synthetic copick roots."""
813 try:
814 print(f"Loading experimental dataset {self.exp_dataset_id} and synthetic dataset {self.synth_dataset_id}")
815 self.exp_root = copick.from_czcdp_datasets([self.exp_dataset_id], overlay_root=self.overlay_root)
816 self.synth_root = copick.from_czcdp_datasets([self.synth_dataset_id], overlay_root=self.overlay_root)
818 print(f"Experimental dataset: {len(self.exp_root.runs)} runs")
819 print(f"Synthetic dataset: {len(self.synth_root.runs)} runs")
821 # Filter synthetic dataset to only include the specified run
822 if self.synth_run_id:
823 print(f"Filtering synthetic dataset to only use run {self.synth_run_id}")
824 filtered_runs = [run for run in self.synth_root.runs if run.meta.name == self.synth_run_id]
825 if filtered_runs:
826 print(f"Found run {self.synth_run_id}. Using only this run.")
827 # Store the filtered run for use in loading zarr data
828 self.synth_root._filtered_run = filtered_runs[0]
829 else:
830 print(f"Run {self.synth_run_id} not found in synthetic dataset. Using all available runs.")
831 self.synth_root._filtered_run = None
832 except Exception as e:
833 print(f"Error loading CoPick roots: {str(e)}")
834 raise
836 def _ensure_zarr_loaded(self):
837 """Load zarr arrays into memory if not already loaded."""
838 if not self._zarr_loaded:
839 self._load_experimental_zarr()
840 self._load_synthetic_zarr()
841 self._load_segmentation_masks()
842 self._zarr_loaded = True
843 print("Zarr arrays loaded into memory.")
845 def _load_experimental_zarr(self):
846 """Load experimental tomogram into memory."""
847 try:
848 # Get available tomograms from experimental dataset
849 exp_tomograms = self._get_available_tomograms(self.exp_root, self.voxel_spacing)
851 if not exp_tomograms:
852 raise ValueError(f"No experimental tomograms found with voxel spacing {self.voxel_spacing}")
854 # Select the first tomogram
855 exp_tomogram_obj = exp_tomograms[0]
856 exp_zarr = zarr.open(exp_tomogram_obj.zarr(), "r")
857 self._exp_zarr_data = exp_zarr["0"][:]
859 # Normalize tomogram data
860 self._exp_zarr_data = (self._exp_zarr_data - np.mean(self._exp_zarr_data)) / np.std(self._exp_zarr_data)
861 print(f"Loaded experimental zarr with shape {self._exp_zarr_data.shape}")
862 except Exception as e:
863 print(f"Error loading experimental zarr: {str(e)}")
864 raise
866 def _load_synthetic_zarr(self):
867 """Load synthetic tomogram into memory."""
868 try:
869 # Get available tomograms from synthetic dataset
870 synth_tomograms = self._get_available_tomograms(self.synth_root, self.voxel_spacing)
872 if not synth_tomograms:
873 raise ValueError(f"No synthetic tomograms found with voxel spacing {self.voxel_spacing}")
875 # Select the first tomogram
876 synth_tomogram_obj = synth_tomograms[0]
877 synth_zarr = zarr.open(synth_tomogram_obj.zarr(), "r")
878 self._synth_zarr_data = synth_zarr["0"][:]
880 # Normalize tomogram data
881 self._synth_zarr_data = (self._synth_zarr_data - np.mean(self._synth_zarr_data)) / np.std(
882 self._synth_zarr_data,
883 )
884 print(f"Loaded synthetic zarr with shape {self._synth_zarr_data.shape}")
885 except Exception as e:
886 print(f"Error loading synthetic zarr: {str(e)}")
887 raise
889 def _get_available_tomograms(self, root, voxel_spacing, tomo_type="wbp"):
890 """Get available tomograms from a CoPick dataset for a specific voxel spacing."""
891 available_tomograms = []
893 # If a filtered run is specified, only use that run
894 if hasattr(root, "_filtered_run") and root._filtered_run is not None:
895 runs = [root._filtered_run]
896 print(f"Using only filtered run: {root._filtered_run.meta.name}")
897 else:
898 runs = root.runs
900 for run in runs:
901 # Get the closest voxel spacing to the target
902 closest_vs = None
903 min_diff = float("inf")
905 for vs in run.voxel_spacings:
906 diff = abs(vs.meta.voxel_size - voxel_spacing)
907 if diff < min_diff:
908 min_diff = diff
909 closest_vs = vs
911 if closest_vs:
912 tomograms = closest_vs.get_tomograms(tomo_type)
913 if tomograms:
914 available_tomograms.extend(tomograms)
915 print(
916 f"Found {len(tomograms)} tomograms in run {run.meta.name} with voxel spacing {closest_vs.meta.voxel_size}",
917 )
919 return available_tomograms
921 def _load_segmentation_masks(self):
922 """Load segmentation masks from synthetic dataset into memory."""
923 try:
924 # Get segmentation masks from synthetic dataset
925 segmentation_masks = self._get_segmentation_masks(self.synth_root, self.voxel_spacing)
927 if not segmentation_masks:
928 raise ValueError(f"No segmentation masks found with voxel spacing {self.voxel_spacing}")
930 # Load each mask into memory
931 for mask_name, mask_obj in segmentation_masks.items():
932 # Skip membrane segmentation masks
933 if mask_name.lower() == "membrane":
934 print(f"Skipping membrane segmentation mask: {mask_name}")
935 continue
937 # Access the mask data
938 mask_zarr = zarr.open(mask_obj.zarr(), "r")
939 mask_data = mask_zarr["data" if "data" in mask_zarr else "0"][:]
941 # Store the mask data
942 self._synth_mask_data[mask_name] = mask_data
943 print(f"Loaded mask '{mask_name}' with shape {mask_data.shape}")
945 print(f"Loaded {len(self._synth_mask_data)} segmentation masks")
946 except Exception as e:
947 print(f"Error loading segmentation masks: {str(e)}")
948 raise
950 def _get_segmentation_masks(self, root, voxel_spacing, pickable_objects=None):
951 """Get segmentation masks from a CoPick dataset for a specific voxel spacing."""
952 segmentation_masks = {}
954 # If a filtered run is specified, only use that run
955 if hasattr(root, "_filtered_run") and root._filtered_run is not None:
956 runs = [root._filtered_run]
957 print(f"Using only filtered run: {root._filtered_run.meta.name} for segmentation masks")
958 else:
959 runs = root.runs
961 for run in runs:
962 # Get the closest voxel spacing to the target
963 closest_vs = None
964 min_diff = float("inf")
966 for vs in run.voxel_spacings:
967 diff = abs(vs.meta.voxel_size - voxel_spacing)
968 if diff < min_diff:
969 min_diff = diff
970 closest_vs = vs
972 if closest_vs:
973 segmentations = run.get_segmentations(voxel_size=closest_vs.meta.voxel_size)
975 for seg in segmentations:
976 # Only include segmentations matching requested pickable objects
977 if pickable_objects is None or seg.meta.name in pickable_objects:
978 segmentation_masks[seg.meta.name] = seg
979 print(f"Found segmentation mask for '{seg.meta.name}' in run {run.meta.name}")
981 return segmentation_masks
983 def _extract_random_crop(self, tomogram_data, crop_size):
984 """Extract a random crop from a tomogram."""
985 depth, height, width = tomogram_data.shape
987 # Ensure crop sizes don't exceed tomogram dimensions
988 crop_depth = min(crop_size[0], depth)
989 crop_height = min(crop_size[1], height)
990 crop_width = min(crop_size[2], width)
992 # Calculate valid ranges for the random crop
993 max_z = depth - crop_depth
994 max_y = height - crop_height
995 max_x = width - crop_width
997 if max_z <= 0 or max_y <= 0 or max_x <= 0:
998 # Tomogram is smaller than crop size in at least one dimension
999 return resize(tomogram_data, crop_size, mode="reflect", anti_aliasing=True)
1001 # Get random start coordinates
1002 z_start = random.randint(0, max_z)
1003 y_start = random.randint(0, max_y)
1004 x_start = random.randint(0, max_x)
1006 # Extract the crop
1007 crop = tomogram_data[
1008 z_start : z_start + crop_depth,
1009 y_start : y_start + crop_height,
1010 x_start : x_start + crop_width,
1011 ]
1013 return crop
1015 def _extract_bounding_box(self, mask_data, mask_name):
1016 """Extract a bounding box for a connected component in a segmentation mask."""
1017 # Label connected components
1018 labels = measure.label(mask_data > 0)
1019 regions = measure.regionprops(labels)
1021 if not regions:
1022 return None
1024 # Select a random region to extract
1025 region = random.choice(regions)
1027 # Get the centroid of the region
1028 z_center, y_center, x_center = region.centroid
1030 # Calculate box boundaries centered on the particle
1031 box_size = self.boxsize[0] # Assume cubic box
1032 half_size = box_size // 2
1034 z_min = max(0, int(z_center - half_size))
1035 y_min = max(0, int(y_center - half_size))
1036 x_min = max(0, int(x_center - half_size))
1038 # Adjust if box would go beyond bounds
1039 if z_min + box_size > mask_data.shape[0]:
1040 z_min = max(0, mask_data.shape[0] - box_size)
1041 if y_min + box_size > mask_data.shape[1]:
1042 y_min = max(0, mask_data.shape[1] - box_size)
1043 if x_min + box_size > mask_data.shape[2]:
1044 x_min = max(0, mask_data.shape[2] - box_size)
1046 # Calculate max coordinates
1047 z_max = min(mask_data.shape[0], z_min + box_size)
1048 y_max = min(mask_data.shape[1], y_min + box_size)
1049 x_max = min(mask_data.shape[2], x_min + box_size)
1051 # Check if we can extract a full box
1052 if (z_max - z_min) != box_size or (y_max - y_min) != box_size or (x_max - x_min) != box_size:
1053 print(f"Cannot extract a full {box_size}^3 box at the edge of the volume.")
1054 return None
1056 # Create a mask for this specific region
1057 region_mask = np.zeros(mask_data.shape, dtype=bool)
1058 region_mask[labels == region.label] = True
1060 # Dilate the mask slightly for smoother boundaries
1061 dilated_mask = binary_dilation(region_mask, iterations=2)
1063 # Extract the fixed-size box from the mask
1064 box_mask = dilated_mask[z_min:z_max, y_min:y_max, x_min:x_max].copy()
1066 # Verify box mask has expected dimensions
1067 if box_mask.shape != (box_size, box_size, box_size):
1068 print(f"Box mask has unexpected shape: {box_mask.shape}")
1069 return None
1071 # Extract corresponding region from synthetic tomogram
1072 try:
1073 synth_region = self._synth_zarr_data[z_min:z_max, y_min:y_max, x_min:x_max].copy()
1074 if synth_region.shape != box_mask.shape:
1075 synth_region = resize(synth_region, box_mask.shape, mode="reflect", anti_aliasing=True)
1076 except Exception as e:
1077 print(f"Error extracting synthetic region: {e}")
1078 return None
1080 # Return the bounding box info
1081 return {
1082 "bbox": (z_min, y_min, x_min, z_max, y_max, x_max),
1083 "region_mask": box_mask,
1084 "synth_region": synth_region,
1085 "object_name": mask_name,
1086 "center": region.centroid,
1087 }
1089 def _splice_volumes(self, synthetic_region, region_mask, exp_crop):
1090 """Splice a synthetic structure into an experimental tomogram using Gaussian blending at the edges."""
1091 # Create a spliced volume by starting with the experimental crop
1092 spliced_volume = exp_crop.copy()
1094 # For Gaussian blending, create a weight map that transitions smoothly from 1 to 0
1095 if self.blend_sigma > 0:
1096 try:
1097 # Start with the region mask
1098 mask_float = region_mask.astype(np.float32)
1100 # Apply Gaussian blur to the mask to create a smooth transition at the boundaries
1101 # This creates a weight map that goes from 1 (inside) to 0 (outside) with smooth transitions
1102 weight_map = gaussian_filter(mask_float, sigma=self.blend_sigma)
1104 # Normalize weight map to ensure it's between 0 and 1
1105 weight_map = np.clip(weight_map, 0, 1)
1107 # Apply weighted blending: synthetic * weight + experimental * (1-weight)
1108 spliced_volume = synthetic_region * weight_map + exp_crop * (1 - weight_map)
1110 except Exception as e:
1111 print(f"Error during Gaussian blending at boundaries: {e}")
1112 # Fall back to simple mask-based splicing
1113 spliced_volume[region_mask] = synthetic_region[region_mask]
1114 else:
1115 # If blend_sigma is 0, just do simple mask-based splicing without blending
1116 spliced_volume[region_mask] = synthetic_region[region_mask]
1118 return spliced_volume
1120 def __getitem__(self, idx):
1121 """Get an item with spliced mixup augmentation."""
1122 # Ensure zarr data is loaded
1123 self._ensure_zarr_loaded()
1125 # Get the base subvolume using the parent method (without normalizing and tensor conversion)
1126 subvolume = self._subvolumes[idx].copy()
1127 molecule_idx = self._molecule_ids[idx]
1129 # Get a random mask and extract a bounding box
1130 mask_name = random.choice(list(self._synth_mask_data.keys()))
1131 bbox_info = self._extract_bounding_box(self._synth_mask_data[mask_name], mask_name)
1133 if bbox_info is not None:
1134 # Extract a random crop from the experimental tomogram
1135 exp_crop = self._extract_random_crop(self._exp_zarr_data, self.boxsize)
1137 # Splice the volumes
1138 spliced_volume = self._splice_volumes(bbox_info["synth_region"], bbox_info["region_mask"], exp_crop)
1140 # Decide whether to use the spliced volume based on a random chance
1141 if random.random() < 0.5:
1142 subvolume = spliced_volume
1143 # Get the molecule_idx for the synthetic object
1144 if bbox_info["object_name"] not in self._keys:
1145 self._keys.append(bbox_info["object_name"])
1146 molecule_idx = self._keys.index(bbox_info["object_name"])
1148 # Apply augmentations if enabled
1149 if self.augment:
1150 subvolume = self._augment_subvolume(subvolume)
1152 # Normalize subvolume
1153 subvolume = (subvolume - np.mean(subvolume)) / (np.std(subvolume) + 1e-6)
1155 # Add channel dimension and convert to tensor
1156 subvolume = torch.as_tensor(subvolume[None, ...], dtype=torch.float32)
1158 # Return the subvolume and class index as a simple tuple
1159 return subvolume, molecule_idx