Coverage for copick_torch/augmentations.py: 35%
89 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1"""
2Augmentations for 3D volumes based on MONAI transform interface.
4This module provides MONAI-based implementations of augmentations for 3D tomographic data.
5"""
7from typing import Optional, Sequence, Tuple, Union
9import numpy as np
10import torch
11from monai.config.type_definitions import NdarrayOrTensor
12from monai.transforms import (
13 Fourier,
14 MapTransform,
15 RandomizableTrait,
16 RandomizableTransform,
17 Transform,
18)
19from monai.transforms.utils import Fourier as FourierUtils
20from monai.utils import convert_data_type, convert_to_dst_type, convert_to_tensor
23class MixupTransform(RandomizableTransform):
24 """
25 Implements Mixup augmentation for 3D volumes based on MONAI transform interface.
27 Mixup is a data augmentation technique that creates virtual training examples
28 by mixing pairs of inputs and their labels with a random proportion.
30 Reference: Zhang et al., "mixup: Beyond Empirical Risk Minimization", ICLR 2018
31 https://arxiv.org/abs/1710.09412
32 """
34 def __init__(self, alpha: float = 0.2, prob: float = 1.0):
35 """
36 Initialize the Mixup augmentation.
38 Args:
39 alpha: Parameter for Beta distribution. Higher values result in more mixing.
40 prob: Probability of applying the transform.
41 """
42 RandomizableTransform.__init__(self, prob)
43 self.alpha = alpha
44 self.lam = 1.0
45 self.index = None
47 def randomize(self, data=None) -> None:
48 """
49 Randomize the transform parameters.
50 """
51 super().randomize(None)
52 if not self._do_transform: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 return None
55 if self.alpha > 0: 55 ↛ 58line 55 didn't jump to line 58 because the condition on line 55 was always true
56 self.lam = np.random.beta(self.alpha, self.alpha)
57 else:
58 self.lam = 1.0
60 # Comment: Previous implementation had a bug that maximized lambda
61 # Ensure lambda is between 0 and 1
62 self.lam = min(max(self.lam, 0.0), 1.0)
64 def __call__(
65 self,
66 img: torch.Tensor,
67 randomize: bool = True,
68 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
69 """
70 Apply mixup augmentation to a batch of images and labels.
72 Args:
73 img: Tensor of shape [batch_size, channels, depth, height, width]
74 randomize: Whether to execute randomize function first, default to True.
76 Returns:
77 Tuple of (mixed_images, label_a, label_b, lam) where:
78 - mixed_images: The mixup result
79 - label_a: Original labels
80 - label_b: Mixed-in labels
81 - lam: Mixing coefficient from Beta distribution
82 """
83 if randomize:
84 self.randomize()
86 if not self._do_transform: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 return img, img, img, 1.0
89 img = convert_to_tensor(img)
90 batch_size = img.shape[0]
92 # Generate random indices for mixing
93 self.index = torch.randperm(batch_size, device=img.device)
95 # Mix the images
96 mixed_images = self.lam * img + (1 - self.lam) * img[self.index]
98 # Return the mixed images and indices
99 return mixed_images, img, img[self.index], self.lam
101 @staticmethod
102 def mixup_criterion(criterion, pred, y_a, y_b, lam):
103 """
104 Apply mixup to the loss calculation.
106 Args:
107 criterion: Loss function
108 pred: Model predictions
109 y_a: First labels
110 y_b: Second (mixed-in) labels
111 lam: Mixing coefficient
113 Returns:
114 Mixed loss
115 """
116 return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
119class FourierAugment3D(RandomizableTransform, Fourier):
120 """
121 Implements Fourier-based augmentation for 3D volumes based on MONAI transform interface.
123 This augmentation performs operations in the frequency domain, including
124 random frequency dropout (masking), phase noise injection, and intensity scaling.
126 It can help the model become more robust to various frequency distortions that
127 may occur in tomographic data.
128 """
130 def __init__(
131 self,
132 freq_mask_prob: float = 0.3,
133 phase_noise_std: float = 0.1,
134 intensity_scaling_range: Tuple[float, float] = (0.8, 1.2),
135 prob: float = 1.0,
136 ) -> None:
137 """
138 Initialize the Fourier domain augmentation.
140 Args:
141 freq_mask_prob: Probability of masking a frequency component
142 phase_noise_std: Standard deviation of Gaussian noise added to the phase
143 intensity_scaling_range: Range for random intensity scaling (min, max)
144 prob: Probability of applying the transform
145 """
146 RandomizableTransform.__init__(self, prob)
147 self.freq_mask_prob = freq_mask_prob
148 self.phase_noise_std = phase_noise_std
149 self.intensity_scaling_range = intensity_scaling_range
151 # Randomized parameters
152 self._mask = None
153 self._phase_noise = None
154 self._intensity_scale = None
156 def randomize(self, spatial_shape=None) -> None:
157 """
158 Randomize the transform parameters.
159 """
160 super().randomize(None)
161 if not self._do_transform or spatial_shape is None:
162 return None
164 # Randomize masking
165 if np.random.rand() < self.freq_mask_prob:
166 self._mask = torch.rand(spatial_shape, dtype=torch.float32) > self.freq_mask_prob
167 else:
168 self._mask = None
170 # Randomize phase noise
171 self._phase_noise = torch.randn(spatial_shape, dtype=torch.float32) * self.phase_noise_std
173 # Randomize intensity scaling
174 self._intensity_scale = np.random.uniform(
175 low=self.intensity_scaling_range[0],
176 high=self.intensity_scaling_range[1],
177 )
179 def __call__(self, volume: torch.Tensor, randomize: bool = True) -> torch.Tensor:
180 """
181 Apply Fourier domain augmentation to a volume.
183 Args:
184 volume: Tensor of shape [depth, height, width] or [channels, depth, height, width]
185 randomize: Whether to execute randomize function first, default to True.
187 Returns:
188 Augmented volume with same shape as input
189 """
190 if randomize:
191 # Get input shape for randomization
192 input_shape = volume.shape
193 spatial_shape = input_shape if len(input_shape) == 3 else input_shape[1:]
194 self.randomize(spatial_shape)
196 if not self._do_transform:
197 return volume
199 # Ensure volume is a torch tensor
200 volume = convert_to_tensor(volume)
201 is_channel_first = len(volume.shape) == 4
203 if is_channel_first:
204 # Process each channel independently with different random parameters
205 # to ensure channel diversity
206 output = []
207 for channel in range(volume.shape[0]):
208 # Re-randomize parameters for each channel to ensure diversity
209 if randomize:
210 self.randomize(volume[channel].shape)
211 aug_channel = self._apply_fourier_aug(volume[channel])
212 output.append(aug_channel)
213 return torch.stack(output)
214 else:
215 # Process 3D volume directly
216 return self._apply_fourier_aug(volume)
218 def _apply_fourier_aug(self, vol_tensor: torch.Tensor) -> torch.Tensor:
219 """
220 Apply Fourier augmentation to a single tensor (no channels).
222 Args:
223 vol_tensor: 3D tensor of shape [depth, height, width]
225 Returns:
226 Augmented tensor of same shape
227 """
228 device = vol_tensor.device
230 # Move randomized parameters to the same device
231 if self._mask is not None:
232 mask = self._mask.to(device)
233 phase_noise = self._phase_noise.to(device)
235 # FFT
236 f_volume = torch.fft.fftn(vol_tensor)
237 f_shifted = torch.fft.fftshift(f_volume)
239 # Magnitude and phase
240 magnitude = torch.abs(f_shifted)
241 phase = torch.angle(f_shifted)
243 # 1. Random frequency dropout (mask)
244 if self._mask is not None:
245 magnitude = magnitude * mask
247 # 2. Random phase noise
248 phase = phase + phase_noise
250 # 3. Combine magnitude and noisy phase
251 real = magnitude * torch.cos(phase)
252 imag = magnitude * torch.sin(phase)
253 f_augmented = torch.complex(real, imag)
255 # IFFT
256 f_ishifted = torch.fft.ifftshift(f_augmented)
257 augmented_volume = torch.fft.ifftn(f_ishifted).real # Discard imaginary part
259 # 4. Intensity scaling
260 augmented_volume *= self._intensity_scale
262 return augmented_volume