Coverage for tests/test_copick_dataset.py: 0%
88 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import os
2import shutil
3import tempfile
4import unittest
5from unittest.mock import MagicMock, patch
7import numpy as np
8import torch
10from copick_torch import CopickDataset
13class TestCopickDataset(unittest.TestCase):
14 def setUp(self):
15 # Create a temporary directory
16 self.test_dir = tempfile.mkdtemp()
17 self.cache_dir = os.path.join(self.test_dir, "cache")
18 os.makedirs(self.cache_dir, exist_ok=True)
20 # Mock config path
21 self.mock_config_path = os.path.join(self.test_dir, "mock_config.json")
23 # Parameters for testing
24 self.boxsize = (16, 16, 16)
25 self.voxel_spacing = 10.0
27 def tearDown(self):
28 # Clean up the temporary directory
29 shutil.rmtree(self.test_dir)
31 @patch("copick_torch.copick.CopickDataset._load_data")
32 def test_init_basic(self, mock_load_data):
33 """Test basic initialization of CopickDataset."""
34 # Initialize with minimal parameters
35 dataset = CopickDataset(
36 config_path=self.mock_config_path,
37 boxsize=self.boxsize,
38 cache_dir=None, # Don't use caching
39 )
41 # Verify initialization
42 self.assertEqual(dataset.config_path, self.mock_config_path)
43 self.assertEqual(dataset.boxsize, self.boxsize)
44 self.assertFalse(dataset.augment)
45 self.assertIsNone(dataset.cache_dir)
47 # Verify _load_data was called
48 mock_load_data.assert_called_once()
50 @patch("copick_torch.copick.CopickDataset._load_data")
51 def test_dataset_empty(self, mock_load_data):
52 """Test behavior with empty dataset."""
53 dataset = CopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize)
55 # Mock empty dataset
56 dataset._subvolumes = np.array([])
57 dataset._molecule_ids = np.array([])
58 dataset._is_background = np.array([])
59 dataset._keys = []
61 # Test length
62 self.assertEqual(len(dataset), 0)
64 # Test get_class_distribution with empty dataset
65 distribution = dataset.get_class_distribution()
66 self.assertEqual(distribution, {})
68 def test_augmentations(self):
69 """Test data augmentation functions."""
70 # Create a dataset with mocked _load_data
71 with patch("copick_torch.copick.CopickDataset._load_data"):
72 dataset = CopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize, augment=True)
74 # Create a test volume
75 test_volume = np.ones(self.boxsize)
77 # Test brightness augmentation
78 augmented = dataset._brightness(test_volume)
79 self.assertEqual(augmented.shape, test_volume.shape)
80 self.assertNotEqual(np.sum(augmented), np.sum(test_volume))
82 # Test intensity scaling
83 augmented = dataset._intensity_scaling(test_volume)
84 self.assertEqual(augmented.shape, test_volume.shape)
86 # Test flip
87 augmented = dataset._flip(test_volume)
88 self.assertEqual(augmented.shape, test_volume.shape)
90 # Test rotate
91 augmented = dataset._rotate(test_volume)
92 self.assertEqual(augmented.shape, test_volume.shape)
94 @patch("copick_torch.copick.CopickDataset._load_data")
95 def test_getitem_no_augment(self, mock_load_data):
96 """Test __getitem__ without augmentation."""
97 dataset = CopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize, augment=False)
99 # Create a mock dataset with one item
100 test_volume = np.ones(self.boxsize)
101 dataset._subvolumes = np.array([test_volume])
102 dataset._molecule_ids = np.array([0])
103 dataset._is_background = np.array([False])
104 dataset._keys = ["test_class"]
106 # Get the item
107 volume, label = dataset[0]
109 # Check shapes and types
110 self.assertEqual(volume.shape, (1, *self.boxsize)) # Check channel dimension added
111 self.assertIsInstance(volume, torch.Tensor)
112 self.assertIsInstance(label, dict) # Verify label is a dictionary
113 self.assertEqual(label["class_idx"], 0) # Check if class_idx is correct
115 @patch("copick_torch.copick.CopickDataset._load_data")
116 def test_stratified_split(self, mock_load_data):
117 """Test stratified_split method."""
118 dataset = CopickDataset(config_path=self.mock_config_path, boxsize=self.boxsize)
120 # Create a mock dataset with balanced classes
121 n_classes = 3
122 n_samples_per_class = 10
123 test_volumes = []
124 test_labels = []
126 for class_idx in range(n_classes):
127 for _ in range(n_samples_per_class):
128 test_volumes.append(np.ones(self.boxsize))
129 test_labels.append(class_idx)
131 dataset._subvolumes = np.array(test_volumes)
132 dataset._molecule_ids = np.array(test_labels)
133 dataset._is_background = np.array([False] * len(test_labels))
134 dataset._keys = [f"class_{i}" for i in range(n_classes)]
136 # Split the dataset
137 train_ds, val_ds, test_ds = dataset.stratified_split(train_ratio=0.6, val_ratio=0.2, test_ratio=0.2)
139 # Check split sizes
140 self.assertEqual(len(train_ds), int(0.6 * len(dataset)))
141 self.assertEqual(len(val_ds), int(0.2 * len(dataset)))
142 self.assertEqual(len(test_ds), len(dataset) - len(train_ds) - len(val_ds))
144 # Check that each split contains samples from all classes
145 train_labels = [dataset._molecule_ids[i] for i in train_ds.indices]
146 val_labels = [dataset._molecule_ids[i] for i in val_ds.indices]
147 test_labels = [dataset._molecule_ids[i] for i in test_ds.indices]
149 for class_idx in range(n_classes):
150 self.assertIn(class_idx, train_labels)
151 self.assertIn(class_idx, val_labels)
152 self.assertIn(class_idx, test_labels)
155if __name__ == "__main__":
156 unittest.main()