Coverage for tests/test_simple_dataset.py: 0%

116 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import os 

2import shutil 

3import tempfile 

4import unittest 

5from unittest.mock import MagicMock, patch 

6 

7import numpy as np 

8import torch 

9 

10from copick_torch import SimpleCopickDataset, SimpleDatasetMixin 

11 

12 

13class TestSimpleDatasetMixin(unittest.TestCase): 

14 """Test the SimpleDatasetMixin functionality.""" 

15 

16 def setUp(self): 

17 # Create a simple dataset with the mixin for testing 

18 self.test_dataset = type("TestDataset", (SimpleDatasetMixin, object), {})() 

19 

20 # Add required attributes for the mixin 

21 self.test_dataset._subvolumes = [np.ones((16, 16, 16))] 

22 self.test_dataset._molecule_ids = [0] 

23 self.test_dataset.augment = False 

24 

25 def test_getitem(self): 

26 """Test the __getitem__ method of SimpleDatasetMixin.""" 

27 # Mock the _augment_subvolume method 

28 self.test_dataset._augment_subvolume = lambda subvol, idx: subvol 

29 

30 # Get an item 

31 subvolume, molecule_idx = self.test_dataset.__getitem__(0) 

32 

33 # Check that the result is a tuple with the right types 

34 self.assertIsInstance(subvolume, torch.Tensor) 

35 self.assertEqual(molecule_idx, 0) 

36 

37 # Check the shape of the subvolume (should have channel dimension) 

38 self.assertEqual(subvolume.shape, (1, 16, 16, 16)) 

39 

40 def test_getitem_with_augmentation(self): 

41 """Test __getitem__ with augmentation enabled.""" 

42 # Enable augmentation 

43 self.test_dataset.augment = True 

44 

45 # Mock augmentation to return a scaled subvolume 

46 self.test_dataset._augment_subvolume = lambda subvol, idx: subvol * 2 

47 

48 # Get an item with augmentation 

49 subvolume, molecule_idx = self.test_dataset.__getitem__(0) 

50 

51 # Verify the augmentation was applied (values should be higher) 

52 # But normalization will bring them back to a similar range 

53 self.assertIsInstance(subvolume, torch.Tensor) 

54 

55 

56class TestSimpleCopickDataset(unittest.TestCase): 

57 """Test the SimpleCopickDataset class.""" 

58 

59 def setUp(self): 

60 # Create a temporary directory for caching 

61 self.test_dir = tempfile.mkdtemp() 

62 self.cache_dir = os.path.join(self.test_dir, "cache") 

63 os.makedirs(self.cache_dir, exist_ok=True) 

64 

65 # Mock config path 

66 self.mock_config_path = os.path.join(self.test_dir, "mock_config.json") 

67 

68 # Parameters for testing 

69 self.boxsize = (16, 16, 16) 

70 self.voxel_spacing = 10.0 

71 

72 def tearDown(self): 

73 # Clean up the temporary directory 

74 shutil.rmtree(self.test_dir) 

75 

76 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

77 def test_init_basic(self, mock_load_data): 

78 """Test basic initialization of SimpleCopickDataset.""" 

79 # Initialize with minimal parameters 

80 dataset = SimpleCopickDataset( 

81 config_path=self.mock_config_path, 

82 boxsize=self.boxsize, 

83 cache_dir=None, # Don't use caching 

84 ) 

85 

86 # Verify initialization 

87 self.assertEqual(dataset.config_path, self.mock_config_path) 

88 self.assertEqual(dataset.boxsize, self.boxsize) 

89 self.assertFalse(dataset.augment) 

90 self.assertIsNone(dataset.cache_dir) 

91 

92 # Verify _load_data was called 

93 mock_load_data.assert_called_once() 

94 

95 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

96 def test_init_with_options(self, mock_load_data): 

97 """Test initialization with various options.""" 

98 dataset = SimpleCopickDataset( 

99 config_path=self.mock_config_path, 

100 boxsize=self.boxsize, 

101 augment=True, 

102 cache_dir=self.cache_dir, 

103 cache_format="parquet", 

104 seed=42, 

105 max_samples=100, 

106 voxel_spacing=5.0, 

107 include_background=True, 

108 background_ratio=0.3, 

109 min_background_distance=20.0, 

110 patch_strategy="random", 

111 debug_mode=True, 

112 ) 

113 

114 # Verify all parameters were set correctly 

115 self.assertEqual(dataset.config_path, self.mock_config_path) 

116 self.assertEqual(dataset.boxsize, self.boxsize) 

117 self.assertTrue(dataset.augment) 

118 self.assertEqual(dataset.cache_dir, self.cache_dir) 

119 self.assertEqual(dataset.cache_format, "parquet") 

120 self.assertEqual(dataset.seed, 42) 

121 self.assertEqual(dataset.max_samples, 100) 

122 self.assertEqual(dataset.voxel_spacing, 5.0) 

123 self.assertTrue(dataset.include_background) 

124 self.assertEqual(dataset.background_ratio, 0.3) 

125 self.assertEqual(dataset.min_background_distance, 20.0) 

