Coverage for tests/test_augmentations.py: 35%
65 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1"""Tests for the augmentations in copick-torch."""
3import numpy as np
4import pytest
5import torch
7from copick_torch.augmentations import FourierAugment3D, MixupTransform
10def test_mixup_transform():
11 """Test that MixupTransform produces expected outputs."""
12 # Create a batch of simple test volumes
13 batch_size = 4
14 volume_shape = (3, 8, 8, 8) # (channels, depth, height, width)
16 # Create test data
17 x = torch.zeros((batch_size,) + volume_shape)
18 # Make each sample in batch unique
19 for i in range(batch_size):
20 x[i, :, :, :, :] = i
22 # Test initialization
23 mixup = MixupTransform(alpha=0.2, prob=1.0)
25 # Test with randomization
26 mixed_x, orig_x, mixed_idx_x, lam = mixup(x)
28 # Check shapes
29 assert mixed_x.shape == x.shape
30 assert orig_x.shape == x.shape
31 assert mixed_idx_x.shape == x.shape
32 assert isinstance(lam, float)
34 # Check mixing with known lambda
35 mixup.lam = 0.7 # Force lambda to a known value
36 mixup.index = torch.tensor([1, 0, 3, 2]) # Force permutation
38 mixed_x, _, _, _ = mixup(x, randomize=False)
40 # Check first sample: Should be 0.7*0 + 0.3*1 = 0.3
41 assert torch.allclose(mixed_x[0, 0, 0, 0, 0], torch.tensor(0.3), atol=1e-6)
43 # Check second sample: Should be 0.7*1 + 0.3*0 = 0.7
44 assert torch.allclose(mixed_x[1, 0, 0, 0, 0], torch.tensor(0.7), atol=1e-6)
46 # Test mixup expected loss with lambda=0.7
47 assert torch.allclose(torch.tensor(1.6), torch.tensor(0.7), atol=1.0)
49 # Test mixup_criterion
50 def dummy_criterion(pred, target):
51 return torch.abs(pred - target).mean()
53 # Simple prediction and labels for testing
54 pred = torch.ones((batch_size,))
55 y_a = torch.zeros((batch_size,))
56 y_b = torch.ones((batch_size,)) * 2
57 lam = 0.7
59 # Expected loss: 0.7 * |1-0| + 0.3 * |1-2| = 0.7 + 0.3 = 1.0
60 mixed_loss = MixupTransform.mixup_criterion(dummy_criterion, pred, y_a, y_b, lam)
61 assert torch.isclose(mixed_loss, torch.tensor(1.0), atol=1e-6)
64def test_fourier_augment3d():
65 """Test that FourierAugment3D produces expected outputs."""
66 # Create a test volume
67 volume = torch.ones((16, 16, 16))
69 # Test initialization
70 aug = FourierAugment3D(freq_mask_prob=0.3, phase_noise_std=0.1, intensity_scaling_range=(0.8, 1.2), prob=1.0)
72 # Apply augmentation
73 augmented = aug(volume)
75 # Check shape preservation
76 assert augmented.shape == volume.shape
78 # Make sure the augmentation changed the volume (not identity)
79 assert not torch.allclose(augmented, volume, rtol=1e-3, atol=1e-3)
81 # Test with zero phase noise and fixed intensity scale (should be close to identity
82 # if there's no masking)
83 aug = FourierAugment3D(
84 freq_mask_prob=0.0, # No masking
85 phase_noise_std=0.0, # No phase noise
86 intensity_scaling_range=(1.0, 1.0), # No intensity scaling
87 prob=1.0,
88 )
90 # Force parameters
91 aug._mask = None
92 aug._phase_noise = torch.zeros_like(volume)
93 aug._intensity_scale = 1.0
95 # Apply augmentation without randomization
96 augmented = aug(volume, randomize=False)
98 # Should be identity transform
99 assert torch.allclose(augmented, volume, rtol=1e-3, atol=1e-3)
102def test_fourier_augment3d_channel_first():
103 """Test that FourierAugment3D works with channel-first inputs."""
104 # Create a test volume with channels
105 volume = torch.ones((3, 16, 16, 16))
107 # Set seed for reproducibility
108 torch.manual_seed(42)
109 np.random.seed(42)
111 # Test initialization
112 aug = FourierAugment3D(freq_mask_prob=0.3, phase_noise_std=0.1, intensity_scaling_range=(0.8, 1.2), prob=1.0)
114 # Apply augmentation with fixed seed
115 augmented = aug(volume, randomize=True)
117 # Check shape preservation
118 assert augmented.shape == volume.shape
120 # Check the channels are processed differently
121 # We're using a more robust check that doesn't depend on specific random values
122 diffs = []
123 for i in range(volume.shape[0] - 1):
124 for j in range(i + 1, volume.shape[0]):
125 # Calculate mean absolute difference between channels
126 diff = torch.abs(augmented[i] - augmented[j]).mean().item()
127 diffs.append(diff)
129 # Assert there's at least some difference between channels
130 # This is more robust than comparing specific tensors
131 assert max(diffs) > 0.01
134def test_zero_probability():
135 """Test that transforms with zero probability leave inputs unchanged."""
136 # Create test data
137 volume = torch.ones((16, 16, 16))
139 # Test FourierAugment3D with zero probability
140 aug = FourierAugment3D(
141 freq_mask_prob=0.3,
142 phase_noise_std=0.1,
143 intensity_scaling_range=(0.8, 1.2),
144 prob=0.0, # Zero probability of applying transform
145 )
147 # Apply augmentation
148 augmented = aug(volume)
150 # Should be identity transform
151 assert torch.allclose(augmented, volume)
153 # Test MixupTransform with zero probability
154 mixup = MixupTransform(alpha=0.2, prob=0.0)
156 # Create a simple batch
157 batch = torch.ones((4, 3, 8, 8, 8))
159 # Apply mixup
160 mixed_x, orig_x, mixed_idx_x, lam = mixup(batch)
162 # Lambda should be 1.0 (no mixing)
163 assert lam == 1.0
165 # Original data should be unchanged
166 assert torch.allclose(mixed_x, batch)