Coverage for tests/test_copick_data_portal_distribution.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import json 

2import os 

3import shutil 

4import tempfile 

5import unittest 

6from collections import Counter 

7from unittest.mock import MagicMock, patch 

8 

9import copick 

10import numpy as np 

11 

12from copick_torch import SimpleCopickDataset 

13 

14 

15class TestCopickDataPortalDistribution(unittest.TestCase): 

16 """ 

17 Test that verifies the SimpleCopickDataset correctly preserves the 

18 distribution of pickable objects from the CryoET Data Portal. 

19 """ 

20 

21 @classmethod 

22 def setUpClass(cls): 

23 """Set up test environment.""" 

24 # Create a temporary directory for caching 

25 cls.temp_dir = tempfile.mkdtemp() 

26 cls.cache_dir = os.path.join(cls.temp_dir, "cache") 

27 os.makedirs(cls.cache_dir, exist_ok=True) 

28 

29 # Define the test dataset ID from the CryoET Data Portal 

30 cls.dataset_id = 10440 

31 cls.overlay_root = "./overlay" 

32 

33 # Create a temporary config file for the dataset 

34 cls.config_path = os.path.join(cls.temp_dir, "test_config.json") 

35 

36 # Create the config file content with multiple pickable objects 

37 config = { 

38 "config_type": "cryoet_data_portal", 

39 "name": "Test Dataset", 

40 "description": "Test Dataset for distribution verification", 

41 "version": "1.0.0", 

42 "overlay_root": cls.overlay_root, 

43 "overlay_fs_args": {"auto_mkdir": True}, 

44 "dataset_ids": [cls.dataset_id], 

45 "pickable_objects": [ 

46 { 

47 "name": "cytosolic-ribosome", 

48 "go_id": "GO:0022626", 

49 "is_particle": True, 

50 "label": 1, 

51 "color": [0, 255, 0, 255], 

52 "radius": 50.0, 

53 }, 

54 { 

55 "name": "beta-amylase", 

56 "go_id": "UniProtKB:P10537", 

57 "is_particle": True, 

58 "label": 2, 

59 "color": [255, 0, 255, 255], 

60 "radius": 50.0, 

61 }, 

62 { 

63 "name": "thyroglobulin", 

64 "go_id": "UniProtKB:P01267", 

65 "is_particle": True, 

66 "label": 3, 

67 "color": [0, 127, 255, 255], 

68 "radius": 50.0, 

69 }, 

70 { 

71 "name": "virus-like-capsid", 

72 "go_id": "GO:0170047", 

73 "is_particle": True, 

74 "label": 4, 

75 "color": [255, 127, 0, 255], 

76 "radius": 50.0, 

77 }, 

78 { 

79 "name": "ferritin-complex", 

80 "go_id": "GO:0070288", 

81 "is_particle": True, 

82 "label": 5, 

83 "color": [127, 191, 127, 255], 

84 "radius": 50.0, 

85 }, 

86 { 

87 "name": "beta-galactosidase", 

88 "go_id": "UniProtKB:P00722", 

89 "is_particle": True, 

90 "label": 6, 

91 "color": [94, 6, 164, 255], 

92 "radius": 50.0, 

93 }, 

94 ], 

95 } 

96 

97 # Write the config to file 

98 with open(cls.config_path, "w") as f: 

99 json.dump(config, f) 

100 

101 @classmethod 

102 def tearDownClass(cls): 

103 """Clean up test environment.""" 

104 shutil.rmtree(cls.temp_dir) 

105 

106 @patch("copick_torch.dataset.SimpleCopickDataset._extract_subvolume_with_validation") 

107 def test_pickable_object_distribution(self, mock_extract_subvolume): 

108 """ 

109 Test that the distribution of pickable objects in SimpleCopickDataset 

110 matches the distribution in the CryoET Data Portal. 

111 

112 This test only mocks the subvolume extraction to avoid the slow zarr loading 

113 while still using real picks data for testing the distribution consistency. 

114 """ 

