Coverage for src / autoencodix / data / _sampler.py: 18%
38 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2from torch.utils.data import Sampler
3from typing import Sized, Iterator, List
6class BalancedBatchSampler(Sampler[List[int]]):
7 """
8 A custom PyTorch Sampler that avoids creating a final batch of size 1.
10 This sampler behaves like a standard `BatchSampler` but with a key
11 difference in handling the last batch. If the last batch would normally
12 have a size of 1, this sampler redistributes the last two batches to be
13 of roughly equal size. For example, if a dataset of 129 samples is used
14 with a batch size of 128, instead of yielding batches of [128, 1], it
15 will yield two balanced batches, such as [65, 64].
17 This is particularly useful for avoiding issues with layers like
18 BatchNorm, which require batch sizes greater than 1, without having to
19 drop data (`drop_last=True`).
21 Args:
22 data_source: The dataset to sample from.
23 batch_size: The target number of samples in each batch.
24 shuffle: If True, the sampler will shuffle the indices at start of each epoch.
25 """
27 def __init__(self, data_source: Sized, batch_size: int, shuffle: bool = True):
28 """Initializes the BalancedBatchSampler.
29 Args:
30 data_source: The dataset to sample from.
31 batch_size: The target number of samples in each batch.
32 shuffle: If True, the sampler will shuffle the indices at start of each epoch.
33 """
34 if not isinstance(batch_size, int) or batch_size <= 0:
35 raise ValueError(
36 f"batch_size should be a positive integer, but got {batch_size}"
37 )
39 self.data_source = data_source
40 self.batch_size = batch_size
41 self.shuffle = shuffle
43 def __iter__(self) -> Iterator[List[int]]:
44 """
45 Returns an iterator over batches of indices.
46 """
47 n_samples = len(self.data_source)
48 if n_samples == 0:
49 return
51 # Generate a list of indices
52 indices = torch.arange(n_samples)
53 if self.shuffle:
54 # Use a random permutation for shuffling
55 indices = torch.randperm(n_samples)
57 # Check for the special case where the last batch would be of size 1.
58 # This logic only applies if there is more than one batch to begin with.
59 if n_samples > self.batch_size and n_samples % self.batch_size == 1:
60 # Calculate the number of full batches to yield before special handling
61 num_full_batches = n_samples // self.batch_size - 1
63 # Yield the full-sized batches
64 for i in range(num_full_batches):
65 start_idx = i * self.batch_size
66 end_idx = start_idx + self.batch_size
67 yield indices[start_idx:end_idx].tolist()
69 # Handle the last two batches by redistributing them
70 remaining_indices_start = num_full_batches * self.batch_size
71 remaining_indices = indices[remaining_indices_start:]
73 # Split the remaining indices (batch_size + 1) into two roughly equal halves
74 split_point = (self.batch_size + 1) // 2
75 yield remaining_indices[:split_point].tolist()
76 yield remaining_indices[split_point:].tolist()
78 else:
79 # Standard behavior: yield batches of size `batch_size`
80 # The last batch will have size > 1 or there will be no remainder.
81 for i in range(0, n_samples, self.batch_size):
82 end_idx = min(i + self.batch_size, n_samples)
83 yield indices[i:end_idx].tolist()
85 def __len__(self) -> int:
86 """
87 Returns the total number of batches in an epoch.
88 """
89 n_samples = len(self.data_source)
90 if n_samples == 0:
91 return 0
93 # If we are redistributing, we create one extra batch compared to floor division
94 if n_samples > self.batch_size and n_samples % self.batch_size == 1:
95 return n_samples // self.batch_size + 1
96 else:
97 # Standard ceiling division to calculate number of batches
98 return (n_samples + self.batch_size - 1) // self.batch_size