Coverage for tests/test_background_sampling.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import os 

2import shutil 

3import tempfile 

4import unittest 

5from unittest.mock import MagicMock, patch 

6 

7import numpy as np 

8 

9from copick_torch import SimpleCopickDataset 

10 

11 

12class TestBackgroundSampling(unittest.TestCase): 

13 """Test the background sampling functionality of SimpleCopickDataset.""" 

14 

15 def setUp(self): 

16 # Create a temporary directory for testing 

17 self.test_dir = tempfile.mkdtemp() 

18 self.mock_config_path = os.path.join(self.test_dir, "mock_config.json") 

19 

20 # Define box size for testing 

21 self.boxsize = (16, 16, 16) 

22 

23 # Create a test tomogram 

24 self.tomogram_array = np.zeros((64, 64, 64)) 

25 

26 # Add some "particles" to make background sampling more realistic 

27 self.particle_coords = [ 

28 (16, 16, 16), # center particle 

29 (40, 40, 40), # corner particle 

30 (16, 40, 40), # edge particle 

31 ] 

32 

33 def tearDown(self): 

34 # Clean up the temporary directory 

35 shutil.rmtree(self.test_dir) 

36 

37 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

38 def test_sample_background_points(self, mock_load_data): 

39 """Test the _sample_background_points method.""" 

40 # Create dataset with background sampling enabled 

41 dataset = SimpleCopickDataset( 

42 config_path=self.mock_config_path, 

43 boxsize=self.boxsize, 

44 include_background=True, 

45 background_ratio=0.5, # One background sample for every two particles 

46 min_background_distance=20.0, # Stay at least 20 units away from particles 

47 patch_strategy="centered", 

48 ) 

49 

50 # Initialize dataset properties 

51 dataset._subvolumes = [] 

52 dataset._molecule_ids = [] 

53 dataset._is_background = [] 

54 

55 # Mock the _extract_subvolume_with_validation method to return valid subvolumes 

56 def mock_extract(*args, **kwargs): 

57 return np.zeros(self.boxsize), True, "valid" 

58 

59 dataset._extract_subvolume_with_validation = mock_extract 

60 

61 # Set a fixed seed for reproducibility 

62 np.random.seed(42) 

63 

64 # Sample background points 

65 dataset._sample_background_points(self.tomogram_array, self.particle_coords) 

66 

67 # Check that background samples were added 

68 # Should add background_ratio * len(particle_coords) samples = 0.5 * 3 = 1 or 2 

69 self.assertGreater(len(dataset._subvolumes), 0) 

70 self.assertGreater(len(dataset._molecule_ids), 0) 

71 self.assertGreater(len(dataset._is_background), 0) 

72 

73 # Check that all added samples are marked as background 

74 self.assertTrue(all(dataset._is_background)) 

75 

76 # Check that all added samples have molecule_id = -1 

77 self.assertTrue(all(mol_id == -1 for mol_id in dataset._molecule_ids)) 

78 

79 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

80 def test_sample_background_points_no_particles(self, mock_load_data): 

81 """Test the _sample_background_points method with no particles.""" 

82 # Create dataset with background sampling enabled 

83 dataset = SimpleCopickDataset( 

84 config_path=self.mock_config_path, 

85 boxsize=self.boxsize, 

86 include_background=True, 

87 background_ratio=0.5, 

88 ) 

89 

90 # Initialize dataset properties 

91 dataset._subvolumes = [] 

92 dataset._molecule_ids = [] 

93 dataset._is_background = [] 

94 

95 # Sample background points with no particles 

96 dataset._sample_background_points(self.tomogram_array, []) 

97 

98 # Check that no background samples were added 

99 self.assertEqual(len(dataset._subvolumes), 0) 

100 self.assertEqual(len(dataset._molecule_ids), 0) 

101 self.assertEqual(len(dataset._is_background), 0) 

102 

103 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

104 def test_sample_background_fails_validation(self, mock_load_data): 

105 """Test when background samples fail validation.""" 

106 # Create dataset with background sampling enabled 

