Coverage for tests/test_samplers.py: 0%
44 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import unittest
2from collections import Counter
4import numpy as np
5import torch
7from copick_torch import ClassBalancedSampler
10class TestClassBalancedSampler(unittest.TestCase):
11 def setUp(self):
12 # Create unbalanced data for testing
13 self.labels = [0, 0, 0, 0, 1, 1, 2] # Class 0 is overrepresented
15 def test_init(self):
16 """Test sampler initialization and weight calculation."""
17 sampler = ClassBalancedSampler(self.labels)
19 # Check initialization
20 np.testing.assert_array_equal(sampler.labels, np.array(self.labels))
21 self.assertEqual(sampler.num_samples, len(self.labels))
22 self.assertTrue(sampler.replacement)
24 # Check class counts
25 expected_counts = {0: 4, 1: 2, 2: 1}
26 self.assertEqual(sampler.class_counts, expected_counts)
28 # Check weights calculation
29 # Weight per class should be inversely proportional to count
30 expected_weights = np.array([1 / 4, 1 / 4, 1 / 4, 1 / 4, 1 / 2, 1 / 2, 1]) # Class 0 # Class 1 # Class 2
32 # Normalize to sum to 1
33 expected_weights = expected_weights / expected_weights.sum()
35 # Check weights are approximately equal (account for floating point precision)
36 np.testing.assert_allclose(sampler.weights, expected_weights)
38 def test_init_with_custom_samples(self):
39 """Test initialization with custom number of samples."""
40 custom_samples = 20
41 sampler = ClassBalancedSampler(self.labels, num_samples=custom_samples)
43 self.assertEqual(sampler.num_samples, custom_samples)
44 self.assertEqual(len(sampler), custom_samples)
46 def test_init_with_no_replacement(self):
47 """Test initialization with replacement=False."""
48 sampler = ClassBalancedSampler(self.labels, replacement=False)
50 self.assertFalse(sampler.replacement)
52 def test_iter(self):
53 """Test __iter__ method produces indices within range."""
54 # Set a fixed seed for reproducibility
55 np.random.seed(42)
57 # Create sampler
58 sampler = ClassBalancedSampler(self.labels, num_samples=50)
60 # Get indices
61 indices = list(iter(sampler))
63 # Check length
64 self.assertEqual(len(indices), 50)
66 # Check all indices are within valid range
67 self.assertTrue(all(0 <= idx < len(self.labels) for idx in indices))
69 # Count class occurrences to verify balance
70 sampled_labels = [self.labels[idx] for idx in indices]
71 label_counts = Counter(sampled_labels)
73 # Check that the minority class (2) has more representation than its original proportion
74 original_proportions = {0: 4 / 7, 1: 2 / 7, 2: 1 / 7} # ~57% # ~29% # ~14%
76 sampled_proportions = {cls: count / 50 for cls, count in label_counts.items()}
78 # The resampled proportion for class 2 should be higher than its original proportion
79 self.assertGreater(sampled_proportions[2], original_proportions[2])
81 # Similarly, the overrepresented class 0 should have a lower proportion
82 self.assertLess(sampled_proportions[0], original_proportions[0])
84 def test_len(self):
85 """Test __len__ method returns correct number of samples."""
86 sampler = ClassBalancedSampler(self.labels)
87 self.assertEqual(len(sampler), len(self.labels))
89 # Test with custom number of samples
90 custom_samples = 20
91 sampler = ClassBalancedSampler(self.labels, num_samples=custom_samples)
92 self.assertEqual(len(sampler), custom_samples)
95if __name__ == "__main__":
96 unittest.main()