Coverage for copick_torch/augmentations.py: 35%

89 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1""" 

2Augmentations for 3D volumes based on MONAI transform interface. 

3 

4This module provides MONAI-based implementations of augmentations for 3D tomographic data. 

5""" 

6 

7from typing import Optional, Sequence, Tuple, Union 

8 

9import numpy as np 

10import torch 

11from monai.config.type_definitions import NdarrayOrTensor 

12from monai.transforms import ( 

13 Fourier, 

14 MapTransform, 

15 RandomizableTrait, 

16 RandomizableTransform, 

17 Transform, 

18) 

19from monai.transforms.utils import Fourier as FourierUtils 

20from monai.utils import convert_data_type, convert_to_dst_type, convert_to_tensor 

21 

22 

23class MixupTransform(RandomizableTransform): 

24 """ 

25 Implements Mixup augmentation for 3D volumes based on MONAI transform interface. 

26 

27 Mixup is a data augmentation technique that creates virtual training examples 

28 by mixing pairs of inputs and their labels with a random proportion. 

29 

30 Reference: Zhang et al., "mixup: Beyond Empirical Risk Minimization", ICLR 2018 

31 https://arxiv.org/abs/1710.09412 

32 """ 

33 

34 def __init__(self, alpha: float = 0.2, prob: float = 1.0): 

35 """ 

36 Initialize the Mixup augmentation. 

37 

38 Args: 

39 alpha: Parameter for Beta distribution. Higher values result in more mixing. 

40 prob: Probability of applying the transform. 

41 """ 

42 RandomizableTransform.__init__(self, prob) 

43 self.alpha = alpha 

44 self.lam = 1.0 

45 self.index = None 

46 

47 def randomize(self, data=None) -> None: 

48 """ 

49 Randomize the transform parameters. 

50 """ 

51 super().randomize(None) 

52 if not self._do_transform: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 return None 

54 

55 if self.alpha > 0: 55 ↛ 58line 55 didn't jump to line 58 because the condition on line 55 was always true

56 self.lam = np.random.beta(self.alpha, self.alpha) 

57 else: 

58 self.lam = 1.0 

59 

60 # Comment: Previous implementation had a bug that maximized lambda 

61 # Ensure lambda is between 0 and 1 

62 self.lam = min(max(self.lam, 0.0), 1.0) 

63 

64 def __call__( 

65 self, 

66 img: torch.Tensor, 

67 randomize: bool = True, 

68 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]: 

69 """ 

70 Apply mixup augmentation to a batch of images and labels. 

71 

72 Args: 

73 img: Tensor of shape [batch_size, channels, depth, height, width] 

74 randomize: Whether to execute randomize function first, default to True. 

75 

76 Returns: 

77 Tuple of (mixed_images, label_a, label_b, lam) where: 

78 - mixed_images: The mixup result 

79 - label_a: Original labels 

80 - label_b: Mixed-in labels 

81 - lam: Mixing coefficient from Beta distribution 

82 """ 

83 if randomize: 

84 self.randomize() 

85 

86 if not self._do_transform: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 return img, img, img, 1.0 

88 

89 img = convert_to_tensor(img) 

90 batch_size = img.shape[0] 

91 

92 # Generate random indices for mixing 

93 self.index = torch.randperm(batch_size, device=img.device) 

94 

95 # Mix the images 

96 mixed_images = self.lam * img + (1 - self.lam) * img[self.index] 

97 

98 # Return the mixed images and indices 

99 return mixed_images, img, img[self.index], self.lam 

100 

101 @staticmethod 

102 def mixup_criterion(criterion, pred, y_a, y_b, lam): 

103 """ 

104 Apply mixup to the loss calculation. 

105 

106 Args: 

107 criterion: Loss function 

108 pred: Model predictions 

109 y_a: First labels 

110 y_b: Second (mixed-in) labels 

111 lam: Mixing coefficient 

112 

113 Returns: 

114 Mixed loss 

115 """ 

116 return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 

117 

118 

119class FourierAugment3D(RandomizableTransform, Fourier): 

120 """ 

121 Implements Fourier-based augmentation for 3D volumes based on MONAI transform interface. 

122 

123 This augmentation performs operations in the frequency domain, including 

124 random frequency dropout (masking), phase noise injection, and intensity scaling. 

125 

126 It can help the model become more robust to various frequency distortions that 

127 may occur in tomographic data. 

128 """ 

129 

130 def __init__( 

131 self, 

132 freq_mask_prob: float = 0.3, 

133 phase_noise_std: float = 0.1, 

134 intensity_scaling_range: Tuple[float, float] = (0.8, 1.2), 

135 prob: float = 1.0, 

136 ) -> None: 

