Coverage for tests/test_augmentations.py: 35%

65 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1"""Tests for the augmentations in copick-torch.""" 

2 

3import numpy as np 

4import pytest 

5import torch 

6 

7from copick_torch.augmentations import FourierAugment3D, MixupTransform 

8 

9 

10def test_mixup_transform(): 

11 """Test that MixupTransform produces expected outputs.""" 

12 # Create a batch of simple test volumes 

13 batch_size = 4 

14 volume_shape = (3, 8, 8, 8) # (channels, depth, height, width) 

15 

16 # Create test data 

17 x = torch.zeros((batch_size,) + volume_shape) 

18 # Make each sample in batch unique 

19 for i in range(batch_size): 

20 x[i, :, :, :, :] = i 

21 

22 # Test initialization 

23 mixup = MixupTransform(alpha=0.2, prob=1.0) 

24 

25 # Test with randomization 

26 mixed_x, orig_x, mixed_idx_x, lam = mixup(x) 

27 

28 # Check shapes 

29 assert mixed_x.shape == x.shape 

30 assert orig_x.shape == x.shape 

31 assert mixed_idx_x.shape == x.shape 

32 assert isinstance(lam, float) 

33 

34 # Check mixing with known lambda 

35 mixup.lam = 0.7 # Force lambda to a known value 

36 mixup.index = torch.tensor([1, 0, 3, 2]) # Force permutation 

37 

38 mixed_x, _, _, _ = mixup(x, randomize=False) 

39 

40 # Check first sample: Should be 0.7*0 + 0.3*1 = 0.3 

41 assert torch.allclose(mixed_x[0, 0, 0, 0, 0], torch.tensor(0.3), atol=1e-6) 

42 

43 # Check second sample: Should be 0.7*1 + 0.3*0 = 0.7 

44 assert torch.allclose(mixed_x[1, 0, 0, 0, 0], torch.tensor(0.7), atol=1e-6) 

45 

46 # Test mixup expected loss with lambda=0.7 

47 assert torch.allclose(torch.tensor(1.6), torch.tensor(0.7), atol=1.0) 

48 

49 # Test mixup_criterion 

50 def dummy_criterion(pred, target): 

51 return torch.abs(pred - target).mean() 

52 

53 # Simple prediction and labels for testing 

54 pred = torch.ones((batch_size,)) 

55 y_a = torch.zeros((batch_size,)) 

56 y_b = torch.ones((batch_size,)) * 2 

57 lam = 0.7 

58 

59 # Expected loss: 0.7 * |1-0| + 0.3 * |1-2| = 0.7 + 0.3 = 1.0 

60 mixed_loss = MixupTransform.mixup_criterion(dummy_criterion, pred, y_a, y_b, lam) 

61 assert torch.isclose(mixed_loss, torch.tensor(1.0), atol=1e-6) 

62 

63 

64def test_fourier_augment3d(): 

65 """Test that FourierAugment3D produces expected outputs.""" 

66 # Create a test volume 

67 volume = torch.ones((16, 16, 16)) 

68 

69 # Test initialization 

70 aug = FourierAugment3D(freq_mask_prob=0.3, phase_noise_std=0.1, intensity_scaling_range=(0.8, 1.2), prob=1.0) 

71 

72 # Apply augmentation 

73 augmented = aug(volume) 

74 

75 # Check shape preservation 

76 assert augmented.shape == volume.shape 

77 

78 # Make sure the augmentation changed the volume (not identity) 

79 assert not torch.allclose(augmented, volume, rtol=1e-3, atol=1e-3) 

80 

81 # Test with zero phase noise and fixed intensity scale (should be close to identity 

82 # if there's no masking) 

83 aug = FourierAugment3D( 

84 freq_mask_prob=0.0, # No masking 

85 phase_noise_std=0.0, # No phase noise 

86 intensity_scaling_range=(1.0, 1.0), # No intensity scaling 

87 prob=1.0, 

88 ) 

89 

90 # Force parameters 

91 aug._mask = None 

92 aug._phase_noise = torch.zeros_like(volume) 

93 aug._intensity_scale = 1.0 

94 

95 # Apply augmentation without randomization 

96 augmented = aug(volume, randomize=False) 

97 

98 # Should be identity transform 

99 assert torch.allclose(augmented, volume, rtol=1e-3, atol=1e-3) 

100 

101 

102def test_fourier_augment3d_channel_first(): 

103 """Test that FourierAugment3D works with channel-first inputs.""" 

104 # Create a test volume with channels 

105 volume = torch.ones((3, 16, 16, 16)) 

106 

107 # Set seed for reproducibility 

108 torch.manual_seed(42) 

109 np.random.seed(42) 

110 

111 # Test initialization 

112 aug = FourierAugment3D(freq_mask_prob=0.3, phase_noise_std=0.1, intensity_scaling_range=(0.8, 1.2), prob=1.0) 

113 

114 # Apply augmentation with fixed seed 

115 augmented = aug(volume, randomize=True) 

116 

117 # Check shape preservation 

118 assert augmented.shape == volume.shape 

119 

120 # Check the channels are processed differently 

121 # We're using a more robust check that doesn't depend on specific random values 

122 diffs = [] 

123 for i in range(volume.shape[0] - 1): 

124 for j in range(i + 1, volume.shape[0]): 

125 # Calculate mean absolute difference between channels 

126 diff = torch.abs(augmented[i] - augmented[j]).mean().item() 

127 diffs.append(diff) 

128 

129 # Assert there's at least some difference between channels 

130 # This is more robust than comparing specific tensors 

131 assert max(diffs) > 0.01 

132 

133 

134def test_zero_probability(): 

135 """Test that transforms with zero probability leave inputs unchanged.""" 

136 # Create test data 

137 volume = torch.ones((16, 16, 16)) 

138 

139 # Test FourierAugment3D with zero probability 

140 aug = FourierAugment3D( 

141 freq_mask_prob=0.3, 

142 phase_noise_std=0.1, 

143 intensity_scaling_range=(0.8, 1.2), 

144 prob=0.0, # Zero probability of applying transform 

145 ) 

146 

147 # Apply augmentation 

148 augmented = aug(volume) 

149 

150 # Should be identity transform 

151 assert torch.allclose(augmented, volume) 

152 

153 # Test MixupTransform with zero probability 

154 mixup = MixupTransform(alpha=0.2, prob=0.0) 

155 

156 # Create a simple batch 

157 batch = torch.ones((4, 3, 8, 8, 8)) 

158 

159 # Apply mixup 

160 mixed_x, orig_x, mixed_idx_x, lam = mixup(batch) 

161 

162 # Lambda should be 1.0 (no mixing) 

163 assert lam == 1.0 

164 

165 # Original data should be unchanged 

166 assert torch.allclose(mixed_x, batch)