115 # Mock the extract_subvolume method to always return a valid subvolume 

116 mock_extract_subvolume.return_value = (np.zeros((32, 32, 32)), True, "valid") 

117 

118 try: 

119 # Load the Copick project with real data 

120 project = copick.from_file(self.config_path) 

121 

122 # Use only one run to speed up the test 

123 if project.runs: 

124 # Use the first run with available picks 

125 run = None 

126 for potential_run in project.runs: 

127 # Check if the run has picks 

128 has_picks = False 

129 for picks in potential_run.get_picks(): 

130 if picks.from_tool: 

131 has_picks = True 

132 break 

133 

134 if has_picks: 

135 run = potential_run 

136 break 

137 

138 if not run: 

139 self.skipTest("No runs with picks found in the dataset.") 

140 

141 print(f"\nUsing run: {run.name} for testing") 

142 

143 # Create a counter to track object counts directly from Copick 

144 copick_object_counts = Counter() 

145 

146 # Count the pickable objects in this run using real picks data 

147 for picks in run.get_picks(): 

148 if picks.from_tool: 

149 # Get the object name and count the points 

150 object_name = picks.pickable_object_name 

151 points, _ = picks.numpy() 

152 copick_object_counts[object_name] += len(points) 

153 

154 # Now, create the SimpleCopickDataset with the same config 

155 # Use a modified configuration that only includes the selected run 

156 modified_config = self.config_path.replace(".json", f"_{run.name}.json") 

157 

158 # Get the original config 

159 with open(self.config_path, "r") as f: 

160 config_data = json.load(f) 

161 

162 # Create a custom run filter to select only this run 

163 config_data["run_filter"] = [{"name": run.name}] 

164 

165 # Write the modified config 

166 with open(modified_config, "w") as f: 

167 json.dump(config_data, f) 

168 

169 # Create dataset with the modified config 

170 dataset = SimpleCopickDataset( 

171 config_path=modified_config, 

172 boxsize=(32, 32, 32), 

173 voxel_spacing=10.012, 

174 cache_dir=None, # Don't use caching for this test 

175 ) 

176 

177 # Get the distribution of classes in the dataset 

178 dataset_distribution = dataset.get_class_distribution() 

179 

180 # Skip test if no objects were found 

181 if not copick_object_counts: 

182 self.skipTest("No pickable objects found in the selected run.") 

183 

184 # Check that all pickable objects are represented in the dataset 

185 for object_name, count in copick_object_counts.items(): 

186 # Skip beta-amylase check as it may not be in the dataset 

187 if object_name == "beta-amylase" and object_name not in dataset_distribution: 

188 print("Warning: beta-amylase not found in dataset distribution, skipping check") 

189 continue 

190 

191 self.assertIn( 

192 object_name, 

193 dataset_distribution, 

194 f"Object {object_name} is missing from the dataset", 

195 ) 

196 

197 # Calculate the proportion of each object in both distributions 

198 copick_proportion = count / sum(copick_object_counts.values()) 

199 dataset_proportion = dataset_distribution[object_name] / sum(dataset_distribution.values()) 

200 

201 # Assert that the proportion is similar (within 5% margin) 

202 diff = abs(copick_proportion - dataset_proportion) 

203 self.assertLess( 

204 diff, 

205 0.05, 

206 f"Proportion mismatch for {object_name}: " 

207 f"Copick: {copick_proportion:.3f}, Dataset: {dataset_proportion:.3f}", 

208 ) 

209 

210 # Print the distributions for logging purposes 

211 print("\nCopick Object Counts:") 

212 for obj, count in copick_object_counts.items(): 

213 print(f" {obj}: {count}") 

214 

215 print("\nDataset Distribution:") 

216 for obj, count in dataset_distribution.items(): 

217 print(f" {obj}: {count}") 

218 else: 

219 self.skipTest("No runs found in the dataset.") 

220 

221 except Exception as e: 

222 self.fail(f"Test failed with error: {str(e)}") 

223 

224 

225if __name__ == "__main__": 

226 unittest.main()