Coverage for src / autoencodix / data / _multimodal_dataset.py: 12%

187 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import warnings 

3import pandas as pd 

4from typing import List, Dict, Any, Optional, Union 

5from autoencodix.base._base_dataset import BaseDataset 

6from autoencodix.configs.default_config import DefaultConfig 

7 

8import numpy as np 

9 

10 

11class MultiModalDataset(BaseDataset, torch.utils.data.Dataset): # type: ignore 

12 """Handles multiple datasets of different modalities. 

13 

14 Attributes: 

15 datasets: Dictionary of datasets for each modality. 

16 n_modalities: Number of modalities. 

17 sample_to_modalities: Mapping from sample IDs to available modalities. 

18 sample_ids: List of all unique sample IDs across modalities. 

19 config: Configuration object. 

20 data: Data from the first modality (for compatibility). 

21 feature_ids: Feature IDs (currently None, to be implemented). 

22 _id_to_idx: Reverse lookup tables for sample IDs to indices per modality. 

23 paired_sample_ids: List of sample IDs that have data in all modalities. 

24 unpaired_sample_ids: List of sample IDs that do not have data in all modalities. 

25 """ 

26 

27 def __init__(self, datasets: Dict[str, BaseDataset], config: DefaultConfig): 

28 """ 

29 Initialize the MultiModalDataset. 

30 

31 Args: 

32 datasets: Dictionary of datasets for each modality. 

33 config: Configuration object. 

34 """ 

35 self.datasets = datasets 

36 self.modalities = list(datasets.keys()) 

37 self.n_modalities = len(self.datasets.keys()) 

38 self.sample_to_modalities = self._build_sample_map() 

39 self.sample_ids: List[Any] = list(self.sample_to_modalities.keys()) 

40 self.config = config 

41 self.data = next(iter(self.datasets.values())).data 

42 self.feature_ids = None # TODO 

43 

44 # Build reverse lookup tables once 

45 for ds_name, ds in self.datasets.items(): 

46 if ds.sample_ids is None: 

47 raise ValueError(f"There are no sample_ids for {ds_name}") 

48 self._id_to_idx = { 

49 mod: {sid: idx for idx, sid in enumerate(ds.sample_ids)} # type: ignore 

50 for mod, ds in self.datasets.items() 

51 } 

52 self.paired_sample_ids = self._get_paired_sample_ids() 

53 self.unpaired_sample_ids = list( 

54 set(self.sample_ids) - set(self.paired_sample_ids) 

55 ) 

56 

57 def _to_df(self, modality: Optional[str] = None) -> pd.DataFrame: 

58 """Convert the dataset to a pandas DataFrame. 

59 

60 Returns: 

61 DataFrame representation of the dataset 

62 """ 

63 if modality is None: 

64 all_modality = list(self.datasets.keys()) 

65 else: 

66 all_modality = [modality] 

67 

68 df_all = pd.DataFrame() 

69 for modality in all_modality: 

70 if modality not in self.datasets: 

71 raise ValueError(f"Unknown modality: {modality}") 

72 

73 ds = self.datasets[modality] 

74 if isinstance(ds.data, torch.Tensor): 

75 df = pd.DataFrame( 

76 ds.data.numpy(), columns=ds.feature_ids, index=ds.sample_ids 

77 ) 

78 elif isinstance(ds.data, list): 

79 # Handle image modality 

80 # Get the list of tensors 

81 tensor_list = self.datasets[modality].data 

82 if not isinstance(tensor_list[0], torch.Tensor): 

83 raise TypeError( 

84 f" Image List is not a List[torch.Tensor], but a {type(tensor_list[0])} and cannot be converted to DataFrame." 

85 ) 

86 

87 rows = [ 

88 ( 

89 t.flatten().cpu().numpy() 

90 if isinstance(t, torch.Tensor) 

91 else t.flatten() 

92 ) 

93 for t in tensor_list 

94 ] 

95 

96 df = pd.DataFrame( 

97 rows, 

98 index=ds.sample_ids, 

99 columns=["Pixel_" + str(i) for i in range(len(rows[0]))], 

100 ) 

101 else: 

102 raise TypeError( 

103 f"Data is not a torch.Tensor or image data, but a {type(ds.data)} and cannot be converted to DataFrame." 

104 ) 

