Coverage for tests/test_copick_dataset.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import os 

2import shutil 

3import tempfile 

4import unittest 

5from unittest.mock import MagicMock, patch 

6 

7import numpy as np 

8import torch 

9 

10from copick_torch import CopickDataset 

11 

12 

13class TestCopickDataset(unittest.TestCase): 

14 def setUp(self): 

15 # Create a temporary directory 

16 self.test_dir = tempfile.mkdtemp() 

17 self.cache_dir = os.path.join(self.test_dir, "cache") 

18 os.makedirs(self.cache_dir, exist_ok=True) 

19 

20 # Mock config path 

21 self.mock_config_path = os.path.join(self.test_dir, "mock_config.json") 

22 

23 # Parameters for testing 

24 self.boxsize = (16, 16, 16) 

25 self.voxel_spacing = 10.0 

26 

27 def tearDown(self): 

28 # Clean up the temporary directory 

29 shutil.rmtree(self.test_dir) 

30 

31 @patch("copick_torch.copick.CopickDataset._load_data") 

32 def test_init_basic(self, mock_load_data): 

33 """Test basic initialization of CopickDataset.""" 

34 # Initialize with minimal parameters 

35 dataset = CopickDataset( 

36 config_path=self.mock_config_path, 

37 boxsize=self.boxsize, 

38 cache_dir=None, # Don't use caching 

39 ) 

40 

41 # Verify initialization 

42 self.assertEqual(dataset.config_path, self.mock_config_path) 

43 self.assertEqual(dataset.boxsize, self.boxsize) 

44 self.assertFalse(dataset.augment) 

45 self.assertIsNone(dataset.cache_dir) 

46 

47 # Verify _load_data was called 

48 mock_load_data.assert_called_once() 

49 

50 @patch("copick_torch.copick.CopickDataset._load_data") 

51 def test_dataset_empty(self, mock_load_data): 

52 """Test behavior with empty dataset.""" 

53 dataset = CopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize) 

54 

55 # Mock empty dataset 

56 dataset._subvolumes = np.array([]) 

57 dataset._molecule_ids = np.array([]) 

58 dataset._is_background = np.array([]) 

59 dataset._keys = [] 

60 

61 # Test length 

62 self.assertEqual(len(dataset), 0) 

63 

64 # Test get_class_distribution with empty dataset 

65 distribution = dataset.get_class_distribution() 

66 self.assertEqual(distribution, {}) 

67 

68 def test_augmentations(self): 

69 """Test data augmentation functions.""" 

70 # Create a dataset with mocked _load_data 

71 with patch("copick_torch.copick.CopickDataset._load_data"): 

72 dataset = CopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize, augment=True) 

73 

74 # Create a test volume 

75 test_volume = np.ones(self.boxsize) 

76 

77 # Test brightness augmentation 

78 augmented = dataset._brightness(test_volume) 

79 self.assertEqual(augmented.shape, test_volume.shape) 

80 self.assertNotEqual(np.sum(augmented), np.sum(test_volume)) 

81 

82 # Test intensity scaling 

83 augmented = dataset._intensity_scaling(test_volume) 

84 self.assertEqual(augmented.shape, test_volume.shape) 

85 

86 # Test flip 

87 augmented = dataset._flip(test_volume) 

88 self.assertEqual(augmented.shape, test_volume.shape) 

89 

90 # Test rotate 

91 augmented = dataset._rotate(test_volume) 

92 self.assertEqual(augmented.shape, test_volume.shape) 

93 

94 @patch("copick_torch.copick.CopickDataset._load_data") 

95 def test_getitem_no_augment(self, mock_load_data): 

96 """Test __getitem__ without augmentation.""" 

97 dataset = CopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize, augment=False) 

98 

99 # Create a mock dataset with one item 

100 test_volume = np.ones(self.boxsize) 

101 dataset._subvolumes = np.array([test_volume]) 

102 dataset._molecule_ids = np.array([0]) 

103 dataset._is_background = np.array([False]) 

104 dataset._keys = ["test_class"] 

105 

106 # Get the item 

107 volume, label = dataset[0] 

108 

109 # Check shapes and types 

110 self.assertEqual(volume.shape, (1, *self.boxsize)) # Check channel dimension added 

111 self.assertIsInstance(volume, torch.Tensor) 

112 self.assertIsInstance(label, dict) # Verify label is a dictionary 

113 self.assertEqual(label["class_idx"], 0) # Check if class_idx is correct 

114 

115 @patch("copick_torch.copick.CopickDataset._load_data") 

116 def test_stratified_split(self, mock_load_data): 

117 """Test stratified_split method.""" 

118 dataset = CopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize) 

119 

120 # Create a mock dataset with balanced classes 

121 n_classes = 3 

122 n_samples_per_class = 10 

123 test_volumes = [] 

124 test_labels = [] 

125 

126 for class_idx in range(n_classes): 

127 for _ in range(n_samples_per_class): 

128 test_volumes.append(np.ones(self.boxsize)) 

129 test_labels.append(class_idx) 

130 

131 dataset._subvolumes = np.array(test_volumes) 

132 dataset._molecule_ids = np.array(test_labels) 

133 dataset._is_background = np.array([False] * len(test_labels)) 

134 dataset._keys = [f"class_{i}" for i in range(n_classes)] 

135 

136 # Split the dataset 

137 train_ds, val_ds, test_ds = dataset.stratified_split(train_ratio=0.6, val_ratio=0.2, test_ratio=0.2) 

138 

139 # Check split sizes 

140 self.assertEqual(len(train_ds), int(0.6 * len(dataset))) 

141 self.assertEqual(len(val_ds), int(0.2 * len(dataset))) 

142 self.assertEqual(len(test_ds), len(dataset) - len(train_ds) - len(val_ds)) 

143 

144 # Check that each split contains samples from all classes 

145 train_labels = [dataset._molecule_ids[i] for i in train_ds.indices] 

146 val_labels = [dataset._molecule_ids[i] for i in val_ds.indices] 

147 test_labels = [dataset._molecule_ids[i] for i in test_ds.indices] 

148 

149 for class_idx in range(n_classes): 

150 self.assertIn(class_idx, train_labels) 

151 self.assertIn(class_idx, val_labels) 

152 self.assertIn(class_idx, test_labels) 

153 

154 

155if __name__ == "__main__": 

156 unittest.main()