Coverage for tests/test_minimal_dataset.py: 0%

99 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import os 

2import shutil 

3import tempfile 

4import unittest 

5from unittest.mock import MagicMock, patch 

6 

7import numpy as np 

8import torch 

9 

10from copick_torch import MinimalCopickDataset 

11 

12 

13class TestMinimalCopickDataset(unittest.TestCase): 

14 """ 

15 Test the MinimalCopickDataset class. 

16 """ 

17 

18 @patch("zarr.open") 

19 @patch("copick.from_czcdp_datasets") 

20 def test_dataset_initialization(self, mock_from_czcdp, mock_zarr_open): 

21 """Test that the dataset can be initialized and returns correct data.""" 

22 # Set up mocks 

23 mock_copick_root = MagicMock() 

24 mock_run = MagicMock() 

25 mock_vs = MagicMock() 

26 mock_tomogram = MagicMock() 

27 

28 # Configure mocks 

29 mock_from_czcdp.return_value = mock_copick_root 

30 mock_copick_root.runs = [mock_run] 

31 mock_copick_root.pickable_objects = [MagicMock(name="object1"), MagicMock(name="object2")] 

32 

33 for po in mock_copick_root.pickable_objects: 

34 po.name = po.name 

35 

36 mock_run.name = "test_run" 

37 mock_run.get_voxel_spacing.return_value = mock_vs 

38 mock_vs.tomograms = [mock_tomogram] 

39 mock_tomogram.tomo_type = "wbp-denoised" 

40 

41 # Setup picks for each object 

42 mock_picks_list = [] 

43 for po in mock_copick_root.pickable_objects: 

44 mock_pick = MagicMock() 

45 mock_pick.from_tool = True 

46 mock_pick.pickable_object_name = po.name 

47 # Create 5 points for each object 

48 mock_pick.numpy.return_value = ( 

49 np.array([[100, 100, 100], [200, 200, 200], [300, 300, 300], [400, 400, 400], [500, 500, 500]]), 

50 None, 

51 ) 

52 mock_picks_list.append(mock_pick) 

53 

54 mock_run.get_picks.return_value = mock_picks_list 

55 

56 # Mock zarr.open to return a dummy array 

57 mock_zarr_root = MagicMock() 

58 mock_zarr_open.return_value = mock_zarr_root 

59 # Create a dummy 3D array 

60 dummy_array = np.random.randn(100, 100, 100) 

61 mock_zarr_root.__getitem__.return_value = dummy_array 

62 mock_tomogram.zarr.return_value = "dummy_zarr_path" 

63 

64 # Create the dataset 

65 dataset = MinimalCopickDataset( 

66 dataset_id=10440, 

67 overlay_root="/tmp/test/", 

68 boxsize=(32, 32, 32), 

69 voxel_spacing=10.0, 

70 include_background=True, 

71 background_ratio=0.2, 

72 ) 

73 

74 # Test the dataset properties 

75 self.assertIsNotNone(dataset) 

76 self.assertEqual(dataset.dataset_id, 10440) 

77 self.assertEqual(dataset.boxsize, (32, 32, 32)) 

78 self.assertEqual(dataset.voxel_spacing, 10.0) 

79 self.assertTrue(dataset.include_background) 

80 

81 # Check the class names 

82 expected_class_names = ["object1", "object2"] 

83 if dataset.include_background: 

84 expected_class_names.append("background") 

85 

86 self.assertEqual(set(dataset.keys()), set(expected_class_names)) 

87 

88 # Test length - with 5 points per object (2 objects) and background_ratio of 0.2, 

89 # we expect 10 object points + approximately 2 background points 

90 # However, the exact number of background points may vary due to random sampling 

91 # So we just check it's at least the total object points 

92 self.assertGreaterEqual(len(dataset), 10) 

93 

94 # Test getting an item 

95 volume, label = dataset[0] 

96 

97 # Check the shape and type 

98 self.assertEqual(volume.shape, (1, 32, 32, 32)) # [C, D, H, W] 

