Coverage for tests/test_copick_data_portal_distribution.py: 0%
78 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import json
2import os
3import shutil
4import tempfile
5import unittest
6from collections import Counter
7from unittest.mock import MagicMock, patch
9import copick
10import numpy as np
12from copick_torch import SimpleCopickDataset
15class TestCopickDataPortalDistribution(unittest.TestCase):
16 """
17 Test that verifies the SimpleCopickDataset correctly preserves the
18 distribution of pickable objects from the CryoET Data Portal.
19 """
21 @classmethod
22 def setUpClass(cls):
23 """Set up test environment."""
24 # Create a temporary directory for caching
25 cls.temp_dir = tempfile.mkdtemp()
26 cls.cache_dir = os.path.join(cls.temp_dir, "cache")
27 os.makedirs(cls.cache_dir, exist_ok=True)
29 # Define the test dataset ID from the CryoET Data Portal
30 cls.dataset_id = 10440
31 cls.overlay_root = "./overlay"
33 # Create a temporary config file for the dataset
34 cls.config_path = os.path.join(cls.temp_dir, "test_config.json")
36 # Create the config file content with multiple pickable objects
37 config = {
38 "config_type": "cryoet_data_portal",
39 "name": "Test Dataset",
40 "description": "Test Dataset for distribution verification",
41 "version": "1.0.0",
42 "overlay_root": cls.overlay_root,
43 "overlay_fs_args": {"auto_mkdir": True},
44 "dataset_ids": [cls.dataset_id],
45 "pickable_objects": [
46 {
47 "name": "cytosolic-ribosome",
48 "go_id": "GO:0022626",
49 "is_particle": True,
50 "label": 1,
51 "color": [0, 255, 0, 255],
52 "radius": 50.0,
53 },
54 {
55 "name": "beta-amylase",
56 "go_id": "UniProtKB:P10537",
57 "is_particle": True,
58 "label": 2,
59 "color": [255, 0, 255, 255],
60 "radius": 50.0,
61 },
62 {
63 "name": "thyroglobulin",
64 "go_id": "UniProtKB:P01267",
65 "is_particle": True,
66 "label": 3,
67 "color": [0, 127, 255, 255],
68 "radius": 50.0,
69 },
70 {
71 "name": "virus-like-capsid",
72 "go_id": "GO:0170047",
73 "is_particle": True,
74 "label": 4,
75 "color": [255, 127, 0, 255],
76 "radius": 50.0,
77 },
78 {
79 "name": "ferritin-complex",
80 "go_id": "GO:0070288",
81 "is_particle": True,
82 "label": 5,
83 "color": [127, 191, 127, 255],
84 "radius": 50.0,
85 },
86 {
87 "name": "beta-galactosidase",
88 "go_id": "UniProtKB:P00722",
89 "is_particle": True,
90 "label": 6,
91 "color": [94, 6, 164, 255],
92 "radius": 50.0,
93 },
94 ],
95 }
97 # Write the config to file
98 with open(cls.config_path, "w") as f:
99 json.dump(config, f)
101 @classmethod
102 def tearDownClass(cls):
103 """Clean up test environment."""
104 shutil.rmtree(cls.temp_dir)
106 @patch("copick_torch.dataset.SimpleCopickDataset._extract_subvolume_with_validation")
107 def test_pickable_object_distribution(self, mock_extract_subvolume):
108 """
109 Test that the distribution of pickable objects in SimpleCopickDataset
110 matches the distribution in the CryoET Data Portal.
112 This test only mocks the subvolume extraction to avoid the slow zarr loading
113 while still using real picks data for testing the distribution consistency.
114 """
115 # Mock the extract_subvolume method to always return a valid subvolume
116 mock_extract_subvolume.return_value = (np.zeros((32, 32, 32)), True, "valid")
118 try:
119 # Load the Copick project with real data
120 project = copick.from_file(self.config_path)
122 # Use only one run to speed up the test
123 if project.runs:
124 # Use the first run with available picks
125 run = None
126 for potential_run in project.runs:
127 # Check if the run has picks
128 has_picks = False
129 for picks in potential_run.get_picks():
130 if picks.from_tool:
131 has_picks = True
132 break
134 if has_picks:
135 run = potential_run
136 break
138 if not run:
139 self.skipTest("No runs with picks found in the dataset.")
141 print(f"\nUsing run: {run.name} for testing")
143 # Create a counter to track object counts directly from Copick
144 copick_object_counts = Counter()
146 # Count the pickable objects in this run using real picks data
147 for picks in run.get_picks():
148 if picks.from_tool:
149 # Get the object name and count the points
150 object_name = picks.pickable_object_name
151 points, _ = picks.numpy()
152 copick_object_counts[object_name] += len(points)
154 # Now, create the SimpleCopickDataset with the same config
155 # Use a modified configuration that only includes the selected run
156 modified_config = self.config_path.replace(".json", f"_{run.name}.json")
158 # Get the original config
159 with open(self.config_path, "r") as f:
160 config_data = json.load(f)
162 # Create a custom run filter to select only this run
163 config_data["run_filter"] = [{"name": run.name}]
165 # Write the modified config
166 with open(modified_config, "w") as f:
167 json.dump(config_data, f)
169 # Create dataset with the modified config
170 dataset = SimpleCopickDataset(
171 config_path=modified_config,
172 boxsize=(32, 32, 32),
173 voxel_spacing=10.012,
174 cache_dir=None, # Don't use caching for this test
175 )
177 # Get the distribution of classes in the dataset
178 dataset_distribution = dataset.get_class_distribution()
180 # Skip test if no objects were found
181 if not copick_object_counts:
182 self.skipTest("No pickable objects found in the selected run.")
184 # Check that all pickable objects are represented in the dataset
185 for object_name, count in copick_object_counts.items():
186 # Skip beta-amylase check as it may not be in the dataset
187 if object_name == "beta-amylase" and object_name not in dataset_distribution:
188 print("Warning: beta-amylase not found in dataset distribution, skipping check")
189 continue
191 self.assertIn(
192 object_name,
193 dataset_distribution,
194 f"Object {object_name} is missing from the dataset",
195 )
197 # Calculate the proportion of each object in both distributions
198 copick_proportion = count / sum(copick_object_counts.values())
199 dataset_proportion = dataset_distribution[object_name] / sum(dataset_distribution.values())
201 # Assert that the proportion is similar (within 5% margin)
202 diff = abs(copick_proportion - dataset_proportion)
203 self.assertLess(
204 diff,
205 0.05,
206 f"Proportion mismatch for {object_name}: "
207 f"Copick: {copick_proportion:.3f}, Dataset: {dataset_proportion:.3f}",
208 )
210 # Print the distributions for logging purposes
211 print("\nCopick Object Counts:")
212 for obj, count in copick_object_counts.items():
213 print(f" {obj}: {count}")
215 print("\nDataset Distribution:")
216 for obj, count in dataset_distribution.items():
217 print(f" {obj}: {count}")
218 else:
219 self.skipTest("No runs found in the dataset.")
221 except Exception as e:
222 self.fail(f"Test failed with error: {str(e)}")
225if __name__ == "__main__":
226 unittest.main()