Coverage for src / autoencodix / data / _stackix_dataset.py: 34%
29 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Any, Dict, Optional, Tuple, Union
2import torch
3from autoencodix.base._base_dataset import BaseDataset
4from autoencodix.data._numeric_dataset import NumericDataset
5from autoencodix.configs.default_config import DefaultConfig
8class StackixDataset(NumericDataset):
9 """
10 Dataset for handling multiple modalities in Stackix models.
12 This dataset holds individual BaseDataset objects for multiple data modalities
13 and provides a consistent interface for accessing them during training.
14 It's designed to work specifically with StackixTrainer.
16 Attributes:
17 dataset_dict: Dictionary mapping modality names to dataset objects
18 modality_keys: List of modality names
19 """
21 def __init__(
22 self,
23 dataset_dict: Dict[str, BaseDataset],
24 config: DefaultConfig,
25 ):
26 """
27 Initialize a StackixDataset instance.
29 Args:
30 dataset_dict: Dictionary mapping modality names to dataset objects
31 config: Configuration object
33 Raises:
34 ValueError: If the datasets dictionary is empty or if modality datasets have different numbers of samples
35 NotImplementedError: If the datasets have incompatible shapes for concatenation
36 """
37 if not dataset_dict:
38 raise ValueError("dataset_dict cannot be empty")
40 # Use first modality for base class initialization
41 first_modality_key = next(iter(dataset_dict.keys()))
42 first_modality = dataset_dict[first_modality_key]
43 try:
44 data = torch.cat(
45 [v.data for _, v in dataset_dict.items() if hasattr(v, "data")], dim=1
46 )
47 except Exception:
48 raise NotImplementedError(
49 "Data modalities have different shapes, set requires_paired=True in config"
50 )
51 super().__init__(
52 data=data,
53 sample_ids=first_modality.sample_ids,
54 config=config,
55 split_indices=first_modality.split_indices,
56 metadata=first_modality.metadata,
57 feature_ids=[
58 v.feature_ids
59 for v in dataset_dict.values()
60 if hasattr(v, "feature_ids")
61 ],
62 )
64 self.dataset_dict = dataset_dict
65 self.modality_keys = list(dataset_dict.keys())
67 # Ensure all datasets have the same number of samples
68 sample_counts = [len(dataset) for dataset in dataset_dict.values()]
69 if not all(count == sample_counts[0] for count in sample_counts):
70 raise ValueError(
71 "All modality datasets must have the same number of samples"
72 )
74 def __len__(self) -> int:
75 """Return the number of samples in the dataset."""
76 return len(next(iter(self.dataset_dict.values())))
78 def __getitem__(
79 self, index: int
80 ) -> Union[Tuple[torch.Tensor, Any], Dict[str, Tuple[torch.Tensor, Any]]]:
81 """
82 Get a single sample and its label from the dataset.
84 Returns the data from the first modality to maintain compatibility
85 with the BaseDataset interface, while still supporting multi-modality
86 access through dataset_dict.
87 Args:
88 index: Index of the sample to retrieve
90 Returns:
91 Dictionary of (data tensor, label) pairs for each modality
93 """
94 return {
95 k: self.dataset_dict[k].__getitem__(index) for k in self.dataset_dict.keys()
96 }
98 def get_modality_item(self, modality: str, index: int) -> Tuple[torch.Tensor, Any]:
99 """
100 Get a sample for a specific modality.
101 Args:
102 modality: The modality name to retrieve data from
103 index: Index of the sample to retrieve
105 Returns:
106 Tuple of (data tensor, label) for the specified modality and sample index
108 Raises:
109 KeyError: If the requested modality doesn't exist in the dataset
110 """
111 if modality not in self.dataset_dict:
112 raise KeyError(f"Modality '{modality}' not found in dataset")
114 return self.dataset_dict[modality][index]