Coverage for tests/test_samplers.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import unittest 

2from collections import Counter 

3 

4import numpy as np 

5import torch 

6 

7from copick_torch import ClassBalancedSampler 

8 

9 

10class TestClassBalancedSampler(unittest.TestCase): 

11 def setUp(self): 

12 # Create unbalanced data for testing 

13 self.labels = [0, 0, 0, 0, 1, 1, 2] # Class 0 is overrepresented 

14 

15 def test_init(self): 

16 """Test sampler initialization and weight calculation.""" 

17 sampler = ClassBalancedSampler(self.labels) 

18 

19 # Check initialization 

20 np.testing.assert_array_equal(sampler.labels, np.array(self.labels)) 

21 self.assertEqual(sampler.num_samples, len(self.labels)) 

22 self.assertTrue(sampler.replacement) 

23 

24 # Check class counts 

25 expected_counts = {0: 4, 1: 2, 2: 1} 

26 self.assertEqual(sampler.class_counts, expected_counts) 

27 

28 # Check weights calculation 

29 # Weight per class should be inversely proportional to count 

30 expected_weights = np.array([1 / 4, 1 / 4, 1 / 4, 1 / 4, 1 / 2, 1 / 2, 1]) # Class 0 # Class 1 # Class 2 

31 

32 # Normalize to sum to 1 

33 expected_weights = expected_weights / expected_weights.sum() 

34 

35 # Check weights are approximately equal (account for floating point precision) 

36 np.testing.assert_allclose(sampler.weights, expected_weights) 

37 

38 def test_init_with_custom_samples(self): 

39 """Test initialization with custom number of samples.""" 

40 custom_samples = 20 

41 sampler = ClassBalancedSampler(self.labels, num_samples=custom_samples) 

42 

43 self.assertEqual(sampler.num_samples, custom_samples) 

44 self.assertEqual(len(sampler), custom_samples) 

45 

46 def test_init_with_no_replacement(self): 

47 """Test initialization with replacement=False.""" 

48 sampler = ClassBalancedSampler(self.labels, replacement=False) 

49 

50 self.assertFalse(sampler.replacement) 

51 

52 def test_iter(self): 

53 """Test __iter__ method produces indices within range.""" 

54 # Set a fixed seed for reproducibility 

55 np.random.seed(42) 

56 

57 # Create sampler 

58 sampler = ClassBalancedSampler(self.labels, num_samples=50) 

59 

60 # Get indices 

61 indices = list(iter(sampler)) 

62 

63 # Check length 

64 self.assertEqual(len(indices), 50) 

65 

66 # Check all indices are within valid range 

67 self.assertTrue(all(0 <= idx < len(self.labels) for idx in indices)) 

68 

69 # Count class occurrences to verify balance 

70 sampled_labels = [self.labels[idx] for idx in indices] 

71 label_counts = Counter(sampled_labels) 

72 

73 # Check that the minority class (2) has more representation than its original proportion 

74 original_proportions = {0: 4 / 7, 1: 2 / 7, 2: 1 / 7} # ~57% # ~29% # ~14% 

75 

76 sampled_proportions = {cls: count / 50 for cls, count in label_counts.items()} 

77 

78 # The resampled proportion for class 2 should be higher than its original proportion 

79 self.assertGreater(sampled_proportions[2], original_proportions[2]) 

80 

81 # Similarly, the overrepresented class 0 should have a lower proportion 

82 self.assertLess(sampled_proportions[0], original_proportions[0]) 

83 

84 def test_len(self): 

85 """Test __len__ method returns correct number of samples.""" 

86 sampler = ClassBalancedSampler(self.labels) 

87 self.assertEqual(len(sampler), len(self.labels)) 

88 

89 # Test with custom number of samples 

90 custom_samples = 20 

91 sampler = ClassBalancedSampler(self.labels, num_samples=custom_samples) 

92 self.assertEqual(len(sampler), custom_samples) 

93 

94 

95if __name__ == "__main__": 

96 unittest.main()