105 

106 df = df.add_prefix(f"{modality}_") 

107 if df_all.empty: 

108 df_all = df 

109 else: 

110 df_all = pd.concat([df_all, df], axis=1, join="inner") 

111 

112 return df_all 

113 

114 def _build_sample_map(self): 

115 sample_to_mods = {} 

116 for modality, dataset in self.datasets.items(): 

117 for sid in dataset.sample_ids: 

118 sample_to_mods.setdefault(sid, set()).add(modality) 

119 return sample_to_mods 

120 

121 def _get_paired_sample_ids(self): 

122 return [ 

123 sid 

124 for sid, mods in self.sample_to_modalities.items() 

125 if all(mod in mods for mod in self.datasets.keys()) 

126 ] 

127 

128 def __len__(self): 

129 return len(self.paired_sample_ids) 

130 

131 def __getitem__(self, idx: Union[int, str]): 

132 sid = self.paired_sample_ids[idx] if isinstance(idx, int) else idx 

133 out = {"sample_id": sid} 

134 for mod in self.modalities: 

135 if sid not in self._id_to_idx[mod]: # missing modality 

136 out[mod] = None 

137 continue 

138 _, data, _ = self.datasets[mod][self._id_to_idx[mod][sid]] 

139 out[mod] = data 

140 return out 

141 

142 @property 

143 def is_fully_paired(self) -> bool: 

144 """Returns True if all samples are fully paired across all modalities (no unpaired samples).""" 

145 

146 return len(self.unpaired_sample_ids) == 0 

147 

148 

149class CoverageEnsuringSampler(torch.utils.data.Sampler): # type: ignore 

150 """ 

151 Sampler that ensures all samples are seen at least once per epoch for each modality. 

152 

153 

154 Attributes: 

155 dataset: The MultiModalDataset to sample from. 

156 paired_ids: List of sample IDs that have data in all modalities. 

157 unpaired_ids: List of sample IDs that do not have data in all modalities. 

158 batch_size: Number of samples per batch. 

159 paired_ratio: Ratio of paired samples in each batch. 

160 modality_samples: Dictionary mapping each modality to its list of sample IDs. 

161 """ 

162 

163 def __init__( 

164 self, multimodal_dataset: MultiModalDataset, paired_ratio=0.5, batch_size=64 

165 ): 

166 """ 

167 Initialize the sampler. 

168 

169 Args: 

170 multimodal_dataset: The MultiModalDataset to sample from. 

171 paired_ratio: Ratio of paired samples in each batch. 

172 batch_size: Number of samples per batch. 

173 """ 

174 self.dataset = multimodal_dataset 

175 self.paired_ids = multimodal_dataset.paired_sample_ids 

176 self.unpaired_ids = multimodal_dataset.unpaired_sample_ids 

177 self.batch_size = batch_size 

178 self.paired_ratio = paired_ratio 

179 

180 total_paired = len(self.paired_ids) 

181 total_unpaired = len(self.unpaired_ids) 

182 

183 if total_paired == 0: 

184 self.paired_ratio = 0.0 

185 elif total_unpaired == 0: 

186 self.paired_ratio = 1.0 

187 else: 

188 # Use requested ratio, but ensure we have enough samples 

189 max_possible_paired = total_paired / (total_paired + total_unpaired) 

190 self.paired_ratio = min(paired_ratio, max_possible_paired) 

191 # Build modality-specific sample lists 

192 self.modality_samples = {} 

193 for modality in multimodal_dataset.datasets.keys(): 

194 self.modality_samples[modality] = multimodal_dataset.datasets[ 

195 modality 

196 ].sample_ids 

197 

198 def __iter__(self): 

199 coverage_batches = self._generate_coverage_batches() 

200 random_batches = self._generate_random_batches(coverage_batches) 

201 all_batches = coverage_batches + random_batches 

202 for batch in all_batches: 

203 if len(batch) > 1: 

204 yield batch 

205 elif len(batch) == 1: 

206 current_sample = batch[0] 

207 candidate_pool = set(self.paired_ids) | set(self.unpaired_ids) 

208 candidate_pool.discard(current_sample) 

209 

210 if not candidate_pool: 