126 self.assertEqual(dataset.patch_strategy, "random") 

127 self.assertTrue(dataset.debug_mode) 

128 

129 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

130 def test_dataset_empty(self, mock_load_data): 

131 """Test behavior with empty dataset.""" 

132 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize) 

133 

134 # Mock empty dataset 

135 dataset._subvolumes = np.array([]) 

136 dataset._molecule_ids = np.array([]) 

137 dataset._is_background = np.array([]) 

138 dataset._keys = [] 

139 

140 # Test length 

141 self.assertEqual(len(dataset), 0) 

142 

143 # Test get_class_distribution with empty dataset 

144 distribution = dataset.get_class_distribution() 

145 self.assertEqual(distribution, {}) 

146 

147 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

148 def test_compute_sample_weights(self, mock_load_data): 

149 """Test the _compute_sample_weights method.""" 

150 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize) 

151 

152 # Create an unbalanced dataset 

153 dataset._molecule_ids = [0, 0, 0, 1, 1, 2] 

154 

155 # Compute sample weights 

156 dataset._compute_sample_weights() 

157 

158 # Check weights are inversely proportional to class frequency 

159 expected_weights = [6 / 3, 6 / 3, 6 / 3, 6 / 2, 6 / 2, 6 / 1] # total_samples / count_per_class 

160 np.testing.assert_array_almost_equal(dataset.sample_weights, expected_weights) 

161 

162 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

163 def test_get_sample_weights(self, mock_load_data): 

164 """Test the get_sample_weights method.""" 

165 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize) 

166 

167 # Set sample weights 

168 dataset.sample_weights = [1.0, 2.0, 3.0] 

169 

170 # Get sample weights 

171 weights = dataset.get_sample_weights() 

172 

173 # Check weights are returned correctly 

174 self.assertEqual(weights, [1.0, 2.0, 3.0]) 

175 

176 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

177 def test_keys(self, mock_load_data): 

178 """Test the keys method.""" 

179 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize) 

180 

181 # Set keys 

182 dataset._keys = ["class1", "class2", "class3"] 

183 

184 # Get keys 

185 keys = dataset.keys() 

186 

187 # Check keys are returned correctly 

188 self.assertEqual(keys, ["class1", "class2", "class3"]) 

189 

190 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

191 def test_get_class_distribution(self, mock_load_data): 

192 """Test the get_class_distribution method.""" 

193 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize) 

194 

195 # Create a test dataset with class distribution 

196 dataset._keys = ["class1", "class2", "class3"] 

197 dataset._molecule_ids = [0, 0, 0, 1, 1, 2, -1, -1] # -1 is background 

198 dataset._is_background = [False, False, False, False, False, False, True, True] 

199 

200 # Get class distribution 

201 distribution = dataset.get_class_distribution() 

202 

203 # Check distribution is correct 

204 expected_distribution = {"class1": 3, "class2": 2, "class3": 1, "background": 2} 

205 self.assertEqual(distribution, expected_distribution) 

206 

207 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

208 def test_validation_logic_missing_config(self, mock_load_data): 

209 """Test validation logic when both config_path and copick_root are missing.""" 

210 # Should raise ValueError when both config_path and copick_root are None 

211 with self.assertRaises(ValueError): 

212 SimpleCopickDataset(config_path=None, copick_root=None, boxsize=self.boxsize) 

213 

214 @patch("copick_torch.dataset.SimpleCopickDataset._extract_subvolume_with_validation") 

215 @patch("copick_torch.dataset.SimpleCopickDataset._load_data") 

216 def test_extract_subvolume_strategies(self, mock_load_data, mock_extract): 

217 """Test different patch extraction strategies.""" 

218 # Create a sample tomogram_array 

219 tomogram_array = np.zeros((32, 32, 32)) 

220 

221 # Test centered strategy 

222 dataset_centered = SimpleCopickDataset( 

223 config_path=self.mock_config_path, 

224 boxsize=self.boxsize, 

225 patch_strategy="centered", 

226 ) 

227 dataset_centered._extract_subvolume_with_validation(tomogram_array, 16, 16, 16) 

228 mock_extract.assert_called_once() 

229 

230 mock_extract.reset_mock() 

231 

232 # Test random strategy 

233 dataset_random = SimpleCopickDataset( 

234 config_path=self.mock_config_path, 

235 boxsize=self.boxsize, 

236 patch_strategy="random", 

237 ) 

238 dataset_random._extract_subvolume_with_validation(tomogram_array, 16, 16, 16) 

239 mock_extract.assert_called_once() 

240 

241 mock_extract.reset_mock() 

242 

243 # Test jittered strategy 

244 dataset_jittered = SimpleCopickDataset( 

245 config_path=self.mock_config_path, 

246 boxsize=self.boxsize, 

247 patch_strategy="jittered", 

248 ) 

249 dataset_jittered._extract_subvolume_with_validation(tomogram_array, 16, 16, 16) 

250 mock_extract.assert_called_once() 

251 

252 

253if __name__ == "__main__": 

254 unittest.main()