Coverage for src / autoencodix / data / _stackix_dataset.py: 34%

29 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Any, Dict, Optional, Tuple, Union 

2import torch 

3from autoencodix.base._base_dataset import BaseDataset 

4from autoencodix.data._numeric_dataset import NumericDataset 

5from autoencodix.configs.default_config import DefaultConfig 

6 

7 

8class StackixDataset(NumericDataset): 

9 """ 

10 Dataset for handling multiple modalities in Stackix models. 

11 

12 This dataset holds individual BaseDataset objects for multiple data modalities 

13 and provides a consistent interface for accessing them during training. 

14 It's designed to work specifically with StackixTrainer. 

15 

16 Attributes: 

17 dataset_dict: Dictionary mapping modality names to dataset objects 

18 modality_keys: List of modality names 

19 """ 

20 

21 def __init__( 

22 self, 

23 dataset_dict: Dict[str, BaseDataset], 

24 config: DefaultConfig, 

25 ): 

26 """ 

27 Initialize a StackixDataset instance. 

28 

29 Args: 

30 dataset_dict: Dictionary mapping modality names to dataset objects 

31 config: Configuration object 

32 

33 Raises: 

34 ValueError: If the datasets dictionary is empty or if modality datasets have different numbers of samples 

35 NotImplementedError: If the datasets have incompatible shapes for concatenation 

36 """ 

37 if not dataset_dict: 

38 raise ValueError("dataset_dict cannot be empty") 

39 

40 # Use first modality for base class initialization 

41 first_modality_key = next(iter(dataset_dict.keys())) 

42 first_modality = dataset_dict[first_modality_key] 

43 try: 

44 data = torch.cat( 

45 [v.data for _, v in dataset_dict.items() if hasattr(v, "data")], dim=1 

46 ) 

47 except Exception: 

48 raise NotImplementedError( 

49 "Data modalities have different shapes, set requires_paired=True in config" 

50 ) 

51 super().__init__( 

52 data=data, 

53 sample_ids=first_modality.sample_ids, 

54 config=config, 

55 split_indices=first_modality.split_indices, 

56 metadata=first_modality.metadata, 

57 feature_ids=[ 

58 v.feature_ids 

59 for v in dataset_dict.values() 

60 if hasattr(v, "feature_ids") 

61 ], 

62 ) 

63 

64 self.dataset_dict = dataset_dict 

65 self.modality_keys = list(dataset_dict.keys()) 

66 

67 # Ensure all datasets have the same number of samples 

68 sample_counts = [len(dataset) for dataset in dataset_dict.values()] 

69 if not all(count == sample_counts[0] for count in sample_counts): 

70 raise ValueError( 

71 "All modality datasets must have the same number of samples" 

72 ) 

73 

74 def __len__(self) -> int: 

75 """Return the number of samples in the dataset.""" 

76 return len(next(iter(self.dataset_dict.values()))) 

77 

78 def __getitem__( 

79 self, index: int 

80 ) -> Union[Tuple[torch.Tensor, Any], Dict[str, Tuple[torch.Tensor, Any]]]: 

81 """ 

82 Get a single sample and its label from the dataset. 

83 

84 Returns the data from the first modality to maintain compatibility 

85 with the BaseDataset interface, while still supporting multi-modality 

86 access through dataset_dict. 

87 Args: 

88 index: Index of the sample to retrieve 

89 

90 Returns: 

91 Dictionary of (data tensor, label) pairs for each modality 

92 

93 """ 

94 return { 

95 k: self.dataset_dict[k].__getitem__(index) for k in self.dataset_dict.keys() 

96 } 

97 

98 def get_modality_item(self, modality: str, index: int) -> Tuple[torch.Tensor, Any]: 

99 """ 

100 Get a sample for a specific modality. 

101 Args: 

102 modality: The modality name to retrieve data from 

103 index: Index of the sample to retrieve 

104 

105 Returns: 

106 Tuple of (data tensor, label) for the specified modality and sample index 

107 

108 Raises: 

109 KeyError: If the requested modality doesn't exist in the dataset 

110 """ 

111 if modality not in self.dataset_dict: 

112 raise KeyError(f"Modality '{modality}' not found in dataset") 

113 

114 return self.dataset_dict[modality][index]