211 raise ValueError( 

212 "Cannot form a batch of size > 1 because the dataset contains " 

213 "only a single unique sample. To proceed, use a larger sample " 

214 "Not this case should not happen, probably something is very odd with your data size " 

215 ) 

216 sample_to_add = np.random.choice(list(candidate_pool)) 

217 batch.append(sample_to_add) 

218 warnings.warn( 

219 "Your combination of batch_size and number of samples whil create a batch of len 1, this will fail all model with a BatchNorm Layer,chose another batch_size to avoid this. We handled this by adding random samples from your data to this 'problem' batch to the current batch. This is an extremely rare case, for our Custom Sampler for unpaired XModalix we don't support this." 

220 ) 

221 yield batch 

222 

223 # def _generate_coverage_batches(self): 

224 # """Generate batches that ensure all samples are covered 

225 

226 # Returns: 

227 # List of batches ensuring coverage of all samples 

228 # """ 

229 # coverage_batches = [] 

230 

231 # covered = {mod: set() for mod in self.modality_samples.keys()} 

232 

233 # while not all( 

234 # len(covered[mod]) == len(self.modality_samples[mod]) 

235 # for mod in self.modality_samples.keys() 

236 # ): 

237 # batch = [] 

238 

239 # for modality in self.modality_samples.keys(): 

240 # uncovered = [ 

241 # s 

242 # for s in self.modality_samples[modality] 

243 # if s not in covered[modality] 

244 # ] 

245 

246 # if uncovered: 

247 # take = min( 

248 # len(uncovered), self.batch_size // len(self.modality_samples) 

249 # ) 

250 # selected = np.random.choice(uncovered, size=take, replace=False) 

251 # batch.extend(selected) 

252 # covered[modality].update(selected) 

253 

254 # # Fill remaining batch slots with random samples 

255 # while len(batch) < self.batch_size: 

256 # if len(batch) < self.batch_size * self.paired_ratio and self.paired_ids: 

257 # sample = np.random.choice(self.paired_ids) 

258 # batch.append(sample) 

259 # elif self.unpaired_ids: 

260 # sample = np.random.choice(self.unpaired_ids) 

261 # batch.append(sample) 

262 # else: 

263 # break 

264 

265 # batch = list(set(batch)) 

266 # if len(batch) > self.batch_size: 

267 # batch = batch[: self.batch_size] 

268 

269 # if batch: 

270 # coverage_batches.append(batch) 

271 

272 # return coverage_batches 

273 

274 def _generate_coverage_batches(self): 

275 """Generate batches that ensure all samples are covered 

276 

277 Returns: 

278 List of batches ensuring coverage of all samples 

279 """ 

280 coverage_batches = [] 

281 

282 covered = {mod: set() for mod in self.modality_samples.keys()} 

283 

284 while not all( 

285 len(covered[mod]) == len(self.modality_samples[mod]) 

286 for mod in self.modality_samples.keys() 

287 ): 

288 batch = [] 

289 batch_set = set() # Track unique samples in current batch 

290 

291 for modality in self.modality_samples.keys(): 

292 uncovered = [ 

293 s 

294 for s in self.modality_samples[modality] 

295 if s not in covered[modality] 

296 ] 

297 

298 if uncovered: 

299 take = min( 

300 len(uncovered), 

301 (self.batch_size - len(batch)) // len(self.modality_samples), 

302 ) 

303 

304 # Select samples that aren't already in the batch 

305 available = [s for s in uncovered if s not in batch_set] 

306 if available: 

307 take = min(take, len(available)) 

308 selected = np.random.choice(available, size=take, replace=False) 

309 batch.extend(selected) 

310 batch_set.update(selected) 

311 covered[modality].update(selected) 

312 

313 # Fill remaining batch slots with random samples, avoiding duplicates 

314 while len(batch) < self.batch_size: 

315 candidate_pool = [] 

316 

317 if len(batch) < self.batch_size * self.paired_ratio and self.paired_ids: 

318 candidate_pool = [s for s in self.paired_ids if s not in batch_set] 

319 elif self.unpaired_ids: 

320 candidate_pool = [ 

321 s for s in self.unpaired_ids if s not in batch_set 

322 ] 

323 

324 if not candidate_pool: 

325 # If no unique candidates available, allow repeats 

