Coverage for tests/test_background_sampling.py: 0%
86 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import os
2import shutil
3import tempfile
4import unittest
5from unittest.mock import MagicMock, patch
7import numpy as np
9from copick_torch import SimpleCopickDataset
12class TestBackgroundSampling(unittest.TestCase):
13 """Test the background sampling functionality of SimpleCopickDataset."""
15 def setUp(self):
16 # Create a temporary directory for testing
17 self.test_dir = tempfile.mkdtemp()
18 self.mock_config_path = os.path.join(self.test_dir, "mock_config.json")
20 # Define box size for testing
21 self.boxsize = (16, 16, 16)
23 # Create a test tomogram
24 self.tomogram_array = np.zeros((64, 64, 64))
26 # Add some "particles" to make background sampling more realistic
27 self.particle_coords = [
28 (16, 16, 16), # center particle
29 (40, 40, 40), # corner particle
30 (16, 40, 40), # edge particle
31 ]
33 def tearDown(self):
34 # Clean up the temporary directory
35 shutil.rmtree(self.test_dir)
37 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
38 def test_sample_background_points(self, mock_load_data):
39 """Test the _sample_background_points method."""
40 # Create dataset with background sampling enabled
41 dataset = SimpleCopickDataset(
42 config_path=self.mock_config_path,
43 boxsize=self.boxsize,
44 include_background=True,
45 background_ratio=0.5, # One background sample for every two particles
46 min_background_distance=20.0, # Stay at least 20 units away from particles
47 patch_strategy="centered",
48 )
50 # Initialize dataset properties
51 dataset._subvolumes = []
52 dataset._molecule_ids = []
53 dataset._is_background = []
55 # Mock the _extract_subvolume_with_validation method to return valid subvolumes
56 def mock_extract(*args, **kwargs):
57 return np.zeros(self.boxsize), True, "valid"
59 dataset._extract_subvolume_with_validation = mock_extract
61 # Set a fixed seed for reproducibility
62 np.random.seed(42)
64 # Sample background points
65 dataset._sample_background_points(self.tomogram_array, self.particle_coords)
67 # Check that background samples were added
68 # Should add background_ratio * len(particle_coords) samples = 0.5 * 3 = 1 or 2
69 self.assertGreater(len(dataset._subvolumes), 0)
70 self.assertGreater(len(dataset._molecule_ids), 0)
71 self.assertGreater(len(dataset._is_background), 0)
73 # Check that all added samples are marked as background
74 self.assertTrue(all(dataset._is_background))
76 # Check that all added samples have molecule_id = -1
77 self.assertTrue(all(mol_id == -1 for mol_id in dataset._molecule_ids))
79 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
80 def test_sample_background_points_no_particles(self, mock_load_data):
81 """Test the _sample_background_points method with no particles."""
82 # Create dataset with background sampling enabled
83 dataset = SimpleCopickDataset(
84 config_path=self.mock_config_path,
85 boxsize=self.boxsize,
86 include_background=True,
87 background_ratio=0.5,
88 )
90 # Initialize dataset properties
91 dataset._subvolumes = []
92 dataset._molecule_ids = []
93 dataset._is_background = []
95 # Sample background points with no particles
96 dataset._sample_background_points(self.tomogram_array, [])
98 # Check that no background samples were added
99 self.assertEqual(len(dataset._subvolumes), 0)
100 self.assertEqual(len(dataset._molecule_ids), 0)
101 self.assertEqual(len(dataset._is_background), 0)
103 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
104 def test_sample_background_fails_validation(self, mock_load_data):
105 """Test when background samples fail validation."""
106 # Create dataset with background sampling enabled
107 dataset = SimpleCopickDataset(
108 config_path=self.mock_config_path,
109 boxsize=self.boxsize,
110 include_background=True,
111 background_ratio=1.0, # One background sample for each particle
112 min_background_distance=20.0,
113 )
115 # Initialize dataset properties
116 dataset._subvolumes = []
117 dataset._molecule_ids = []
118 dataset._is_background = []
120 # Mock the _extract_subvolume_with_validation method to always fail validation
121 def mock_extract(*args, **kwargs):
122 return None, False, "Invalid slice range"
124 dataset._extract_subvolume_with_validation = mock_extract
126 # Sample background points
127 dataset._sample_background_points(self.tomogram_array, self.particle_coords)
129 # Check that no background samples were added
130 self.assertEqual(len(dataset._subvolumes), 0)
131 self.assertEqual(len(dataset._molecule_ids), 0)
132 self.assertEqual(len(dataset._is_background), 0)
134 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
135 def test_sample_background_distance_constraint(self, mock_load_data):
136 """Test that background samples respect the minimum distance constraint."""
137 # Create dataset with background sampling enabled and a large min distance
138 min_distance = 100.0 # Very large distance requirement
139 dataset = SimpleCopickDataset(
140 config_path=self.mock_config_path,
141 boxsize=self.boxsize,
142 include_background=True,
143 background_ratio=1.0,
144 min_background_distance=min_distance, # Very strict distance requirement
145 )
147 # Initialize dataset properties
148 dataset._subvolumes = []
149 dataset._molecule_ids = []
150 dataset._is_background = []
152 # Mock extraction to return valid subvolumes
153 def mock_extract(*args, **kwargs):
154 return np.zeros(self.boxsize), True, "valid"
156 dataset._extract_subvolume_with_validation = mock_extract
158 # Set a fixed seed for reproducibility
159 np.random.seed(42)
161 # Sample background points with a very strict distance constraint
162 # This will likely hit the max_attempts limit
163 dataset._sample_background_points(self.tomogram_array, self.particle_coords)
165 # We expect few or no samples due to the strict constraint
166 # The test passes if we don't get an infinite loop
167 self.assertGreaterEqual(len(dataset._subvolumes), 0)
169 def test_include_background_in_load_data(self):
170 """Test that _load_data calls _sample_background_points when include_background=True."""
171 # Create a mock copick root
172 mock_root = MagicMock()
173 mock_run = MagicMock()
174 mock_voxel_spacing = MagicMock()
175 mock_tomogram = MagicMock()
177 # Configure the mocks
178 mock_root.runs = [mock_run]
179 mock_run.name = "mock_run"
180 mock_run.get_voxel_spacing.return_value = mock_voxel_spacing
181 mock_voxel_spacing.tomograms = [mock_tomogram]
182 mock_tomogram.numpy.return_value = self.tomogram_array
184 # Create mock picks
185 mock_picks = MagicMock()
186 mock_picks.from_tool = True
187 mock_picks.pickable_object_name = "test_object"
188 mock_picks.numpy.return_value = (np.array([[16, 16, 16]]), None)
189 mock_run.get_picks.return_value = [mock_picks]
191 # Use patch within the test function
192 with patch("copick_torch.dataset.SimpleCopickDataset._sample_background_points") as mock_sample_bg:
193 # Create dataset with include_background=True
194 _ = SimpleCopickDataset(
195 config_path=None,
196 copick_root=mock_root,
197 boxsize=self.boxsize,
198 include_background=True,
199 background_ratio=0.5,
200 cache_dir=None, # Disable caching to ensure _load_data runs
201 )
203 # The _load_data method should have been called during initialization
204 # and _sample_background_points should be called from within it
205 mock_sample_bg.assert_called()
208if __name__ == "__main__":
209 unittest.main()