Coverage for tests/test_simple_dataset.py: 0%
116 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import os
2import shutil
3import tempfile
4import unittest
5from unittest.mock import MagicMock, patch
7import numpy as np
8import torch
10from copick_torch import SimpleCopickDataset, SimpleDatasetMixin
13class TestSimpleDatasetMixin(unittest.TestCase):
14 """Test the SimpleDatasetMixin functionality."""
16 def setUp(self):
17 # Create a simple dataset with the mixin for testing
18 self.test_dataset = type("TestDataset", (SimpleDatasetMixin, object), {})()
20 # Add required attributes for the mixin
21 self.test_dataset._subvolumes = [np.ones((16, 16, 16))]
22 self.test_dataset._molecule_ids = [0]
23 self.test_dataset.augment = False
25 def test_getitem(self):
26 """Test the __getitem__ method of SimpleDatasetMixin."""
27 # Mock the _augment_subvolume method
28 self.test_dataset._augment_subvolume = lambda subvol, idx: subvol
30 # Get an item
31 subvolume, molecule_idx = self.test_dataset.__getitem__(0)
33 # Check that the result is a tuple with the right types
34 self.assertIsInstance(subvolume, torch.Tensor)
35 self.assertEqual(molecule_idx, 0)
37 # Check the shape of the subvolume (should have channel dimension)
38 self.assertEqual(subvolume.shape, (1, 16, 16, 16))
40 def test_getitem_with_augmentation(self):
41 """Test __getitem__ with augmentation enabled."""
42 # Enable augmentation
43 self.test_dataset.augment = True
45 # Mock augmentation to return a scaled subvolume
46 self.test_dataset._augment_subvolume = lambda subvol, idx: subvol * 2
48 # Get an item with augmentation
49 subvolume, molecule_idx = self.test_dataset.__getitem__(0)
51 # Verify the augmentation was applied (values should be higher)
52 # But normalization will bring them back to a similar range
53 self.assertIsInstance(subvolume, torch.Tensor)
56class TestSimpleCopickDataset(unittest.TestCase):
57 """Test the SimpleCopickDataset class."""
59 def setUp(self):
60 # Create a temporary directory for caching
61 self.test_dir = tempfile.mkdtemp()
62 self.cache_dir = os.path.join(self.test_dir, "cache")
63 os.makedirs(self.cache_dir, exist_ok=True)
65 # Mock config path
66 self.mock_config_path = os.path.join(self.test_dir, "mock_config.json")
68 # Parameters for testing
69 self.boxsize = (16, 16, 16)
70 self.voxel_spacing = 10.0
72 def tearDown(self):
73 # Clean up the temporary directory
74 shutil.rmtree(self.test_dir)
76 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
77 def test_init_basic(self, mock_load_data):
78 """Test basic initialization of SimpleCopickDataset."""
79 # Initialize with minimal parameters
80 dataset = SimpleCopickDataset(
81 config_path=self.mock_config_path,
82 boxsize=self.boxsize,
83 cache_dir=None, # Don't use caching
84 )
86 # Verify initialization
87 self.assertEqual(dataset.config_path, self.mock_config_path)
88 self.assertEqual(dataset.boxsize, self.boxsize)
89 self.assertFalse(dataset.augment)
90 self.assertIsNone(dataset.cache_dir)
92 # Verify _load_data was called
93 mock_load_data.assert_called_once()
95 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
96 def test_init_with_options(self, mock_load_data):
97 """Test initialization with various options."""
98 dataset = SimpleCopickDataset(
99 config_path=self.mock_config_path,
100 boxsize=self.boxsize,
101 augment=True,
102 cache_dir=self.cache_dir,
103 cache_format="parquet",
104 seed=42,
105 max_samples=100,
106 voxel_spacing=5.0,
107 include_background=True,
108 background_ratio=0.3,
109 min_background_distance=20.0,
110 patch_strategy="random",
111 debug_mode=True,
112 )
114 # Verify all parameters were set correctly
115 self.assertEqual(dataset.config_path, self.mock_config_path)
116 self.assertEqual(dataset.boxsize, self.boxsize)
117 self.assertTrue(dataset.augment)
118 self.assertEqual(dataset.cache_dir, self.cache_dir)
119 self.assertEqual(dataset.cache_format, "parquet")
120 self.assertEqual(dataset.seed, 42)
121 self.assertEqual(dataset.max_samples, 100)
122 self.assertEqual(dataset.voxel_spacing, 5.0)
123 self.assertTrue(dataset.include_background)
124 self.assertEqual(dataset.background_ratio, 0.3)
125 self.assertEqual(dataset.min_background_distance, 20.0)
126 self.assertEqual(dataset.patch_strategy, "random")
127 self.assertTrue(dataset.debug_mode)
129 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
130 def test_dataset_empty(self, mock_load_data):
131 """Test behavior with empty dataset."""
132 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize)
134 # Mock empty dataset
135 dataset._subvolumes = np.array([])
136 dataset._molecule_ids = np.array([])
137 dataset._is_background = np.array([])
138 dataset._keys = []
140 # Test length
141 self.assertEqual(len(dataset), 0)
143 # Test get_class_distribution with empty dataset
144 distribution = dataset.get_class_distribution()
145 self.assertEqual(distribution, {})
147 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
148 def test_compute_sample_weights(self, mock_load_data):
149 """Test the _compute_sample_weights method."""
150 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize)
152 # Create an unbalanced dataset
153 dataset._molecule_ids = [0, 0, 0, 1, 1, 2]
155 # Compute sample weights
156 dataset._compute_sample_weights()
158 # Check weights are inversely proportional to class frequency
159 expected_weights = [6 / 3, 6 / 3, 6 / 3, 6 / 2, 6 / 2, 6 / 1] # total_samples / count_per_class
160 np.testing.assert_array_almost_equal(dataset.sample_weights, expected_weights)
162 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
163 def test_get_sample_weights(self, mock_load_data):
164 """Test the get_sample_weights method."""
165 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize)
167 # Set sample weights
168 dataset.sample_weights = [1.0, 2.0, 3.0]
170 # Get sample weights
171 weights = dataset.get_sample_weights()
173 # Check weights are returned correctly
174 self.assertEqual(weights, [1.0, 2.0, 3.0])
176 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
177 def test_keys(self, mock_load_data):
178 """Test the keys method."""
179 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize)
181 # Set keys
182 dataset._keys = ["class1", "class2", "class3"]
184 # Get keys
185 keys = dataset.keys()
187 # Check keys are returned correctly
188 self.assertEqual(keys, ["class1", "class2", "class3"])
190 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
191 def test_get_class_distribution(self, mock_load_data):
192 """Test the get_class_distribution method."""
193 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize)
195 # Create a test dataset with class distribution
196 dataset._keys = ["class1", "class2", "class3"]
197 dataset._molecule_ids = [0, 0, 0, 1, 1, 2, -1, -1] # -1 is background
198 dataset._is_background = [False, False, False, False, False, False, True, True]
200 # Get class distribution
201 distribution = dataset.get_class_distribution()
203 # Check distribution is correct
204 expected_distribution = {"class1": 3, "class2": 2, "class3": 1, "background": 2}
205 self.assertEqual(distribution, expected_distribution)
207 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
208 def test_validation_logic_missing_config(self, mock_load_data):
209 """Test validation logic when both config_path and copick_root are missing."""
210 # Should raise ValueError when both config_path and copick_root are None
211 with self.assertRaises(ValueError):
212 SimpleCopickDataset(config_path=None, copick_root=None, boxsize=self.boxsize)
214 @patch("copick_torch.dataset.SimpleCopickDataset._extract_subvolume_with_validation")
215 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
216 def test_extract_subvolume_strategies(self, mock_load_data, mock_extract):
217 """Test different patch extraction strategies."""
218 # Create a sample tomogram_array
219 tomogram_array = np.zeros((32, 32, 32))
221 # Test centered strategy
222 dataset_centered = SimpleCopickDataset(
223 config_path=self.mock_config_path,
224 boxsize=self.boxsize,
225 patch_strategy="centered",
226 )
227 dataset_centered._extract_subvolume_with_validation(tomogram_array, 16, 16, 16)
228 mock_extract.assert_called_once()
230 mock_extract.reset_mock()
232 # Test random strategy
233 dataset_random = SimpleCopickDataset(
234 config_path=self.mock_config_path,
235 boxsize=self.boxsize,
236 patch_strategy="random",
237 )
238 dataset_random._extract_subvolume_with_validation(tomogram_array, 16, 16, 16)
239 mock_extract.assert_called_once()
241 mock_extract.reset_mock()
243 # Test jittered strategy
244 dataset_jittered = SimpleCopickDataset(
245 config_path=self.mock_config_path,
246 boxsize=self.boxsize,
247 patch_strategy="jittered",
248 )
249 dataset_jittered._extract_subvolume_with_validation(tomogram_array, 16, 16, 16)
250 mock_extract.assert_called_once()
253if __name__ == "__main__":
254 unittest.main()