326 if ( 

327 self.paired_ids 

328 and len(batch) < self.batch_size * self.paired_ratio 

329 ): 

330 candidate_pool = self.paired_ids 

331 elif self.unpaired_ids: 

332 candidate_pool = self.unpaired_ids 

333 else: 

334 break 

335 

336 if candidate_pool: 

337 sample = np.random.choice(candidate_pool) 

338 batch.append(sample) 

339 batch_set.add(sample) 

340 else: 

341 break 

342 

343 # No need for deduplication since we track uniqueness during construction 

344 if len(batch) > self.batch_size: 

345 batch = batch[: self.batch_size] 

346 

347 if batch: 

348 coverage_batches.append(batch) 

349 

350 return coverage_batches 

351 

352 def _generate_random_batches(self, coverage_batches: List[Any]): 

353 """Generate additional random batches to fill the epoch 

354 Args: 

355 coverage_batches: Batches already generated to ensure coverage 

356 Returns: 

357 List of additional random batches 

358 """ 

359 total_samples = len(self.paired_ids) + len(self.unpaired_ids) 

360 covered_samples = sum(len(batch) for batch in coverage_batches) 

361 remaining_samples = max(0, total_samples - covered_samples) 

362 

363 random_batches = [] 

364 num_random_batches = remaining_samples // self.batch_size 

365 

366 for _ in range(num_random_batches): 

367 batch = [] 

368 

369 # Add paired samples 

370 paired_needed = int(self.batch_size * self.paired_ratio) 

371 if paired_needed > 0 and self.paired_ids: 

372 paired_samples = np.random.choice( 

373 self.paired_ids, 

374 size=min(paired_needed, len(self.paired_ids)), 

375 replace=True, 

376 ) 

377 batch.extend(paired_samples) 

378 

379 # Add unpaired samples 

380 unpaired_needed = self.batch_size - len(batch) 

381 if unpaired_needed > 0 and self.unpaired_ids: 

382 unpaired_samples = np.random.choice( 

383 self.unpaired_ids, 

384 size=min(unpaired_needed, len(self.unpaired_ids)), 

385 replace=True, 

386 ) 

387 batch.extend(unpaired_samples) 

388 

389 if batch: 

390 random_batches.append(batch) 

391 

392 return random_batches 

393 

394 def __len__(self): 

395 total_samples = len(self.paired_ids) + len(self.unpaired_ids) 

396 # return total_samples // self.batch_size 

397 return max(total_samples // self.batch_size, len(self.modality_samples)) 

398 

399 

400def create_multimodal_collate_fn(multimodal_dataset: MultiModalDataset): 

401 """ 

402 Factory function to create a collate function with access to the dataset. 

403 This allows us to get metadata and original indices. 

404 Args: 

405 multimodal_dataset: The multimodal dataset 

406 Returns: 

407 A collate function for DataLoader 

408 """ 

409 

410 def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: 

411 if not batch: 

412 return {} 

413 result = {} 

414 modalities = multimodal_dataset.modalities 

415 class_col = multimodal_dataset.config.class_param 

416 for modality in modalities: 

417 dataset = multimodal_dataset.datasets[modality] 

418 has_metadata = class_col and hasattr(dataset, "metadata") 

419 # Collect only for samples with this modality 

420 relevant_samples = [s for s in batch if s.get(modality) is not None] 

421 if not relevant_samples: 

422 raise ValueError(f"Modality {modality} has no data in batch") 

423 data_list = [s[modality] for s in relevant_samples] 

424 sample_ids = [s["sample_id"] for s in relevant_samples] 

425 sampled_index = [ 

426 multimodal_dataset._id_to_idx[modality].get(s["sample_id"], None) 

427 for s in relevant_samples 

428 ] 

429 if has_metadata: 

430 class_labels: List[str] = [ 

431 dataset.metadata.at[s["sample_id"], class_col] 

432 for s in relevant_samples 

433 ] 

434 else: 

435 class_labels = [None] * len(relevant_samples) 

436 result[modality] = { 

437 "data": torch.stack(data_list), 

438 "sample_ids": sample_ids, 

439 "sampled_index": sampled_index, 

440 "class_labels": class_labels, # List; convert to tensor if needed for loss 

441 } 

442 return result 

443 

444 return collate_fn