Coverage for src / autoencodix / data / _multimodal_dataset.py: 12%
187 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import warnings
3import pandas as pd
4from typing import List, Dict, Any, Optional, Union
5from autoencodix.base._base_dataset import BaseDataset
6from autoencodix.configs.default_config import DefaultConfig
8import numpy as np
11class MultiModalDataset(BaseDataset, torch.utils.data.Dataset): # type: ignore
12 """Handles multiple datasets of different modalities.
14 Attributes:
15 datasets: Dictionary of datasets for each modality.
16 n_modalities: Number of modalities.
17 sample_to_modalities: Mapping from sample IDs to available modalities.
18 sample_ids: List of all unique sample IDs across modalities.
19 config: Configuration object.
20 data: Data from the first modality (for compatibility).
21 feature_ids: Feature IDs (currently None, to be implemented).
22 _id_to_idx: Reverse lookup tables for sample IDs to indices per modality.
23 paired_sample_ids: List of sample IDs that have data in all modalities.
24 unpaired_sample_ids: List of sample IDs that do not have data in all modalities.
25 """
27 def __init__(self, datasets: Dict[str, BaseDataset], config: DefaultConfig):
28 """
29 Initialize the MultiModalDataset.
31 Args:
32 datasets: Dictionary of datasets for each modality.
33 config: Configuration object.
34 """
35 self.datasets = datasets
36 self.modalities = list(datasets.keys())
37 self.n_modalities = len(self.datasets.keys())
38 self.sample_to_modalities = self._build_sample_map()
39 self.sample_ids: List[Any] = list(self.sample_to_modalities.keys())
40 self.config = config
41 self.data = next(iter(self.datasets.values())).data
42 self.feature_ids = None # TODO
44 # Build reverse lookup tables once
45 for ds_name, ds in self.datasets.items():
46 if ds.sample_ids is None:
47 raise ValueError(f"There are no sample_ids for {ds_name}")
48 self._id_to_idx = {
49 mod: {sid: idx for idx, sid in enumerate(ds.sample_ids)} # type: ignore
50 for mod, ds in self.datasets.items()
51 }
52 self.paired_sample_ids = self._get_paired_sample_ids()
53 self.unpaired_sample_ids = list(
54 set(self.sample_ids) - set(self.paired_sample_ids)
55 )
57 def _to_df(self, modality: Optional[str] = None) -> pd.DataFrame:
58 """Convert the dataset to a pandas DataFrame.
60 Returns:
61 DataFrame representation of the dataset
62 """
63 if modality is None:
64 all_modality = list(self.datasets.keys())
65 else:
66 all_modality = [modality]
68 df_all = pd.DataFrame()
69 for modality in all_modality:
70 if modality not in self.datasets:
71 raise ValueError(f"Unknown modality: {modality}")
73 ds = self.datasets[modality]
74 if isinstance(ds.data, torch.Tensor):
75 df = pd.DataFrame(
76 ds.data.numpy(), columns=ds.feature_ids, index=ds.sample_ids
77 )
78 elif isinstance(ds.data, list):
79 # Handle image modality
80 # Get the list of tensors
81 tensor_list = self.datasets[modality].data
82 if not isinstance(tensor_list[0], torch.Tensor):
83 raise TypeError(
84 f" Image List is not a List[torch.Tensor], but a {type(tensor_list[0])} and cannot be converted to DataFrame."
85 )
87 rows = [
88 (
89 t.flatten().cpu().numpy()
90 if isinstance(t, torch.Tensor)
91 else t.flatten()
92 )
93 for t in tensor_list
94 ]
96 df = pd.DataFrame(
97 rows,
98 index=ds.sample_ids,
99 columns=["Pixel_" + str(i) for i in range(len(rows[0]))],
100 )
101 else:
102 raise TypeError(
103 f"Data is not a torch.Tensor or image data, but a {type(ds.data)} and cannot be converted to DataFrame."
104 )
106 df = df.add_prefix(f"{modality}_")
107 if df_all.empty:
108 df_all = df
109 else:
110 df_all = pd.concat([df_all, df], axis=1, join="inner")
112 return df_all
114 def _build_sample_map(self):
115 sample_to_mods = {}
116 for modality, dataset in self.datasets.items():
117 for sid in dataset.sample_ids:
118 sample_to_mods.setdefault(sid, set()).add(modality)
119 return sample_to_mods
121 def _get_paired_sample_ids(self):
122 return [
123 sid
124 for sid, mods in self.sample_to_modalities.items()
125 if all(mod in mods for mod in self.datasets.keys())
126 ]
128 def __len__(self):
129 return len(self.paired_sample_ids)
131 def __getitem__(self, idx: Union[int, str]):
132 sid = self.paired_sample_ids[idx] if isinstance(idx, int) else idx
133 out = {"sample_id": sid}
134 for mod in self.modalities:
135 if sid not in self._id_to_idx[mod]: # missing modality
136 out[mod] = None
137 continue
138 _, data, _ = self.datasets[mod][self._id_to_idx[mod][sid]]
139 out[mod] = data
140 return out
142 @property
143 def is_fully_paired(self) -> bool:
144 """Returns True if all samples are fully paired across all modalities (no unpaired samples)."""
146 return len(self.unpaired_sample_ids) == 0
149class CoverageEnsuringSampler(torch.utils.data.Sampler): # type: ignore
150 """
151 Sampler that ensures all samples are seen at least once per epoch for each modality.
154 Attributes:
155 dataset: The MultiModalDataset to sample from.
156 paired_ids: List of sample IDs that have data in all modalities.
157 unpaired_ids: List of sample IDs that do not have data in all modalities.
158 batch_size: Number of samples per batch.
159 paired_ratio: Ratio of paired samples in each batch.
160 modality_samples: Dictionary mapping each modality to its list of sample IDs.
161 """
163 def __init__(
164 self, multimodal_dataset: MultiModalDataset, paired_ratio=0.5, batch_size=64
165 ):
166 """
167 Initialize the sampler.
169 Args:
170 multimodal_dataset: The MultiModalDataset to sample from.
171 paired_ratio: Ratio of paired samples in each batch.
172 batch_size: Number of samples per batch.
173 """
174 self.dataset = multimodal_dataset
175 self.paired_ids = multimodal_dataset.paired_sample_ids
176 self.unpaired_ids = multimodal_dataset.unpaired_sample_ids
177 self.batch_size = batch_size
178 self.paired_ratio = paired_ratio
180 total_paired = len(self.paired_ids)
181 total_unpaired = len(self.unpaired_ids)
183 if total_paired == 0:
184 self.paired_ratio = 0.0
185 elif total_unpaired == 0:
186 self.paired_ratio = 1.0
187 else:
188 # Use requested ratio, but ensure we have enough samples
189 max_possible_paired = total_paired / (total_paired + total_unpaired)
190 self.paired_ratio = min(paired_ratio, max_possible_paired)
191 # Build modality-specific sample lists
192 self.modality_samples = {}
193 for modality in multimodal_dataset.datasets.keys():
194 self.modality_samples[modality] = multimodal_dataset.datasets[
195 modality
196 ].sample_ids
198 def __iter__(self):
199 coverage_batches = self._generate_coverage_batches()
200 random_batches = self._generate_random_batches(coverage_batches)
201 all_batches = coverage_batches + random_batches
202 for batch in all_batches:
203 if len(batch) > 1:
204 yield batch
205 elif len(batch) == 1:
206 current_sample = batch[0]
207 candidate_pool = set(self.paired_ids) | set(self.unpaired_ids)
208 candidate_pool.discard(current_sample)
210 if not candidate_pool:
211 raise ValueError(
212 "Cannot form a batch of size > 1 because the dataset contains "
213 "only a single unique sample. To proceed, use a larger sample "
214 "Not this case should not happen, probably something is very odd with your data size "
215 )
216 sample_to_add = np.random.choice(list(candidate_pool))
217 batch.append(sample_to_add)
218 warnings.warn(
219 "Your combination of batch_size and number of samples whil create a batch of len 1, this will fail all model with a BatchNorm Layer,chose another batch_size to avoid this. We handled this by adding random samples from your data to this 'problem' batch to the current batch. This is an extremely rare case, for our Custom Sampler for unpaired XModalix we don't support this."
220 )
221 yield batch
223 # def _generate_coverage_batches(self):
224 # """Generate batches that ensure all samples are covered
226 # Returns:
227 # List of batches ensuring coverage of all samples
228 # """
229 # coverage_batches = []
231 # covered = {mod: set() for mod in self.modality_samples.keys()}
233 # while not all(
234 # len(covered[mod]) == len(self.modality_samples[mod])
235 # for mod in self.modality_samples.keys()
236 # ):
237 # batch = []
239 # for modality in self.modality_samples.keys():
240 # uncovered = [
241 # s
242 # for s in self.modality_samples[modality]
243 # if s not in covered[modality]
244 # ]
246 # if uncovered:
247 # take = min(
248 # len(uncovered), self.batch_size // len(self.modality_samples)
249 # )
250 # selected = np.random.choice(uncovered, size=take, replace=False)
251 # batch.extend(selected)
252 # covered[modality].update(selected)
254 # # Fill remaining batch slots with random samples
255 # while len(batch) < self.batch_size:
256 # if len(batch) < self.batch_size * self.paired_ratio and self.paired_ids:
257 # sample = np.random.choice(self.paired_ids)
258 # batch.append(sample)
259 # elif self.unpaired_ids:
260 # sample = np.random.choice(self.unpaired_ids)
261 # batch.append(sample)
262 # else:
263 # break
265 # batch = list(set(batch))
266 # if len(batch) > self.batch_size:
267 # batch = batch[: self.batch_size]
269 # if batch:
270 # coverage_batches.append(batch)
272 # return coverage_batches
274 def _generate_coverage_batches(self):
275 """Generate batches that ensure all samples are covered
277 Returns:
278 List of batches ensuring coverage of all samples
279 """
280 coverage_batches = []
282 covered = {mod: set() for mod in self.modality_samples.keys()}
284 while not all(
285 len(covered[mod]) == len(self.modality_samples[mod])
286 for mod in self.modality_samples.keys()
287 ):
288 batch = []
289 batch_set = set() # Track unique samples in current batch
291 for modality in self.modality_samples.keys():
292 uncovered = [
293 s
294 for s in self.modality_samples[modality]
295 if s not in covered[modality]
296 ]
298 if uncovered:
299 take = min(
300 len(uncovered),
301 (self.batch_size - len(batch)) // len(self.modality_samples),
302 )
304 # Select samples that aren't already in the batch
305 available = [s for s in uncovered if s not in batch_set]
306 if available:
307 take = min(take, len(available))
308 selected = np.random.choice(available, size=take, replace=False)
309 batch.extend(selected)
310 batch_set.update(selected)
311 covered[modality].update(selected)
313 # Fill remaining batch slots with random samples, avoiding duplicates
314 while len(batch) < self.batch_size:
315 candidate_pool = []
317 if len(batch) < self.batch_size * self.paired_ratio and self.paired_ids:
318 candidate_pool = [s for s in self.paired_ids if s not in batch_set]
319 elif self.unpaired_ids:
320 candidate_pool = [
321 s for s in self.unpaired_ids if s not in batch_set
322 ]
324 if not candidate_pool:
325 # If no unique candidates available, allow repeats
326 if (
327 self.paired_ids
328 and len(batch) < self.batch_size * self.paired_ratio
329 ):
330 candidate_pool = self.paired_ids
331 elif self.unpaired_ids:
332 candidate_pool = self.unpaired_ids
333 else:
334 break
336 if candidate_pool:
337 sample = np.random.choice(candidate_pool)
338 batch.append(sample)
339 batch_set.add(sample)
340 else:
341 break
343 # No need for deduplication since we track uniqueness during construction
344 if len(batch) > self.batch_size:
345 batch = batch[: self.batch_size]
347 if batch:
348 coverage_batches.append(batch)
350 return coverage_batches
352 def _generate_random_batches(self, coverage_batches: List[Any]):
353 """Generate additional random batches to fill the epoch
354 Args:
355 coverage_batches: Batches already generated to ensure coverage
356 Returns:
357 List of additional random batches
358 """
359 total_samples = len(self.paired_ids) + len(self.unpaired_ids)
360 covered_samples = sum(len(batch) for batch in coverage_batches)
361 remaining_samples = max(0, total_samples - covered_samples)
363 random_batches = []
364 num_random_batches = remaining_samples // self.batch_size
366 for _ in range(num_random_batches):
367 batch = []
369 # Add paired samples
370 paired_needed = int(self.batch_size * self.paired_ratio)
371 if paired_needed > 0 and self.paired_ids:
372 paired_samples = np.random.choice(
373 self.paired_ids,
374 size=min(paired_needed, len(self.paired_ids)),
375 replace=True,
376 )
377 batch.extend(paired_samples)
379 # Add unpaired samples
380 unpaired_needed = self.batch_size - len(batch)
381 if unpaired_needed > 0 and self.unpaired_ids:
382 unpaired_samples = np.random.choice(
383 self.unpaired_ids,
384 size=min(unpaired_needed, len(self.unpaired_ids)),
385 replace=True,
386 )
387 batch.extend(unpaired_samples)
389 if batch:
390 random_batches.append(batch)
392 return random_batches
394 def __len__(self):
395 total_samples = len(self.paired_ids) + len(self.unpaired_ids)
396 # return total_samples // self.batch_size
397 return max(total_samples // self.batch_size, len(self.modality_samples))
400def create_multimodal_collate_fn(multimodal_dataset: MultiModalDataset):
401 """
402 Factory function to create a collate function with access to the dataset.
403 This allows us to get metadata and original indices.
404 Args:
405 multimodal_dataset: The multimodal dataset
406 Returns:
407 A collate function for DataLoader
408 """
410 def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
411 if not batch:
412 return {}
413 result = {}
414 modalities = multimodal_dataset.modalities
415 class_col = multimodal_dataset.config.class_param
416 for modality in modalities:
417 dataset = multimodal_dataset.datasets[modality]
418 has_metadata = class_col and hasattr(dataset, "metadata")
419 # Collect only for samples with this modality
420 relevant_samples = [s for s in batch if s.get(modality) is not None]
421 if not relevant_samples:
422 raise ValueError(f"Modality {modality} has no data in batch")
423 data_list = [s[modality] for s in relevant_samples]
424 sample_ids = [s["sample_id"] for s in relevant_samples]
425 sampled_index = [
426 multimodal_dataset._id_to_idx[modality].get(s["sample_id"], None)
427 for s in relevant_samples
428 ]
429 if has_metadata:
430 class_labels: List[str] = [
431 dataset.metadata.at[s["sample_id"], class_col]
432 for s in relevant_samples
433 ]
434 else:
435 class_labels = [None] * len(relevant_samples)
436 result[modality] = {
437 "data": torch.stack(data_list),
438 "sample_ids": sample_ids,
439 "sampled_index": sampled_index,
440 "class_labels": class_labels, # List; convert to tensor if needed for loss
441 }
442 return result
444 return collate_fn