Coverage for copick_torch/samplers.py: 42%
19 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1from collections import Counter
3import numpy as np
4import torch
5from torch.utils.data import Sampler
8class ClassBalancedSampler(Sampler):
9 """
10 A sampler that balances class distributions during training.
12 This sampler is designed to address the class imbalance problem by providing
13 a way to balance the frequency of each class in the mini-batches.
14 """
16 def __init__(self, labels, num_samples=None, replacement=True):
17 """
18 Initialize the class-balanced sampler.
20 Args:
21 labels: List or tensor of integer class labels for each sample
22 num_samples: Number of samples to draw (default: len(labels))
23 replacement: Whether to sample with replacement (default: True)
24 """
25 self.labels = np.array(labels)
26 self.num_samples = len(labels) if num_samples is None else num_samples
27 self.replacement = replacement
29 # Count occurrences of each class
30 label_counter = Counter(self.labels)
31 self.class_counts = label_counter
33 # Calculate sampling weights
34 weight_per_class = {class_idx: 1.0 / count for class_idx, count in label_counter.items()}
35 self.weights = np.array([weight_per_class[label] for label in self.labels])
37 # Normalize weights to sum to 1
38 self.weights = self.weights / self.weights.sum()
40 def __iter__(self):
41 """
42 Generate a random sequence of indices based on weighted sampling.
44 Returns:
45 Iterator over indices
46 """
47 # Generate random indices weighted by class distribution
48 indices = np.random.choice(len(self.labels), size=self.num_samples, replace=self.replacement, p=self.weights)
50 return iter(indices.tolist())
52 def __len__(self):
53 """
54 Return the number of samples in the sampler.
56 Returns:
57 Number of samples
58 """
59 return self.num_samples