107 dataset = SimpleCopickDataset( 

108 config_path=self.mock_config_path, 

109 boxsize=self.boxsize, 

110 include_background=True, 

111 background_ratio=1.0, # One background sample for each particle 

112 min_background_distance=20.0, 

113 ) 

114 

115 # Initialize dataset properties 

116 dataset._subvolumes = [] 

117 dataset._molecule_ids = [] 

118 dataset._is_background = [] 

119 

120 # Mock the _extract_subvolume_with_validation method to always fail validation 

121 def mock_extract(*args, **kwargs): 

122 return None, False, "Invalid slice range" 

123 

124 dataset._extract_subvolume_with_validation = mock_extract 

125 

126 # Sample background points 

127 dataset._sample_background_points(self.tomogram_array, self.particle_coords) 

128 

129 # Check that no background samples were added 

130 self.assertEqual(len(dataset._subvolumes), 0) 

131 self.assertEqual(len(dataset._molecule_ids), 0) 

132 self.assertEqual(len(dataset._is_background), 0) 

133 

134 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

135 def test_sample_background_distance_constraint(self, mock_load_data): 

136 """Test that background samples respect the minimum distance constraint.""" 

137 # Create dataset with background sampling enabled and a large min distance 

138 min_distance = 100.0 # Very large distance requirement 

139 dataset = SimpleCopickDataset( 

140 config_path=self.mock_config_path, 

141 boxsize=self.boxsize, 

142 include_background=True, 

143 background_ratio=1.0, 

144 min_background_distance=min_distance, # Very strict distance requirement 

145 ) 

146 

147 # Initialize dataset properties 

148 dataset._subvolumes = [] 

149 dataset._molecule_ids = [] 

150 dataset._is_background = [] 

151 

152 # Mock extraction to return valid subvolumes 

153 def mock_extract(*args, **kwargs): 

154 return np.zeros(self.boxsize), True, "valid" 

155 

156 dataset._extract_subvolume_with_validation = mock_extract 

157 

158 # Set a fixed seed for reproducibility 

159 np.random.seed(42) 

160 

161 # Sample background points with a very strict distance constraint 

162 # This will likely hit the max_attempts limit 

163 dataset._sample_background_points(self.tomogram_array, self.particle_coords) 

164 

165 # We expect few or no samples due to the strict constraint 

166 # The test passes if we don't get an infinite loop 

167 self.assertGreaterEqual(len(dataset._subvolumes), 0) 

168 

169 def test_include_background_in_load_data(self): 

170 """Test that _load_data calls _sample_background_points when include_background=True.""" 

171 # Create a mock copick root 

172 mock_root = MagicMock() 

173 mock_run = MagicMock() 

174 mock_voxel_spacing = MagicMock() 

175 mock_tomogram = MagicMock() 

176 

177 # Configure the mocks 

178 mock_root.runs = [mock_run] 

179 mock_run.name = "mock_run" 

180 mock_run.get_voxel_spacing.return_value = mock_voxel_spacing 

181 mock_voxel_spacing.tomograms = [mock_tomogram] 

182 mock_tomogram.numpy.return_value = self.tomogram_array 

183 

184 # Create mock picks 

185 mock_picks = MagicMock() 

186 mock_picks.from_tool = True 

187 mock_picks.pickable_object_name = "test_object" 

188 mock_picks.numpy.return_value = (np.array([[16, 16, 16]]), None) 

189 mock_run.get_picks.return_value = [mock_picks] 

190 

191 # Use patch within the test function 

192 with patch("copick_torch.dataset.SimpleCopickDataset._sample_background_points") as mock_sample_bg: 

193 # Create dataset with include_background=True 

194 _ = SimpleCopickDataset( 

195 config_path=None, 

196 copick_root=mock_root, 

197 boxsize=self.boxsize, 

198 include_background=True, 

199 background_ratio=0.5, 

200 cache_dir=None, # Disable caching to ensure _load_data runs 

201 ) 

202 

203 # The _load_data method should have been called during initialization 

204 # and _sample_background_points should be called from within it 

205 mock_sample_bg.assert_called() 

206 

207 

208if __name__ == "__main__": 

209 unittest.main()