137 """ 

138 Initialize the Fourier domain augmentation. 

139 

140 Args: 

141 freq_mask_prob: Probability of masking a frequency component 

142 phase_noise_std: Standard deviation of Gaussian noise added to the phase 

143 intensity_scaling_range: Range for random intensity scaling (min, max) 

144 prob: Probability of applying the transform 

145 """ 

146 RandomizableTransform.__init__(self, prob) 

147 self.freq_mask_prob = freq_mask_prob 

148 self.phase_noise_std = phase_noise_std 

149 self.intensity_scaling_range = intensity_scaling_range 

150 

151 # Randomized parameters 

152 self._mask = None 

153 self._phase_noise = None 

154 self._intensity_scale = None 

155 

156 def randomize(self, spatial_shape=None) -> None: 

157 """ 

158 Randomize the transform parameters. 

159 """ 

160 super().randomize(None) 

161 if not self._do_transform or spatial_shape is None: 

162 return None 

163 

164 # Randomize masking 

165 if np.random.rand() < self.freq_mask_prob: 

166 self._mask = torch.rand(spatial_shape, dtype=torch.float32) > self.freq_mask_prob 

167 else: 

168 self._mask = None 

169 

170 # Randomize phase noise 

171 self._phase_noise = torch.randn(spatial_shape, dtype=torch.float32) * self.phase_noise_std 

172 

173 # Randomize intensity scaling 

174 self._intensity_scale = np.random.uniform( 

175 low=self.intensity_scaling_range[0], 

176 high=self.intensity_scaling_range[1], 

177 ) 

178 

179 def __call__(self, volume: torch.Tensor, randomize: bool = True) -> torch.Tensor: 

180 """ 

181 Apply Fourier domain augmentation to a volume. 

182 

183 Args: 

184 volume: Tensor of shape [depth, height, width] or [channels, depth, height, width] 

185 randomize: Whether to execute randomize function first, default to True. 

186 

187 Returns: 

188 Augmented volume with same shape as input 

189 """ 

190 if randomize: 

191 # Get input shape for randomization 

192 input_shape = volume.shape 

193 spatial_shape = input_shape if len(input_shape) == 3 else input_shape[1:] 

194 self.randomize(spatial_shape) 

195 

196 if not self._do_transform: 

197 return volume 

198 

199 # Ensure volume is a torch tensor 

200 volume = convert_to_tensor(volume) 

201 is_channel_first = len(volume.shape) == 4 

202 

203 if is_channel_first: 

204 # Process each channel independently with different random parameters 

205 # to ensure channel diversity 

206 output = [] 

207 for channel in range(volume.shape[0]): 

208 # Re-randomize parameters for each channel to ensure diversity 

209 if randomize: 

210 self.randomize(volume[channel].shape) 

211 aug_channel = self._apply_fourier_aug(volume[channel]) 

212 output.append(aug_channel) 

213 return torch.stack(output) 

214 else: 

215 # Process 3D volume directly 

216 return self._apply_fourier_aug(volume) 

217 

218 def _apply_fourier_aug(self, vol_tensor: torch.Tensor) -> torch.Tensor: 

219 """ 

220 Apply Fourier augmentation to a single tensor (no channels). 

221 

222 Args: 

223 vol_tensor: 3D tensor of shape [depth, height, width] 

224 

225 Returns: 

226 Augmented tensor of same shape 

227 """ 

228 device = vol_tensor.device 

229 

230 # Move randomized parameters to the same device 

231 if self._mask is not None: 

232 mask = self._mask.to(device) 

233 phase_noise = self._phase_noise.to(device) 

234 

235 # FFT 

236 f_volume = torch.fft.fftn(vol_tensor) 

237 f_shifted = torch.fft.fftshift(f_volume) 

238 

239 # Magnitude and phase 

240 magnitude = torch.abs(f_shifted) 

241 phase = torch.angle(f_shifted) 

242 

243 # 1. Random frequency dropout (mask) 

244 if self._mask is not None: 

245 magnitude = magnitude * mask 

246 

247 # 2. Random phase noise 

248 phase = phase + phase_noise 

249 

250 # 3. Combine magnitude and noisy phase 

251 real = magnitude * torch.cos(phase) 

252 imag = magnitude * torch.sin(phase) 

253 f_augmented = torch.complex(real, imag) 

254 

255 # IFFT 

256 f_ishifted = torch.fft.ifftshift(f_augmented) 

257 augmented_volume = torch.fft.ifftn(f_ishifted).real # Discard imaginary part 

258 

259 # 4. Intensity scaling 

260 augmented_volume *= self._intensity_scale 

261 

262 return augmented_volume