Coverage for copick_torch/copick.py: 5%
671 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import logging
2import os
3import pickle
4import random
5from collections import Counter
6from datetime import datetime
7from pathlib import Path
8from typing import Any, Dict, List, Optional, Tuple, Union
10import copick
11import numpy as np
12import pandas as pd
13import torch
14import zarr
15from scipy.ndimage import gaussian_filter
16from torch.utils.data import ConcatDataset, Dataset, Subset
19class CopickDataset(Dataset):
20 """
21 A PyTorch dataset for working with copick data for particle picking tasks.
23 This implementation focuses on extracting subvolumes around pick coordinates
24 with support for data augmentation, caching, and class balancing.
25 """
27 def __init__(
28 self,
29 config_path: Union[str, Any] = None,
30 copick_root: Optional[Any] = None,
31 boxsize: Tuple[int, int, int] = (32, 32, 32),
32 augment: bool = False,
33 cache_dir: Optional[str] = None,
34 cache_format: str = "parquet", # Can be "pickle" or "parquet"
35 seed: Optional[int] = 1717,
36 max_samples: Optional[int] = None,
37 voxel_spacing: float = 10.0,
38 include_background: bool = False,
39 background_ratio: float = 0.2, # Background samples as proportion of particle samples
40 min_background_distance: Optional[float] = None, # Min distance in voxels from particles
41 patch_strategy: str = "centered", # Can be "centered", "random", or "jittered"
42 augmentations: Optional[List[str]] = None, # List of augmentation types to apply
43 augmentation_prob: float = 0.2, # Probability of applying each augmentation
44 mixup_alpha: Optional[float] = None, # Alpha parameter for mixup augmentation
45 rotate_axes: Tuple[int, int, int] = (1, 1, 1), # Enable/disable rotation around each axis (x, y, z)
46 debug_mode: bool = False,
47 ):
48 # Validate input: either config_path or copick_root must be provided
49 if config_path is None and copick_root is None:
50 raise ValueError("Either config_path or copick_root must be provided")
52 self.config_path = config_path
53 self.copick_root = copick_root
54 self.boxsize = boxsize
55 self.augment = augment
56 self.cache_dir = cache_dir
57 self.cache_format = cache_format.lower()
58 self.seed = seed
59 self.max_samples = max_samples
60 self.voxel_spacing = voxel_spacing
61 self.include_background = include_background
62 self.background_ratio = background_ratio
63 self.min_background_distance = min_background_distance or max(boxsize)
64 self.patch_strategy = patch_strategy
65 self.debug_mode = debug_mode
67 # Augmentation settings
68 self.augmentation_prob = augmentation_prob
69 self.mixup_alpha = mixup_alpha
70 self.rotate_axes = rotate_axes
72 # Default augmentations if not specified
73 self.default_augmentations = ["brightness", "blur", "intensity", "flip", "rotate"]
74 self.augmentations = augmentations or self.default_augmentations
76 # Special augmentations that need additional handling
77 self.special_augmentations = ["mixup", "rotate_z"]
79 # Validate augmentations
80 valid_augmentations = self.default_augmentations + self.special_augmentations
81 for aug in self.augmentations:
82 if aug not in valid_augmentations:
83 raise ValueError(f"Unknown augmentation type: {aug}. Valid options are: {valid_augmentations}")
85 # Validate parameters
86 if self.cache_format not in ["pickle", "parquet"]:
87 raise ValueError("cache_format must be either 'pickle' or 'parquet'")
89 if self.patch_strategy not in ["centered", "random", "jittered"]:
90 raise ValueError("patch_strategy must be one of 'centered', 'random', or 'jittered'")
92 # Initialize dataset
93 self._set_random_seed()
94 self._subvolumes = []
95 self._molecule_ids = []
96 self._keys = []
97 self._is_background = []
98 self._load_or_process_data()
99 self._compute_sample_weights()
101 def _set_random_seed(self):
102 """Set random seeds for reproducibility."""
103 if self.seed is not None:
104 random.seed(self.seed)
105 np.random.seed(self.seed)
106 torch.manual_seed(self.seed)
107 if torch.cuda.is_available():
108 torch.cuda.manual_seed(self.seed)
110 def _compute_sample_weights(self):
111 """Compute sample weights based on class frequency for balancing."""
112 # Include special handling for background class if it exists
113 class_counts = Counter(self._molecule_ids)
114 total_samples = len(self._molecule_ids)
116 # Assign weights inversely proportional to class frequency
117 class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
119 # Compute weights for each sample
120 self.sample_weights = [class_weights[mol_id] for mol_id in self._molecule_ids]
122 def _get_cache_path(self):
123 """Get the appropriate cache file path based on format."""
124 # If we have a copick_root but no config_path, use dataset IDs
125 cache_key = self.config_path
126 if cache_key is None and self.copick_root is not None:
127 # Try to get dataset IDs from the datasets attribute
128 try:
129 dataset_ids = []
130 for dataset in self.copick_root.datasets:
131 if hasattr(dataset, "id"):
132 dataset_ids.append(str(dataset.id))
134 if dataset_ids:
135 # Use the dataset IDs in order as the cache key
136 dataset_ids_str = "_".join(dataset_ids)
137 cache_key = f"datasets_{dataset_ids_str}"
138 else:
139 # Fallback if no dataset IDs found
140 cache_key = "copick_root_unknown"
141 except (AttributeError, IndexError):
142 # Fallback if datasets attribute doesn't exist
143 if hasattr(self.copick_root, "dataset_ids"):
144 dataset_ids = [str(did) for did in self.copick_root.dataset_ids]
145 cache_key = f"datasets_{'_'.join(dataset_ids)}"
146 else:
147 # Last resort fallback
148 cache_key = f"copick_root_{hash(str(self.copick_root))}"
150 if self.cache_format == "pickle":
151 return os.path.join(
152 self.cache_dir,
153 f"{cache_key}_{self.boxsize[0]}x{self.boxsize[1]}x{self.boxsize[2]}"
154 f"_{self.voxel_spacing}"
155 f"{'_with_bg' if self.include_background else ''}.pkl",
156 )
157 else: # parquet
158 return os.path.join(
159 self.cache_dir,
160 f"{cache_key}_{self.boxsize[0]}x{self.boxsize[1]}x{self.boxsize[2]}"
161 f"_{self.voxel_spacing}"
162 f"{'_with_bg' if self.include_background else ''}.parquet",
163 )
165 def _load_or_process_data(self):
166 """Load data from cache or process it directly."""
167 # If cache_dir is None, process data directly without caching
168 if self.cache_dir is None:
169 print("Cache directory not specified. Processing data without caching...")
170 self._load_data()
171 return
173 # If cache_dir is specified, use caching logic
174 os.makedirs(self.cache_dir, exist_ok=True)
175 cache_file = self._get_cache_path()
177 if os.path.exists(cache_file):
178 print(f"Loading cached data from {cache_file}")
180 if self.cache_format == "pickle":
181 self._load_from_pickle(cache_file)
182 else: # parquet
183 self._load_from_parquet(cache_file)
185 # Apply max_samples limit if specified
186 if self.max_samples is not None and len(self._subvolumes) > self.max_samples:
187 indices = np.random.choice(len(self._subvolumes), self.max_samples, replace=False)
188 self._subvolumes = np.array(self._subvolumes)[indices]
189 self._molecule_ids = np.array(self._molecule_ids)[indices]
190 if self._is_background:
191 self._is_background = np.array(self._is_background)[indices]
192 else:
193 print("Processing data and creating cache...")
194 self._load_data()
196 # Only save to cache if we actually loaded some data
197 if len(self._subvolumes) > 0:
198 if self.cache_format == "pickle":
199 self._save_to_pickle(cache_file)
200 else: # parquet
201 self._save_to_parquet(cache_file)
202 print(f"Cached data saved to {cache_file}")
203 else:
204 print("No data loaded, skipping cache creation")
206 def _load_from_pickle(self, cache_file):
207 """Load dataset from pickle cache."""
208 with open(cache_file, "rb") as f:
209 cached_data = pickle.load(f)
210 self._subvolumes = cached_data.get("subvolumes", [])
211 self._molecule_ids = cached_data.get("molecule_ids", [])
212 self._keys = cached_data.get("keys", [])
213 self._is_background = cached_data.get("is_background", [])
215 # Handle case where background flag wasn't saved
216 if not self._is_background and self.include_background:
217 # Initialize all as non-background
218 self._is_background = [False] * len(self._subvolumes)
220 def _save_to_pickle(self, cache_file):
221 """Save dataset to pickle cache."""
222 with open(cache_file, "wb") as f:
223 pickle.dump(
224 {
225 "subvolumes": self._subvolumes,
226 "molecule_ids": self._molecule_ids,
227 "keys": self._keys,
228 "is_background": self._is_background,
229 },
230 f,
231 )
233 def _load_from_parquet(self, cache_file):
234 """Load dataset from parquet cache."""
235 try:
236 df = pd.read_parquet(cache_file)
238 # Process subvolumes from bytes back to numpy arrays
239 self._subvolumes = []
240 for _, row in df.iterrows():
241 if isinstance(row["subvolume"], bytes):
242 # Reconstruct numpy array from bytes
243 subvol = np.frombuffer(row["subvolume"], dtype=np.float32)
244 shape = row["shape"]
245 if isinstance(shape, list):
246 shape = tuple(shape)
247 subvol = subvol.reshape(shape)
248 self._subvolumes.append(subvol)
249 else:
250 raise ValueError(f"Invalid subvolume format: {type(row['subvolume'])}")
252 # Convert to numpy array
253 self._subvolumes = np.array(self._subvolumes)
255 # Extract other fields
256 self._molecule_ids = df["molecule_id"].tolist()
257 self._keys = df["key"].tolist() if "key" in df.columns else []
259 # Load or initialize background flags
260 if "is_background" in df.columns:
261 self._is_background = df["is_background"].tolist()
262 else:
263 self._is_background = [False] * len(self._subvolumes)
265 # Reconstruct unique keys if not available
266 if not self._keys:
267 unique_ids = set()
268 for mol_id in self._molecule_ids:
269 if mol_id != -1 and not df.loc[df["molecule_id"] == mol_id, "is_background"].any():
270 unique_ids.add(mol_id)
271 self._keys = sorted(unique_ids)
273 except Exception as e:
274 print(f"Error loading from parquet: {str(e)}")
275 raise
277 def _save_to_parquet(self, cache_file):
278 """Save dataset to parquet cache."""
279 try:
280 # Check if we have any data to save
281 if len(self._subvolumes) == 0:
282 print("No data to save to parquet")
283 return
285 # Prepare records
286 records = []
287 for subvol, mol_id, is_bg in zip(self._subvolumes, self._molecule_ids, self._is_background):
288 record = {
289 "subvolume": subvol.tobytes(),
290 "shape": list(subvol.shape),
291 "molecule_id": mol_id,
292 "is_background": is_bg,
293 }
294 records.append(record)
296 # Add keys information
297 key_mapping = []
298 for i, key in enumerate(self._keys):
299 key_mapping.append({"key_index": i, "key": key})
301 # Create and save main dataframe
302 df = pd.DataFrame(records)
304 # Add keys as a column for each row
305 df["key"] = df["molecule_id"].apply(
306 lambda x: self._keys[x] if x != -1 and x < len(self._keys) else "background",
307 )
309 df.to_parquet(cache_file, index=False)
311 # Save additional metadata
312 metadata_file = cache_file.replace(".parquet", "_metadata.parquet")
313 metadata = {
314 "creation_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
315 "total_samples": len(records),
316 "unique_molecules": len(self._keys),
317 "boxsize": self.boxsize,
318 "include_background": self.include_background,
319 "background_samples": sum(self._is_background),
320 }
321 pd.DataFrame([metadata]).to_parquet(metadata_file, index=False)
323 except Exception as e:
324 print(f"Error saving to parquet: {str(e)}")
325 raise
327 def _extract_subvolume_with_validation(self, tomogram_array, x, y, z):
328 """Extract a subvolume with validation checks, applying the selected patch strategy."""
329 half_box = np.array(self.boxsize) // 2
331 # Apply patch strategy
332 if self.patch_strategy == "centered":
333 # Standard centered extraction
334 offset_x, offset_y, offset_z = 0, 0, 0
335 elif self.patch_strategy == "random":
336 # Random offsets within half_box/4
337 max_offset = [size // 4 for size in half_box]
338 offset_x = np.random.randint(-max_offset[0], max_offset[0] + 1)
339 offset_y = np.random.randint(-max_offset[1], max_offset[1] + 1)
340 offset_z = np.random.randint(-max_offset[2], max_offset[2] + 1)
341 elif self.patch_strategy == "jittered":
342 # Small random jitter for data augmentation
343 offset_x = np.random.randint(-2, 3)
344 offset_y = np.random.randint(-2, 3)
345 offset_z = np.random.randint(-2, 3)
347 # Apply offsets to coordinates
348 x_adj = x + offset_x
349 y_adj = y + offset_y
350 z_adj = z + offset_z
352 # Calculate slice indices
353 x_start = max(0, int(x_adj - half_box[0]))
354 x_end = min(tomogram_array.shape[2], int(x_adj + half_box[0]))
355 y_start = max(0, int(y_adj - half_box[1]))
356 y_end = min(tomogram_array.shape[1], int(y_adj + half_box[1]))
357 z_start = max(0, int(z_adj - half_box[2]))
358 z_end = min(tomogram_array.shape[0], int(z_adj + half_box[2]))
360 # Validate slice ranges
361 if x_end <= x_start or y_end <= y_start or z_end <= z_start:
362 return None, False, "Invalid slice range"
364 # Extract subvolume
365 subvolume = tomogram_array[z_start:z_end, y_start:y_end, x_start:x_end]
367 # Check if extracted shape matches requested size
368 if subvolume.shape != self.boxsize:
369 # Need to pad the subvolume
370 padded = np.zeros(self.boxsize, dtype=subvolume.dtype)
372 # Calculate padding amounts
373 z_pad_start = max(0, half_box[2] - int(z))
374 y_pad_start = max(0, half_box[1] - int(y))
375 x_pad_start = max(0, half_box[0] - int(x))
377 # Calculate end indices
378 z_pad_end = min(z_pad_start + (z_end - z_start), self.boxsize[0])
379 y_pad_end = min(y_pad_start + (y_end - y_start), self.boxsize[1])
380 x_pad_end = min(x_pad_start + (x_end - x_start), self.boxsize[2])
382 # Copy data
383 padded[z_pad_start:z_pad_end, y_pad_start:y_pad_end, x_pad_start:x_pad_end] = subvolume
384 return padded, True, "padded"
386 return subvolume, True, "valid"
388 def _load_data(self):
389 """Load particle picks data from copick project."""
390 # Determine which root to use
391 if self.copick_root is not None:
392 root = self.copick_root
393 print("Using provided copick root object")
394 else:
395 try:
396 root = copick.from_file(self.config_path)
397 print(f"Loading data from {self.config_path}")
398 except Exception as e:
399 print(f"Failed to load copick root: {str(e)}")
400 return
402 # Store all particle coordinates for background sampling
403 all_particle_coords = []
405 for run in root.runs:
406 print(f"Processing run: {run.name}")
408 # Try to load tomogram
409 try:
410 voxel_spacing_obj = run.get_voxel_spacing(self.voxel_spacing)
411 if (
412 voxel_spacing_obj is None
413 or not hasattr(voxel_spacing_obj, "tomograms")
414 or not voxel_spacing_obj.tomograms
415 ):
416 print(f"No tomograms found for run {run.name} at voxel spacing {self.voxel_spacing}")
417 continue
419 tomogram = voxel_spacing_obj.tomograms[0]
420 tomogram_array = tomogram.numpy()
421 except Exception as e:
422 print(f"Error loading tomogram for run {run.name}: {str(e)}")
423 continue
425 # Process picks
426 run_particle_coords = [] # Store coordinates for this run
428 for picks in run.get_picks():
429 if not picks.from_tool:
430 continue
432 object_name = picks.pickable_object_name
434 try:
435 points, _ = picks.numpy()
436 points = points / self.voxel_spacing
438 for point in points:
439 try:
440 x, y, z = point
442 # Save for background sampling
443 run_particle_coords.append((x, y, z))
445 # Extract subvolume
446 subvolume, is_valid, _ = self._extract_subvolume_with_validation(tomogram_array, x, y, z)
448 if is_valid:
449 self._subvolumes.append(subvolume)
451 if object_name not in self._keys:
452 self._keys.append(object_name)
454 self._molecule_ids.append(self._keys.index(object_name))
455 self._is_background.append(False)
456 except Exception as e:
457 print(f"Error extracting subvolume: {str(e)}")
458 except Exception as e:
459 print(f"Error processing picks for {object_name}: {str(e)}")
461 # Sample background points for this run if needed
462 if self.include_background and run_particle_coords:
463 all_particle_coords.extend(run_particle_coords)
464 self._sample_background_points(tomogram_array, run_particle_coords)
466 self._subvolumes = np.array(self._subvolumes)
467 self._molecule_ids = np.array(self._molecule_ids)
468 self._is_background = np.array(self._is_background)
470 # Apply max_samples limit if specified
471 if self.max_samples is not None and len(self._subvolumes) > self.max_samples:
472 indices = np.random.choice(len(self._subvolumes), self.max_samples, replace=False)
473 self._subvolumes = self._subvolumes[indices]
474 self._molecule_ids = self._molecule_ids[indices]
475 self._is_background = self._is_background[indices]
477 print(f"Loaded {len(self._subvolumes)} subvolumes with {len(self._keys)} classes")
478 print(f"Background samples: {sum(self._is_background)}")
480 def _sample_background_points(self, tomogram_array, particle_coords):
481 """Sample background points away from particles."""
482 if not particle_coords:
483 return
485 # Convert to numpy array for distance calculations
486 particle_coords = np.array(particle_coords)
488 # Calculate number of background samples based on ratio
489 num_particles = len(particle_coords)
490 num_background = int(num_particles * self.background_ratio)
492 # Limit attempts to avoid infinite loop
493 max_attempts = num_background * 10
494 attempts = 0
495 bg_points_found = 0
497 half_box = np.array(self.boxsize) // 2
499 while bg_points_found < num_background and attempts < max_attempts:
500 # Generate random point within tomogram bounds with margin for box extraction
501 random_point = np.array(
502 [
503 np.random.randint(half_box[0], tomogram_array.shape[2] - half_box[0]),
504 np.random.randint(half_box[1], tomogram_array.shape[1] - half_box[1]),
505 np.random.randint(half_box[2], tomogram_array.shape[0] - half_box[2]),
506 ],
507 )
509 # Calculate distances to all particles
510 distances = np.linalg.norm(particle_coords - random_point, axis=1)
512 # Check if point is far enough from all particles
513 if np.min(distances) >= self.min_background_distance:
514 # Extract subvolume
515 x, y, z = random_point
516 subvolume, is_valid, _ = self._extract_subvolume_with_validation(tomogram_array, x, y, z)
518 if is_valid:
519 self._subvolumes.append(subvolume)
520 self._molecule_ids.append(-1) # Use -1 to indicate background
521 self._is_background.append(True)
522 bg_points_found += 1
524 attempts += 1
526 print(f"Added {bg_points_found} background points after {attempts} attempts")
528 def _augment_subvolume(self, subvolume, idx=None):
529 """Apply data augmentation to a subvolume.
531 Args:
532 subvolume: The 3D volume to augment
533 idx: Optional index for mixup augmentation
535 Returns:
536 Augmented subvolume and list of applied augmentations if debug_mode is True
537 """
538 # Track applied augmentations if debug mode is enabled
539 applied_augmentations = []
540 mixup_info = None
542 # Apply standard augmentations with probability
543 for aug in self.augmentations:
544 if random.random() < self.augmentation_prob:
545 if aug == "brightness":
546 delta = np.random.uniform(-0.5, 0.5)
547 subvolume = self._brightness(subvolume, max_delta=0.5)
548 if self.debug_mode:
549 applied_augmentations.append({"type": "brightness", "delta": float(delta)})
551 elif aug == "blur":
552 sigma = np.random.uniform(0.5, 1.5)
553 subvolume = self._gaussian_blur(subvolume, sigma_range=(0.5, 1.5))
554 if self.debug_mode:
555 applied_augmentations.append({"type": "blur", "sigma": float(sigma)})
557 elif aug == "intensity":
558 factor = np.random.uniform(0.5, 1.5)
559 subvolume = self._intensity_scaling(subvolume, intensity_range=(0.5, 1.5))
560 if self.debug_mode:
561 applied_augmentations.append({"type": "intensity", "factor": float(factor)})
563 elif aug == "flip":
564 axis = random.randint(0, 2)
565 subvolume = self._flip(subvolume, axis=axis)
566 if self.debug_mode:
567 applied_augmentations.append({"type": "flip", "axis": int(axis)})
569 elif aug == "rotate":
570 # Filter available axes based on rotate_axes setting
571 available_axes = [i for i, allowed in enumerate(self.rotate_axes) if allowed]
572 if len(available_axes) < 2:
573 axis1, axis2 = 0, 1
574 else:
575 axis1, axis2 = random.sample(available_axes, 2)
576 k = random.randint(1, 3) # 90, 180, or 270 degrees
577 subvolume = self._rotate(subvolume, axes=(axis1, axis2), k=k)
578 if self.debug_mode:
579 applied_augmentations.append({"type": "rotate", "axes": (int(axis1), int(axis2)), "k": int(k)})
581 elif aug == "rotate_z" and "rotate_z" in self.augmentations:
582 angle = np.random.uniform(0, 360)
583 subvolume = self._rotate_z(subvolume, angle=angle)
584 if self.debug_mode:
585 applied_augmentations.append({"type": "rotate_z", "angle": float(angle)})
587 # Apply mixup if enabled and we have an index to work with
588 if "mixup" in self.augmentations and self.mixup_alpha is not None and idx is not None: # noqa: SIM102
589 if random.random() < self.augmentation_prob:
590 subvolume, mixup_other_idx, mixup_lambda = self._apply_mixup(subvolume, idx)
591 if self.debug_mode and mixup_other_idx is not None:
592 mixup_info = {"other_idx": int(mixup_other_idx), "lambda": float(mixup_lambda)}
594 if self.debug_mode:
595 return subvolume, applied_augmentations, mixup_info
596 else:
597 return subvolume
599 def _brightness(self, volume, max_delta=0.5):
600 """Adjust brightness of a volume."""
601 delta = np.random.uniform(-max_delta, max_delta)
602 return volume + delta
604 def _gaussian_blur(self, volume, sigma_range=(0.5, 1.5)):
605 """Apply Gaussian blur to a volume."""
606 sigma = np.random.uniform(*sigma_range)
607 return gaussian_filter(volume, sigma=sigma)
609 def _intensity_scaling(self, volume, intensity_range=(0.5, 1.5)):
610 """Scale the intensity of a volume."""
611 intensity_factor = np.random.uniform(*intensity_range)
612 return volume * intensity_factor
614 def _flip(self, volume, axis=None):
615 """Flip the volume along specified axis or a random axis if not specified."""
616 if axis is None:
617 axis = random.randint(0, 2)
618 return np.flip(volume, axis=axis)
620 def _rotate(self, volume, axes=None, k=None):
621 """Rotate the volume 90, 180, or 270 degrees around specified or allowed axes.
623 Args:
624 volume: The 3D volume to rotate
625 axes: Optional tuple of (axis1, axis2) to rotate around
626 k: Optional number of 90-degree rotations (1, 2, or 3)
628 Returns:
629 Rotated volume
630 """
631 # If axes not specified, select from available axes
632 if axes is None:
633 # Filter available axes based on rotate_axes setting
634 available_axes = [i for i, allowed in enumerate(self.rotate_axes) if allowed]
636 if len(available_axes) < 2:
637 # Need at least 2 axes for rotation, if not enough are enabled,
638 # default to standard x-y rotation
639 axes = (0, 1)
640 else:
641 # Select two random axes from the available ones
642 axes = tuple(random.sample(available_axes, 2))
644 # If k not specified, choose random rotation
645 if k is None:
646 k = random.randint(1, 3) # 90, 180, or 270 degrees
648 return np.rot90(volume, k=k, axes=axes)
650 def _rotate_z(self, volume, angle=None):
651 """Apply rotation specifically around z-axis.
653 Args:
654 volume: The 3D volume to rotate
655 angle: Optional rotation angle in degrees (0-360)
657 Returns:
658 Rotated volume
659 """
660 # For z-rotation, we'll rotate around the first dimension (z-axis)
661 # using alternative methods that allow arbitrary angles
662 if angle is None:
663 angle = np.random.uniform(0, 360) # Random angle in degrees
665 # Get center coordinates
666 center_z, center_y, center_x = np.array(volume.shape) // 2
668 # Create rotation matrix for z-axis rotation
669 theta = np.radians(angle)
670 c, s = np.cos(theta), np.sin(theta)
671 rotation_matrix = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
673 # Create coordinates grid
674 z, y, x = np.meshgrid(
675 np.arange(volume.shape[0]),
676 np.arange(volume.shape[1]),
677 np.arange(volume.shape[2]),
678 indexing="ij",
679 )
681 # Adjust coordinates to be relative to center
682 z -= center_z
683 y -= center_y
684 x -= center_x
686 # Stack coordinates and reshape
687 coords = np.stack([z.flatten(), y.flatten(), x.flatten()])
689 # Apply rotation
690 rotated_coords = np.dot(rotation_matrix, coords)
692 # Reshape back and adjust to original coordinate system
693 z_rot = rotated_coords[0].reshape(volume.shape) + center_z
694 y_rot = rotated_coords[1].reshape(volume.shape) + center_y
695 x_rot = rotated_coords[2].reshape(volume.shape) + center_x
697 # Interpolate using scipy map_coordinates
698 from scipy.ndimage import map_coordinates
700 rotated_volume = map_coordinates(volume, [z_rot, y_rot, x_rot], order=1, mode="constant")
702 return rotated_volume
704 def _apply_mixup(self, subvolume, idx):
705 """Apply mixup augmentation by blending with another random sample.
707 Mixup is a data augmentation technique that creates virtual training examples
708 by mixing pairs of inputs and their labels with random proportions.
710 Args:
711 subvolume: The current subvolume being processed
712 idx: Index of the current subvolume to avoid mixing with itself
714 Returns:
715 Tuple of (mixed subvolume, other_idx, lambda) for complete mixup
716 """
717 if len(self._subvolumes) <= 1:
718 return subvolume, None, 1.0
720 # Select a different index at random
721 other_idx = idx
722 while other_idx == idx:
723 other_idx = random.randint(0, len(self._subvolumes) - 1)
725 # Get the other subvolume
726 other_subvolume = self._subvolumes[other_idx].copy()
728 # Sample lambda from beta distribution
729 if self.mixup_alpha > 0:
730 lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
731 else:
732 lam = 0.5 # Equal mix if alpha not provided
734 # Mix the subvolumes
735 mixed_subvolume = lam * subvolume + (1 - lam) * other_subvolume
737 # Return the mixed subvolume, the other sample's index, and the lambda value
738 # This allows the __getitem__ method to properly handle the labels
739 return mixed_subvolume, other_idx, lam
741 def __len__(self):
742 """Get the total number of items in the dataset."""
743 return len(self._subvolumes)
745 def __getitem__(self, idx):
746 """Get an item from the dataset with proper mixup handling and augmentation tracking.
748 Returns:
749 tuple: (subvolume, label_dict)
751 Where label_dict contains:
752 - 'class_idx': Original class index (or primary class index if mixed)
753 - 'is_mixed': Boolean indicating if mixup was applied
754 - 'mix_lambda': Lambda value for mixup (1.0 if no mixup)
755 - 'mix_class_idx': Secondary class index for mixup (None if no mixup)
756 - 'applied_augmentations': List of applied augmentations (if debug_mode=True)
757 """
758 subvolume = self._subvolumes[idx].copy()
759 molecule_idx = self._molecule_ids[idx]
761 # Initialize label dictionary with default values
762 label_dict = {"class_idx": molecule_idx, "is_mixed": False, "mix_lambda": 1.0, "mix_class_idx": None}
764 # Track augmentations if debug mode is enabled
765 if self.debug_mode:
766 label_dict["applied_augmentations"] = []
768 if self.augment:
769 if self.debug_mode:
770 # Apply augmentations with tracking in debug mode
771 subvolume, applied_augmentations, mixup_info = self._augment_subvolume(subvolume, idx)
773 # Store augmentation information in label dictionary
774 label_dict["applied_augmentations"] = applied_augmentations
776 # Update mixup information if applicable
777 if mixup_info is not None:
778 mixup_other_idx = mixup_info["other_idx"]
779 mixup_lambda = mixup_info["lambda"]
780 other_molecule_idx = self._molecule_ids[mixup_other_idx]
781 label_dict.update(
782 {"is_mixed": True, "mix_lambda": mixup_lambda, "mix_class_idx": other_molecule_idx},
783 )
784 else:
785 # Standard approach without tracking
786 for aug in self.augmentations:
787 if random.random() < self.augmentation_prob:
788 if aug == "brightness":
789 subvolume = self._brightness(subvolume)
790 elif aug == "blur":
791 subvolume = self._gaussian_blur(subvolume)
792 elif aug == "intensity":
793 subvolume = self._intensity_scaling(subvolume)
794 elif aug == "flip":
795 subvolume = self._flip(subvolume)
796 elif aug == "rotate":
797 subvolume = self._rotate(subvolume)
798 elif aug == "rotate_z" and "rotate_z" in self.augmentations:
799 subvolume = self._rotate_z(subvolume)
801 # Apply mixup separately to capture its metadata
802 if "mixup" in self.augmentations and self.mixup_alpha is not None:
803 if random.random() < self.augmentation_prob:
804 subvolume, mixup_other_idx, mixup_lambda = self._apply_mixup(subvolume, idx)
806 # Update label dictionary with mixup information if applicable
807 if mixup_other_idx is not None:
808 other_molecule_idx = self._molecule_ids[mixup_other_idx]
809 label_dict.update(
810 {"is_mixed": True, "mix_lambda": mixup_lambda, "mix_class_idx": other_molecule_idx},
811 )
813 # Normalize
814 subvolume = (subvolume - np.mean(subvolume)) / (np.std(subvolume) + 1e-6)
816 # Add channel dimension and convert to tensor
817 subvolume = torch.as_tensor(subvolume[None, ...], dtype=torch.float32)
819 return subvolume, label_dict
821 def get_sample_weights(self):
822 """Return sample weights for use in a WeightedRandomSampler."""
823 return self.sample_weights
825 def keys(self):
826 """Get pickable object keys."""
827 return self._keys
829 def examples(self):
830 """Get example volumes for each class."""
831 # Check if dataset is empty
832 if len(self._subvolumes) == 0 or len(self._molecule_ids) == 0:
833 return None, []
835 class_examples = {}
836 example_tensors = []
837 example_labels = []
839 # Get examples for regular classes
840 for cls in range(len(self._keys)):
841 # Find first index for this class
842 for i, mol_id in enumerate(self._molecule_ids):
843 if (
844 mol_id == cls
845 and cls not in class_examples
846 and (not self._is_background or not self._is_background[i])
847 ):
848 try:
849 volume, _ = self[i]
850 example_tensors.append(volume)
851 example_labels.append(cls)
852 class_examples[cls] = i
853 break
854 except Exception as e:
855 print(f"Error extracting example for class {cls}: {str(e)}")
856 continue
858 # Add background example if present
859 if self.include_background and self._is_background and any(self._is_background):
860 # Find first background sample
861 for i, is_bg in enumerate(self._is_background):
862 if is_bg:
863 try:
864 volume, _ = self[i]
865 example_tensors.append(volume)
866 example_labels.append(-1) # Use -1 for background
867 break
868 except Exception as e:
869 print(f"Error extracting background example: {str(e)}")
870 continue
872 if example_tensors:
873 return torch.stack(example_tensors), [
874 "background" if label == -1 else self._keys[label] for label in example_labels
875 ]
876 return None, []
878 def get_class_distribution(self):
879 """Get distribution of classes in the dataset."""
880 class_counts = Counter(self._molecule_ids)
882 # Create a readable distribution
883 distribution = {}
885 # Count background samples if any
886 if -1 in class_counts:
887 distribution["background"] = class_counts[-1]
888 del class_counts[-1]
890 # Count regular classes
891 for cls_idx, count in class_counts.items():
892 if 0 <= cls_idx < len(self._keys):
893 distribution[self._keys[cls_idx]] = count
895 return distribution
897 def stratified_split(self, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=None):
898 """Split the dataset into train, validation, and test sets while preserving class distributions.
900 Args:
901 train_ratio: Proportion of data to use for training
902 val_ratio: Proportion of data to use for validation
903 test_ratio: Proportion of data to use for testing
904 seed: Random seed for reproducibility
906 Returns:
907 Tuple of (train_dataset, val_dataset, test_dataset) as Subset objects
908 """
909 # Validate ratios
910 if not np.isclose(train_ratio + val_ratio + test_ratio, 1.0):
911 raise ValueError("Ratios must sum to 1.0")
913 # Set random seed if provided
914 if seed is not None:
915 np.random.seed(seed)
917 # Get indices for each class, including background if present
918 class_indices = {}
919 for i, mol_id in enumerate(self._molecule_ids):
920 if mol_id not in class_indices:
921 class_indices[mol_id] = []
922 class_indices[mol_id].append(i)
924 # Shuffle indices for each class
925 for mol_id in class_indices:
926 np.random.shuffle(class_indices[mol_id])
928 # Split indices for each class according to ratios
929 train_indices = []
930 val_indices = []
931 test_indices = []
933 for mol_id, indices in class_indices.items():
934 n_samples = len(indices)
935 n_train = int(n_samples * train_ratio)
936 n_val = int(n_samples * val_ratio)
938 # Assign indices to splits
939 train_indices.extend(indices[:n_train])
940 val_indices.extend(indices[n_train : n_train + n_val])
941 test_indices.extend(indices[n_train + n_val :])
943 # Shuffle the final indices
944 np.random.shuffle(train_indices)
945 np.random.shuffle(val_indices)
946 np.random.shuffle(test_indices)
948 # Create subset datasets
949 train_dataset = Subset(self, train_indices)
950 val_dataset = Subset(self, val_indices)
951 test_dataset = Subset(self, test_indices)
953 # Print split information
954 print(
955 f"Dataset split: {len(train_dataset)} train, {len(val_dataset)} validation, {len(test_dataset)} test samples",
956 )
958 return train_dataset, val_dataset, test_dataset
960 def balance_classes(self, method="oversample", target_ratio=1.0, exclude_background=False):
961 """Balance class distribution in the dataset.
963 Args:
964 method: Balancing method to use ('oversample' or 'undersample')
965 target_ratio: For partial balancing (1.0 = perfect balance)
966 exclude_background: Whether to exclude background class from balancing
968 Returns:
969 A new CopickDataset instance with balanced classes
970 """
971 # Validate parameters
972 if method not in ["oversample", "undersample"]:
973 raise ValueError("method must be either 'oversample' or 'undersample'")
975 if target_ratio <= 0 or target_ratio > 1.0:
976 raise ValueError("target_ratio must be between 0 and 1.0")
978 # Get class distribution
979 class_indices = {}
980 for i, mol_id in enumerate(self._molecule_ids):
981 # Skip background class if requested
982 if exclude_background and mol_id == -1:
983 continue
985 if mol_id not in class_indices:
986 class_indices[mol_id] = []
987 class_indices[mol_id].append(i)
989 class_counts = {mol_id: len(indices) for mol_id, indices in class_indices.items()}
990 print("Original class distribution:")
991 for mol_id, count in class_counts.items():
992 class_name = "background" if mol_id == -1 else self._keys[mol_id]
993 print(f" {class_name}: {count} samples")
995 # Determine target counts
996 if method == "oversample":
997 # Oversample minority classes to match majority class
998 max_count = max(class_counts.values())
999 target_counts = {}
1000 for mol_id, count in class_counts.items():
1001 # Calculate the target count for this class
1002 # At target_ratio=1.0, all classes will have max_count samples
1003 # At lower ratios, there will be partial balancing
1004 deficit = max_count - count
1005 target_counts[mol_id] = count + int(deficit * target_ratio)
1007 else: # undersample
1008 # Undersample majority classes to match minority class
1009 min_count = min(class_counts.values())
1010 target_counts = {}
1011 for mol_id, count in class_counts.items():
1012 # Calculate the target count for this class
1013 # At target_ratio=1.0, all classes will have min_count samples
1014 # At lower ratios, there will be partial balancing towards min_count
1015 excess = count - min_count
1016 target_counts[mol_id] = count - int(excess * target_ratio)
1018 # Create new balanced dataset
1019 new_subvolumes = []
1020 new_molecule_ids = []
1021 new_is_background = []
1023 # Process each class
1024 for mol_id, indices in class_indices.items():
1025 current_count = len(indices)
1026 target_count = target_counts[mol_id]
1028 if target_count <= current_count:
1029 # Undersample: randomly select subset of samples
1030 selected_indices = np.random.choice(indices, target_count, replace=False)
1031 for idx in selected_indices:
1032 new_subvolumes.append(self._subvolumes[idx].copy())
1033 new_molecule_ids.append(self._molecule_ids[idx])
1034 new_is_background.append(self._is_background[idx])
1035 else:
1036 # Oversample: use all original samples and add duplicates with augmentation
1037 # First, add all original samples
1038 for idx in indices:
1039 new_subvolumes.append(self._subvolumes[idx].copy())
1040 new_molecule_ids.append(self._molecule_ids[idx])
1041 new_is_background.append(self._is_background[idx])
1043 # Then add duplicates with augmentation to reach target count
1044 n_duplicates = target_count - current_count
1045 duplicate_indices = np.random.choice(indices, n_duplicates, replace=True)
1047 for idx in duplicate_indices:
1048 # Apply some basic augmentation to avoid exact duplicates
1049 augmented = self._subvolumes[idx].copy()
1051 if self.debug_mode:
1052 augmented, applied_augs, _ = self._augment_subvolume(augmented)
1053 else:
1054 # Just apply some basic augmentations
1055 augmented = self._flip(augmented)
1056 if random.random() < 0.5:
1057 augmented = self._brightness(augmented)
1058 if random.random() < 0.5:
1059 augmented = self._intensity_scaling(augmented)
1061 new_subvolumes.append(augmented)
1062 new_molecule_ids.append(self._molecule_ids[idx])
1063 new_is_background.append(self._is_background[idx])
1065 # Convert to numpy arrays
1066 new_subvolumes = np.array(new_subvolumes)
1067 new_molecule_ids = np.array(new_molecule_ids)
1068 new_is_background = np.array(new_is_background)
1070 # Create a new dataset with balanced classes
1071 balanced_dataset = CopickDataset(
1072 config_path=self.config_path,
1073 boxsize=self.boxsize,
1074 augment=self.augment,
1075 cache_dir=None, # Don't use caching for the balanced dataset
1076 seed=self.seed,
1077 voxel_spacing=self.voxel_spacing,
1078 include_background=self.include_background,
1079 patch_strategy=self.patch_strategy,
1080 debug_mode=self.debug_mode,
1081 )
1083 # Replace data with balanced data
1084 balanced_dataset._subvolumes = new_subvolumes
1085 balanced_dataset._molecule_ids = new_molecule_ids
1086 balanced_dataset._is_background = new_is_background
1087 balanced_dataset._keys = self._keys.copy()
1089 # Compute new sample weights
1090 balanced_dataset._compute_sample_weights()
1092 # Print final distribution
1093 balanced_dist = balanced_dataset.get_class_distribution()
1094 print("Balanced class distribution:")
1095 for class_name, count in balanced_dist.items():
1096 print(f" {class_name}: {count} samples")
1098 return balanced_dataset
1100 def extract_grid_patches(self, patch_size, overlap=0.25, normalize=True, run_index=0, tomo_type="raw"):
1101 """Extract a grid of patches from a tomogram.
1103 Args:
1104 patch_size: Int or tuple (z, y, x) for patch dimensions
1105 overlap: Overlap ratio between adjacent patches (0-1)
1106 normalize: Whether to normalize patches
1107 run_index: Index of the run to extract from
1108 tomo_type: Type of tomogram to extract from ('raw' or 'filtered')
1110 Returns:
1111 List of extracted patches and their coordinates (z, y, x)
1112 """
1113 # Validate parameters
1114 if isinstance(patch_size, int):
1115 patch_size = (patch_size, patch_size, patch_size)
1116 elif len(patch_size) != 3:
1117 raise ValueError("patch_size must be an integer or tuple of 3 integers")
1119 if overlap < 0 or overlap >= 1:
1120 raise ValueError("overlap must be between 0 and 1")
1122 # Get tomogram data
1123 try:
1124 root = copick.from_file(self.config_path)
1125 if not root.runs:
1126 raise ValueError("No runs found in the copick project")
1128 # Use the specified run
1129 if run_index >= len(root.runs):
1130 raise ValueError(f"Run index {run_index} out of range. Only {len(root.runs)} runs available.")
1132 run = root.runs[run_index]
1134 # Get the tomogram based on voxel spacing
1135 tomogram = run.get_voxel_spacing(self.voxel_spacing).tomograms[0]
1137 # Get the appropriate tomogram type
1138 if tomo_type == "raw":
1139 tomogram_array = tomogram.numpy()
1140 elif tomo_type == "filtered":
1141 # Check if filtered data is available
1142 if hasattr(tomogram, "filtered") and tomogram.filtered is not None:
1143 tomogram_array = tomogram.filtered.numpy()
1144 else:
1145 print("Warning: Filtered tomogram not available, using raw tomogram instead")
1146 tomogram_array = tomogram.numpy()
1147 else:
1148 raise ValueError(f"Invalid tomogram type: {tomo_type}. Must be 'raw' or 'filtered'")
1150 # Calculate stride (step size between patches)
1151 stride_z = int(patch_size[0] * (1 - overlap))
1152 stride_y = int(patch_size[1] * (1 - overlap))
1153 stride_x = int(patch_size[2] * (1 - overlap))
1155 # Ensure stride is at least 1
1156 stride_z = max(1, stride_z)
1157 stride_y = max(1, stride_y)
1158 stride_x = max(1, stride_x)
1160 # Calculate number of patches in each dimension
1161 n_patches_z = 1 + (tomogram_array.shape[0] - patch_size[0]) // stride_z
1162 n_patches_y = 1 + (tomogram_array.shape[1] - patch_size[1]) // stride_y
1163 n_patches_x = 1 + (tomogram_array.shape[2] - patch_size[2]) // stride_x
1165 # Initialize results
1166 patches = []
1167 coordinates = []
1169 # Extract patches
1170 for iz in range(n_patches_z):
1171 z_start = iz * stride_z
1172 z_end = z_start + patch_size[0]
1173 if z_end > tomogram_array.shape[0]:
1174 continue
1176 for iy in range(n_patches_y):
1177 y_start = iy * stride_y
1178 y_end = y_start + patch_size[1]
1179 if y_end > tomogram_array.shape[1]:
1180 continue
1182 for ix in range(n_patches_x):
1183 x_start = ix * stride_x
1184 x_end = x_start + patch_size[2]
1185 if x_end > tomogram_array.shape[2]:
1186 continue
1188 # Extract the patch
1189 patch = tomogram_array[z_start:z_end, y_start:y_end, x_start:x_end].copy()
1191 # Normalize if requested
1192 if normalize:
1193 # Center and scale to unit variance
1194 patch = (patch - np.mean(patch)) / (np.std(patch) + 1e-6)
1196 # Record patch and its center coordinates
1197 patches.append(patch)
1198 coordinates.append(
1199 (z_start + patch_size[0] // 2, y_start + patch_size[1] // 2, x_start + patch_size[2] // 2),
1200 )
1202 print(f"Extracted {len(patches)} patches of size {patch_size} with {overlap:.2f} overlap")
1203 return patches, coordinates
1205 except Exception as e:
1206 print(f"Error extracting grid patches: {str(e)}")
1207 raise
1209 def extract_from_region(self, x_range, y_range, z_range, tomo_type="raw"):
1210 """Extract a specific region from a tomogram.
1212 Args:
1213 x_range: Tuple of (min_x, max_x) in voxel space
1214 y_range: Tuple of (min_y, max_y) in voxel space
1215 z_range: Tuple of (min_z, max_z) in voxel space
1216 tomo_type: Type of tomogram to extract from ('raw' or 'filtered')
1218 Returns:
1219 A numpy array containing the extracted region
1220 """
1221 # Validate ranges
1222 if not all(isinstance(r, tuple) and len(r) == 2 for r in [x_range, y_range, z_range]):
1223 raise ValueError("Range parameters must be tuples of (min, max)")
1225 # Get tomogram data
1226 try:
1227 root = copick.from_file(self.config_path)
1228 if not root.runs:
1229 raise ValueError("No runs found in the copick project")
1231 # Use the first run by default
1232 run = root.runs[0]
1234 # Get the tomogram based on voxel spacing
1235 tomogram = run.get_voxel_spacing(self.voxel_spacing).tomograms[0]
1237 # Get the appropriate tomogram type
1238 if tomo_type == "raw":
1239 tomogram_array = tomogram.numpy()
1240 elif tomo_type == "filtered":
1241 # Check if filtered data is available
1242 if hasattr(tomogram, "filtered") and tomogram.filtered is not None:
1243 tomogram_array = tomogram.filtered.numpy()
1244 else:
1245 print("Warning: Filtered tomogram not available, using raw tomogram instead")
1246 tomogram_array = tomogram.numpy()
1247 else:
1248 raise ValueError(f"Invalid tomogram type: {tomo_type}. Must be 'raw' or 'filtered'")
1250 # Extract the requested region
1251 min_x, max_x = x_range
1252 min_y, max_y = y_range
1253 min_z, max_z = z_range
1255 # Convert to integer indices and ensure they're within bounds
1256 min_z = max(0, int(min_z))
1257 max_z = min(tomogram_array.shape[0], int(max_z))
1258 min_y = max(0, int(min_y))
1259 max_y = min(tomogram_array.shape[1], int(max_y))
1260 min_x = max(0, int(min_x))
1261 max_x = min(tomogram_array.shape[2], int(max_x))
1263 # Extract the region
1264 region = tomogram_array[min_z:max_z, min_y:max_y, min_x:max_x]
1266 if region.size == 0:
1267 raise ValueError("Extracted region is empty. Check range parameters.")
1269 return region
1271 except Exception as e:
1272 print(f"Error extracting region from tomogram: {str(e)}")
1273 raise