Coverage for tests/test_dataset_extraction.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import os 

2import shutil 

3import tempfile 

4import unittest 

5from unittest.mock import patch 

6 

7import numpy as np 

8 

9from copick_torch import SimpleCopickDataset 

10 

11 

12class TestDatasetExtraction(unittest.TestCase): 

13 """Test the subvolume extraction functionality of SimpleCopickDataset.""" 

14 

15 def setUp(self): 

16 # Create a temporary directory for testing 

17 self.test_dir = tempfile.mkdtemp() 

18 self.mock_config_path = os.path.join(self.test_dir, "mock_config.json") 

19 

20 # Define box size for testing 

21 self.boxsize = (16, 16, 16) 

22 

23 # Create a test tomogram 

24 self.tomogram_array = np.ones((32, 32, 32)) 

25 

26 # Add a gradient pattern to make the tomogram less uniform 

27 x, y, z = np.meshgrid(np.linspace(0, 1, 32), np.linspace(0, 1, 32), np.linspace(0, 1, 32)) 

28 self.tomogram_array = self.tomogram_array * (x + y + z) 

29 

30 def tearDown(self): 

31 # Clean up the temporary directory 

32 shutil.rmtree(self.test_dir) 

33 

34 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

35 def test_extract_center_valid(self, mock_load_data): 

36 """Test extracting a valid subvolume from the center of the tomogram.""" 

37 dataset = SimpleCopickDataset( 

38 config_path=self.mock_config_path, 

39 boxsize=self.boxsize, 

40 patch_strategy="centered", 

41 ) 

42 

43 # Extract from center of tomogram (should be valid) 

44 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16) 

45 

46 # Check results 

47 self.assertTrue(is_valid) 

48 self.assertEqual(status, "valid") 

49 self.assertEqual(subvolume.shape, self.boxsize) 

50 

51 # Since we extracted from the center, values should match the source tomogram 

52 center_slice = self.tomogram_array[8:24, 8:24, 8:24] 

53 np.testing.assert_array_equal(subvolume, center_slice) 

54 

55 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

56 def test_extract_edge_padded(self, mock_load_data): 

57 """Test extracting a subvolume near the edge of the tomogram (requires padding).""" 

58 dataset = SimpleCopickDataset( 

59 config_path=self.mock_config_path, 

60 boxsize=self.boxsize, 

61 patch_strategy="centered", 

62 ) 

63 

64 # Extract from near edge of tomogram (should require padding) 

65 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(self.tomogram_array, 3, 16, 16) 

66 

67 # Check results 

68 self.assertTrue(is_valid) 

69 self.assertEqual(status, "padded") 

70 self.assertEqual(subvolume.shape, self.boxsize) 

71 

72 # Check that the extracted subvolume has some zeros (from padding) 

73 self.assertTrue(np.any(subvolume == 0)) 

74 

75 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

76 def test_extract_near_edge(self, mock_load_data): 

77 """Test extracting a subvolume very close to the edge of the tomogram.""" 

78 dataset = SimpleCopickDataset( 

79 config_path=self.mock_config_path, 

80 boxsize=(16, 16, 16), 

81 patch_strategy="centered", 

82 ) 

83 

84 # Try to extract from positions that are technically valid but will need padding 

85 subvolume, is_valid, status = dataset._extract_subvolume_with_validation( 

86 self.tomogram_array, 

87 2, 

88 2, 

89 2, # Very close to the edge (0,0,0) 

90 ) 

91 

92 # The actual implementation pads rather than invalidates, so check for padding 

93 self.assertTrue(is_valid) 

94 self.assertEqual(status, "padded") 

95 self.assertEqual(subvolume.shape, (16, 16, 16)) 

96 

97 # Should contain zeros from padding 

98 self.assertTrue(np.any(subvolume == 0)) 

99 

100 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

101 def test_random_strategy(self, mock_load_data): 

102 """Test the random patch extraction strategy.""" 

103 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize, patch_strategy="random") 

104 

105 # Set a fixed seed for reproducibility 

106 np.random.seed(42) 

107 

108 # Extract with random strategy 

109 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16) 

110 

111 # Check results 

112 self.assertTrue(is_valid) 

113 self.assertEqual(subvolume.shape, self.boxsize) 

114 

115 # Extract again with same seed 

116 np.random.seed(42) 

117 subvolume2, is_valid2, status2 = dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16) 

118 

119 # Both extractions should be identical with the same seed 

120 np.testing.assert_array_equal(subvolume, subvolume2) 

121 

122 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

123 def test_jittered_strategy(self, mock_load_data): 

124 """Test the jittered patch extraction strategy.""" 

125 dataset = SimpleCopickDataset( 

126 config_path=self.mock_config_path, 

127 boxsize=self.boxsize, 

128 patch_strategy="jittered", 

129 ) 

130 

131 # Set a fixed seed for reproducibility 

132 np.random.seed(42) 

133 

134 # Extract with jittered strategy 

135 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16) 

136 

137 # Check results 

138 self.assertTrue(is_valid) 

139 self.assertEqual(subvolume.shape, self.boxsize) 

140 

141 # Extract with centered strategy for comparison 

142 centered_dataset = SimpleCopickDataset( 

143 config_path=self.mock_config_path, 

144 boxsize=self.boxsize, 

145 patch_strategy="centered", 

146 ) 

147 

148 centered_subvolume, _, _ = centered_dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16) 

149 

150 # Jittered should be different from centered (small chance they're identical) 

151 # This might rarely fail if the random jitter is (0,0,0) 

152 try: 

153 np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, subvolume, centered_subvolume) 

154 except AssertionError: 

155 # If the above fails, check that we had a very small jitter 

156 # by verifying most values are the same 

157 same_values = np.count_nonzero(subvolume == centered_subvolume) 

158 total_values = np.prod(self.boxsize) 

159 self.assertGreater(same_values / total_values, 0.9) # >90% same 

160 

161 

162if __name__ == "__main__": 

163 unittest.main()