Coverage for src / autoencodix / data / _sampler.py: 18%

38 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2from torch.utils.data import Sampler 

3from typing import Sized, Iterator, List 

4 

5 

6class BalancedBatchSampler(Sampler[List[int]]): 

7 """ 

8 A custom PyTorch Sampler that avoids creating a final batch of size 1. 

9 

10 This sampler behaves like a standard `BatchSampler` but with a key 

11 difference in handling the last batch. If the last batch would normally 

12 have a size of 1, this sampler redistributes the last two batches to be 

13 of roughly equal size. For example, if a dataset of 129 samples is used 

14 with a batch size of 128, instead of yielding batches of [128, 1], it 

15 will yield two balanced batches, such as [65, 64]. 

16 

17 This is particularly useful for avoiding issues with layers like 

18 BatchNorm, which require batch sizes greater than 1, without having to 

19 drop data (`drop_last=True`). 

20 

21 Args: 

22 data_source: The dataset to sample from. 

23 batch_size: The target number of samples in each batch. 

24 shuffle: If True, the sampler will shuffle the indices at start of each epoch. 

25 """ 

26 

27 def __init__(self, data_source: Sized, batch_size: int, shuffle: bool = True): 

28 """Initializes the BalancedBatchSampler. 

29 Args: 

30 data_source: The dataset to sample from. 

31 batch_size: The target number of samples in each batch. 

32 shuffle: If True, the sampler will shuffle the indices at start of each epoch. 

33 """ 

34 if not isinstance(batch_size, int) or batch_size <= 0: 

35 raise ValueError( 

36 f"batch_size should be a positive integer, but got {batch_size}" 

37 ) 

38 

39 self.data_source = data_source 

40 self.batch_size = batch_size 

41 self.shuffle = shuffle 

42 

43 def __iter__(self) -> Iterator[List[int]]: 

44 """ 

45 Returns an iterator over batches of indices. 

46 """ 

47 n_samples = len(self.data_source) 

48 if n_samples == 0: 

49 return 

50 

51 # Generate a list of indices 

52 indices = torch.arange(n_samples) 

53 if self.shuffle: 

54 # Use a random permutation for shuffling 

55 indices = torch.randperm(n_samples) 

56 

57 # Check for the special case where the last batch would be of size 1. 

58 # This logic only applies if there is more than one batch to begin with. 

59 if n_samples > self.batch_size and n_samples % self.batch_size == 1: 

60 # Calculate the number of full batches to yield before special handling 

61 num_full_batches = n_samples // self.batch_size - 1 

62 

63 # Yield the full-sized batches 

64 for i in range(num_full_batches): 

65 start_idx = i * self.batch_size 

66 end_idx = start_idx + self.batch_size 

67 yield indices[start_idx:end_idx].tolist() 

68 

69 # Handle the last two batches by redistributing them 

70 remaining_indices_start = num_full_batches * self.batch_size 

71 remaining_indices = indices[remaining_indices_start:] 

72 

73 # Split the remaining indices (batch_size + 1) into two roughly equal halves 

74 split_point = (self.batch_size + 1) // 2 

75 yield remaining_indices[:split_point].tolist() 

76 yield remaining_indices[split_point:].tolist() 

77 

78 else: 

79 # Standard behavior: yield batches of size `batch_size` 

80 # The last batch will have size > 1 or there will be no remainder. 

81 for i in range(0, n_samples, self.batch_size): 

82 end_idx = min(i + self.batch_size, n_samples) 

83 yield indices[i:end_idx].tolist() 

84 

85 def __len__(self) -> int: 

86 """ 

87 Returns the total number of batches in an epoch. 

88 """ 

89 n_samples = len(self.data_source) 

90 if n_samples == 0: 

91 return 0 

92 

93 # If we are redistributing, we create one extra batch compared to floor division 

94 if n_samples > self.batch_size and n_samples % self.batch_size == 1: 

95 return n_samples // self.batch_size + 1 

96 else: 

97 # Standard ceiling division to calculate number of batches 

98 return (n_samples + self.batch_size - 1) // self.batch_size