Coverage for copick_torch/samplers.py: 42%

19 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1from collections import Counter 

2 

3import numpy as np 

4import torch 

5from torch.utils.data import Sampler 

6 

7 

8class ClassBalancedSampler(Sampler): 

9 """ 

10 A sampler that balances class distributions during training. 

11 

12 This sampler is designed to address the class imbalance problem by providing 

13 a way to balance the frequency of each class in the mini-batches. 

14 """ 

15 

16 def __init__(self, labels, num_samples=None, replacement=True): 

17 """ 

18 Initialize the class-balanced sampler. 

19 

20 Args: 

21 labels: List or tensor of integer class labels for each sample 

22 num_samples: Number of samples to draw (default: len(labels)) 

23 replacement: Whether to sample with replacement (default: True) 

24 """ 

25 self.labels = np.array(labels) 

26 self.num_samples = len(labels) if num_samples is None else num_samples 

27 self.replacement = replacement 

28 

29 # Count occurrences of each class 

30 label_counter = Counter(self.labels) 

31 self.class_counts = label_counter 

32 

33 # Calculate sampling weights 

34 weight_per_class = {class_idx: 1.0 / count for class_idx, count in label_counter.items()} 

35 self.weights = np.array([weight_per_class[label] for label in self.labels]) 

36 

37 # Normalize weights to sum to 1 

38 self.weights = self.weights / self.weights.sum() 

39 

40 def __iter__(self): 

41 """ 

42 Generate a random sequence of indices based on weighted sampling. 

43 

44 Returns: 

45 Iterator over indices 

46 """ 

47 # Generate random indices weighted by class distribution 

48 indices = np.random.choice(len(self.labels), size=self.num_samples, replace=self.replacement, p=self.weights) 

49 

50 return iter(indices.tolist()) 

51 

52 def __len__(self): 

53 """ 

54 Return the number of samples in the sampler. 

55 

56 Returns: 

57 Number of samples 

58 """ 

59 return self.num_samples