99 self.assertIsInstance(volume, torch.Tensor) 

100 self.assertIsInstance(label, int) 

101 

102 # Check label is in the expected range 

103 self.assertIn(label, [-1, 0, 1]) # -1 for background, 0-1 for objects 

104 

105 # Test the class distribution 

106 distribution = dataset.get_class_distribution() 

107 

108 # Check each object has 5 points 

109 for obj_name in expected_class_names[:2]: # Skip "background" 

110 self.assertEqual(distribution.get(obj_name, 0), 5) 

111 

112 # Test sample weights 

113 weights = dataset.get_sample_weights() 

114 self.assertEqual(len(weights), len(dataset)) 

115 

116 @patch("zarr.open") 

117 @patch("copick.from_czcdp_datasets") 

118 def test_class_to_label_consistency(self, mock_from_czcdp, mock_zarr_open): 

119 """Test that class names and labels are consistent.""" 

120 # Set up mocks (simplified version compared to above) 

121 mock_copick_root = MagicMock() 

122 mock_run = MagicMock() 

123 mock_vs = MagicMock() 

124 mock_tomogram = MagicMock() 

125 

126 # Configure mocks 

127 mock_from_czcdp.return_value = mock_copick_root 

128 mock_copick_root.runs = [mock_run] 

129 mock_copick_root.pickable_objects = [ 

130 MagicMock(name="object1"), 

131 MagicMock(name="object2"), 

132 MagicMock(name="object3"), 

133 ] 

134 

135 for po in mock_copick_root.pickable_objects: 

136 po.name = po.name 

137 

138 mock_run.name = "test_run" 

139 mock_run.get_voxel_spacing.return_value = mock_vs 

140 mock_vs.tomograms = [mock_tomogram] 

141 mock_tomogram.tomo_type = "wbp-denoised" 

142 

143 # Setup picks - just one point per object for simplicity 

144 mock_picks_list = [] 

145 for i, po in enumerate(mock_copick_root.pickable_objects): 

146 mock_pick = MagicMock() 

147 mock_pick.from_tool = True 

148 mock_pick.pickable_object_name = po.name 

149 mock_pick.numpy.return_value = (np.array([[100 * (i + 1), 100 * (i + 1), 100 * (i + 1)]]), None) 

150 mock_picks_list.append(mock_pick) 

151 

152 mock_run.get_picks.return_value = mock_picks_list 

153 

154 # Mock zarr.open to return a dummy array 

155 mock_zarr_root = MagicMock() 

156 mock_zarr_open.return_value = mock_zarr_root 

157 mock_zarr_root.__getitem__.return_value = np.zeros((500, 500, 500)) 

158 mock_tomogram.zarr.return_value = "dummy_zarr_path" 

159 

160 # Create the dataset without background for simpler testing 

161 dataset = MinimalCopickDataset( 

162 dataset_id=10440, 

163 overlay_root="/tmp/test/", 

164 boxsize=(32, 32, 32), 

165 voxel_spacing=10.0, 

166 include_background=False, # No background samples 

167 ) 

168 

169 # Check length - should be exactly 3 points (one per object) 

170 self.assertEqual(len(dataset), 3) 

171 

172 # Verify class names and indices 

173 class_names = dataset.keys() 

174 self.assertEqual(len(class_names), 3) 

175 

176 # For each sample, verify the label matches the expected class 

177 for i in range(len(dataset)): 

178 _, label = dataset[i] 

179 # The label should be a valid index in the class_names list 

180 self.assertTrue(0 <= label < len(class_names)) 

181 # Get the class name from the label 

182 class_name = class_names[label] 

183 # Make sure it's one of our expected class names 

184 self.assertIn(class_name, ["object1", "object2", "object3"]) 

185 

186 # Test that the class distribution is correct 

187 distribution = dataset.get_class_distribution() 

188 

189 for name in ["object1", "object2", "object3"]: 

190 self.assertEqual(distribution.get(name, 0), 1) 

191 

192 

193if __name__ == "__main__": 

194 unittest.main()