Coverage for tests/test_minimal_dataset.py: 0%
99 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import os
2import shutil
3import tempfile
4import unittest
5from unittest.mock import MagicMock, patch
7import numpy as np
8import torch
10from copick_torch import MinimalCopickDataset
13class TestMinimalCopickDataset(unittest.TestCase):
14 """
15 Test the MinimalCopickDataset class.
16 """
18 @patch("zarr.open")
19 @patch("copick.from_czcdp_datasets")
20 def test_dataset_initialization(self, mock_from_czcdp, mock_zarr_open):
21 """Test that the dataset can be initialized and returns correct data."""
22 # Set up mocks
23 mock_copick_root = MagicMock()
24 mock_run = MagicMock()
25 mock_vs = MagicMock()
26 mock_tomogram = MagicMock()
28 # Configure mocks
29 mock_from_czcdp.return_value = mock_copick_root
30 mock_copick_root.runs = [mock_run]
31 mock_copick_root.pickable_objects = [MagicMock(name="object1"), MagicMock(name="object2")]
33 for po in mock_copick_root.pickable_objects:
34 po.name = po.name
36 mock_run.name = "test_run"
37 mock_run.get_voxel_spacing.return_value = mock_vs
38 mock_vs.tomograms = [mock_tomogram]
39 mock_tomogram.tomo_type = "wbp-denoised"
41 # Setup picks for each object
42 mock_picks_list = []
43 for po in mock_copick_root.pickable_objects:
44 mock_pick = MagicMock()
45 mock_pick.from_tool = True
46 mock_pick.pickable_object_name = po.name
47 # Create 5 points for each object
48 mock_pick.numpy.return_value = (
49 np.array([[100, 100, 100], [200, 200, 200], [300, 300, 300], [400, 400, 400], [500, 500, 500]]),
50 None,
51 )
52 mock_picks_list.append(mock_pick)
54 mock_run.get_picks.return_value = mock_picks_list
56 # Mock zarr.open to return a dummy array
57 mock_zarr_root = MagicMock()
58 mock_zarr_open.return_value = mock_zarr_root
59 # Create a dummy 3D array
60 dummy_array = np.random.randn(100, 100, 100)
61 mock_zarr_root.__getitem__.return_value = dummy_array
62 mock_tomogram.zarr.return_value = "dummy_zarr_path"
64 # Create the dataset
65 dataset = MinimalCopickDataset(
66 dataset_id=10440,
67 overlay_root="/tmp/test/",
68 boxsize=(32, 32, 32),
69 voxel_spacing=10.0,
70 include_background=True,
71 background_ratio=0.2,
72 )
74 # Test the dataset properties
75 self.assertIsNotNone(dataset)
76 self.assertEqual(dataset.dataset_id, 10440)
77 self.assertEqual(dataset.boxsize, (32, 32, 32))
78 self.assertEqual(dataset.voxel_spacing, 10.0)
79 self.assertTrue(dataset.include_background)
81 # Check the class names
82 expected_class_names = ["object1", "object2"]
83 if dataset.include_background:
84 expected_class_names.append("background")
86 self.assertEqual(set(dataset.keys()), set(expected_class_names))
88 # Test length - with 5 points per object (2 objects) and background_ratio of 0.2,
89 # we expect 10 object points + approximately 2 background points
90 # However, the exact number of background points may vary due to random sampling
91 # So we just check it's at least the total object points
92 self.assertGreaterEqual(len(dataset), 10)
94 # Test getting an item
95 volume, label = dataset[0]
97 # Check the shape and type
98 self.assertEqual(volume.shape, (1, 32, 32, 32)) # [C, D, H, W]
99 self.assertIsInstance(volume, torch.Tensor)
100 self.assertIsInstance(label, int)
102 # Check label is in the expected range
103 self.assertIn(label, [-1, 0, 1]) # -1 for background, 0-1 for objects
105 # Test the class distribution
106 distribution = dataset.get_class_distribution()
108 # Check each object has 5 points
109 for obj_name in expected_class_names[:2]: # Skip "background"
110 self.assertEqual(distribution.get(obj_name, 0), 5)
112 # Test sample weights
113 weights = dataset.get_sample_weights()
114 self.assertEqual(len(weights), len(dataset))
116 @patch("zarr.open")
117 @patch("copick.from_czcdp_datasets")
118 def test_class_to_label_consistency(self, mock_from_czcdp, mock_zarr_open):
119 """Test that class names and labels are consistent."""
120 # Set up mocks (simplified version compared to above)
121 mock_copick_root = MagicMock()
122 mock_run = MagicMock()
123 mock_vs = MagicMock()
124 mock_tomogram = MagicMock()
126 # Configure mocks
127 mock_from_czcdp.return_value = mock_copick_root
128 mock_copick_root.runs = [mock_run]
129 mock_copick_root.pickable_objects = [
130 MagicMock(name="object1"),
131 MagicMock(name="object2"),
132 MagicMock(name="object3"),
133 ]
135 for po in mock_copick_root.pickable_objects:
136 po.name = po.name
138 mock_run.name = "test_run"
139 mock_run.get_voxel_spacing.return_value = mock_vs
140 mock_vs.tomograms = [mock_tomogram]
141 mock_tomogram.tomo_type = "wbp-denoised"
143 # Setup picks - just one point per object for simplicity
144 mock_picks_list = []
145 for i, po in enumerate(mock_copick_root.pickable_objects):
146 mock_pick = MagicMock()
147 mock_pick.from_tool = True
148 mock_pick.pickable_object_name = po.name
149 mock_pick.numpy.return_value = (np.array([[100 * (i + 1), 100 * (i + 1), 100 * (i + 1)]]), None)
150 mock_picks_list.append(mock_pick)
152 mock_run.get_picks.return_value = mock_picks_list
154 # Mock zarr.open to return a dummy array
155 mock_zarr_root = MagicMock()
156 mock_zarr_open.return_value = mock_zarr_root
157 mock_zarr_root.__getitem__.return_value = np.zeros((500, 500, 500))
158 mock_tomogram.zarr.return_value = "dummy_zarr_path"
160 # Create the dataset without background for simpler testing
161 dataset = MinimalCopickDataset(
162 dataset_id=10440,
163 overlay_root="/tmp/test/",
164 boxsize=(32, 32, 32),
165 voxel_spacing=10.0,
166 include_background=False, # No background samples
167 )
169 # Check length - should be exactly 3 points (one per object)
170 self.assertEqual(len(dataset), 3)
172 # Verify class names and indices
173 class_names = dataset.keys()
174 self.assertEqual(len(class_names), 3)
176 # For each sample, verify the label matches the expected class
177 for i in range(len(dataset)):
178 _, label = dataset[i]
179 # The label should be a valid index in the class_names list
180 self.assertTrue(0 <= label < len(class_names))
181 # Get the class name from the label
182 class_name = class_names[label]
183 # Make sure it's one of our expected class names
184 self.assertIn(class_name, ["object1", "object2", "object3"])
186 # Test that the class distribution is correct
187 distribution = dataset.get_class_distribution()
189 for name in ["object1", "object2", "object3"]:
190 self.assertEqual(distribution.get(name, 0), 1)
193if __name__ == "__main__":
194 unittest.main()