Coverage for tests/test_dataset_extraction.py: 0%
67 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import os
2import shutil
3import tempfile
4import unittest
5from unittest.mock import patch
7import numpy as np
9from copick_torch import SimpleCopickDataset
12class TestDatasetExtraction(unittest.TestCase):
13 """Test the subvolume extraction functionality of SimpleCopickDataset."""
15 def setUp(self):
16 # Create a temporary directory for testing
17 self.test_dir = tempfile.mkdtemp()
18 self.mock_config_path = os.path.join(self.test_dir, "mock_config.json")
20 # Define box size for testing
21 self.boxsize = (16, 16, 16)
23 # Create a test tomogram
24 self.tomogram_array = np.ones((32, 32, 32))
26 # Add a gradient pattern to make the tomogram less uniform
27 x, y, z = np.meshgrid(np.linspace(0, 1, 32), np.linspace(0, 1, 32), np.linspace(0, 1, 32))
28 self.tomogram_array = self.tomogram_array * (x + y + z)
30 def tearDown(self):
31 # Clean up the temporary directory
32 shutil.rmtree(self.test_dir)
34 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
35 def test_extract_center_valid(self, mock_load_data):
36 """Test extracting a valid subvolume from the center of the tomogram."""
37 dataset = SimpleCopickDataset(
38 config_path=self.mock_config_path,
39 boxsize=self.boxsize,
40 patch_strategy="centered",
41 )
43 # Extract from center of tomogram (should be valid)
44 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16)
46 # Check results
47 self.assertTrue(is_valid)
48 self.assertEqual(status, "valid")
49 self.assertEqual(subvolume.shape, self.boxsize)
51 # Since we extracted from the center, values should match the source tomogram
52 center_slice = self.tomogram_array[8:24, 8:24, 8:24]
53 np.testing.assert_array_equal(subvolume, center_slice)
55 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
56 def test_extract_edge_padded(self, mock_load_data):
57 """Test extracting a subvolume near the edge of the tomogram (requires padding)."""
58 dataset = SimpleCopickDataset(
59 config_path=self.mock_config_path,
60 boxsize=self.boxsize,
61 patch_strategy="centered",
62 )
64 # Extract from near edge of tomogram (should require padding)
65 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(self.tomogram_array, 3, 16, 16)
67 # Check results
68 self.assertTrue(is_valid)
69 self.assertEqual(status, "padded")
70 self.assertEqual(subvolume.shape, self.boxsize)
72 # Check that the extracted subvolume has some zeros (from padding)
73 self.assertTrue(np.any(subvolume == 0))
75 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
76 def test_extract_near_edge(self, mock_load_data):
77 """Test extracting a subvolume very close to the edge of the tomogram."""
78 dataset = SimpleCopickDataset(
79 config_path=self.mock_config_path,
80 boxsize=(16, 16, 16),
81 patch_strategy="centered",
82 )
84 # Try to extract from positions that are technically valid but will need padding
85 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(
86 self.tomogram_array,
87 2,
88 2,
89 2, # Very close to the edge (0,0,0)
90 )
92 # The actual implementation pads rather than invalidates, so check for padding
93 self.assertTrue(is_valid)
94 self.assertEqual(status, "padded")
95 self.assertEqual(subvolume.shape, (16, 16, 16))
97 # Should contain zeros from padding
98 self.assertTrue(np.any(subvolume == 0))
100 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
101 def test_random_strategy(self, mock_load_data):
102 """Test the random patch extraction strategy."""
103 dataset = SimpleCopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize, patch_strategy="random")
105 # Set a fixed seed for reproducibility
106 np.random.seed(42)
108 # Extract with random strategy
109 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16)
111 # Check results
112 self.assertTrue(is_valid)
113 self.assertEqual(subvolume.shape, self.boxsize)
115 # Extract again with same seed
116 np.random.seed(42)
117 subvolume2, is_valid2, status2 = dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16)
119 # Both extractions should be identical with the same seed
120 np.testing.assert_array_equal(subvolume, subvolume2)
122 @patch("copick_torch.dataset.SimpleCopickDataset._load_data")
123 def test_jittered_strategy(self, mock_load_data):
124 """Test the jittered patch extraction strategy."""
125 dataset = SimpleCopickDataset(
126 config_path=self.mock_config_path,
127 boxsize=self.boxsize,
128 patch_strategy="jittered",
129 )
131 # Set a fixed seed for reproducibility
132 np.random.seed(42)
134 # Extract with jittered strategy
135 subvolume, is_valid, status = dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16)
137 # Check results
138 self.assertTrue(is_valid)
139 self.assertEqual(subvolume.shape, self.boxsize)
141 # Extract with centered strategy for comparison
142 centered_dataset = SimpleCopickDataset(
143 config_path=self.mock_config_path,
144 boxsize=self.boxsize,
145 patch_strategy="centered",
146 )
148 centered_subvolume, _, _ = centered_dataset._extract_subvolume_with_validation(self.tomogram_array, 16, 16, 16)
150 # Jittered should be different from centered (small chance they're identical)
151 # This might rarely fail if the random jitter is (0,0,0)
152 try:
153 np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, subvolume, centered_subvolume)
154 except AssertionError:
155 # If the above fails, check that we had a very small jitter
156 # by verifying most values are the same
157 same_values = np.count_nonzero(subvolume == centered_subvolume)
158 total_values = np.prod(self.boxsize)
159 self.assertGreater(same_values / total_values, 0.9) # >90% same
162if __name__ == "__main__":
163 unittest.main()