Coverage for src / autoencodix / visualize / _xmodal_visualizer.py: 11%

292 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from dataclasses import field 

2import pandas as pd 

3import numpy as np 

4import seaborn as sns 

5import matplotlib.pyplot as plt 

6from umap import UMAP 

7import warnings 

8import torch 

9from sklearn.decomposition import PCA 

10from sklearn.manifold import TSNE 

11 

12from typing import Any, Dict, Optional, Union, List, no_type_check 

13from autoencodix.base._base_visualizer import BaseVisualizer 

14from autoencodix.utils._result import Result 

15from autoencodix.utils._utils import nested_dict, show_figure 

16from autoencodix.configs.default_config import DefaultConfig 

17from autoencodix.data._datasetcontainer import DatasetContainer 

18 

19 

20class XModalVisualizer(BaseVisualizer): 

21 plots: Dict[str, Any] = field( 

22 default_factory=nested_dict 

23 ) ## Nested dictionary of plots as figure handles 

24 

25 def __init__(self): 

26 self.plots = nested_dict() 

27 

28 def __setitem__(self, key, elem): 

29 self.plots[key] = elem 

30 

31 def visualize(self, result: Result, config: DefaultConfig) -> Result: 

32 ## Make Model Weights plot 

33 ## TODO needs to be adjusted for X-Modalix ## 

34 ## Plot Model weights for each sub-VAE ## 

35 # self.plots["ModelWeights"] = self._plot_model_weights(model=result.model) 

36 

37 ## Make long format of losses 

38 loss_df_melt = self._make_loss_format(result=result, config=config) 

39 

40 ## X-Modalix specific ## 

41 # Filter loss terms which are specific for each modality VAE 

42 # Plot only combined loss terms as in old autoencodix framework 

43 if not hasattr(result.datasets, "train"): 

44 raise ValueError("result.datasets has no attribute train") 

45 if result.datasets.train is None: 

46 raise ValueError("Train attribute of datasets is None") 

47 loss_df_melt = loss_df_melt[ 

48 ~loss_df_melt["Loss Term"].str.startswith( 

49 tuple(result.datasets.train.datasets.keys()) 

50 ) 

51 ] 

52 if not result.losses._data: 

53 import warnings 

54 

55 warnings.warn( 

56 "No loss data: This usually happens if you try to visualize after saving and loading the pipeline object with `save_all=False`. This memory-efficient saving mode does not retain past training loss data." 

57 ) 

58 return result 

59 ## Make plot loss absolute 

60 self.plots["loss_absolute"] = self._make_loss_plot( 

61 df_plot=loss_df_melt, plot_type="absolute" 

62 ) 

63 ## Make plot loss relative 

64 self.plots["loss_relative"] = self._make_loss_plot( 

65 df_plot=loss_df_melt, plot_type="relative" 

66 ) 

67 

68 return result 

69 

70 def show_latent_space( 

71 self, 

72 result: Result, 

73 plot_type: str = "2D-scatter", 

74 labels: Optional[Union[list, pd.Series, None]] = None, 

75 param: Optional[Union[list, str]] = None, 

76 epoch: Optional[Union[int, None]] = None, 

77 split: str = "all", 

78 ) -> None: 

79 plt.ioff() 

80 if plot_type == "Coverage-Correlation": 

81 print("TODO: Implement Coverage-Correlation plot for X-Modalix") 

82 # if "Coverage-Correlation" in self.plots: 

83 # fig = self.plots["Coverage-Correlation"] 

84 # show_figure(fig) 

85 # plt.show() 

86 # else: 

87 # results = [] 

88 # for epoch in range(result.model.config.checkpoint_interval, result.model.config.epochs + 1, result.model.config.checkpoint_interval): 

89 # for split in ["train", "valid"]: 

90 # latent_df = result.get_latent_df(epoch=epoch-1, split=split) 

91 # tc = self._total_correlation(latent_df) 

92 # cov = self._coverage_calc(latent_df) 

93 # results.append({"epoch": epoch, "split": split, "total_correlation": tc, "coverage": cov}) 

94 

95 # df_metrics = pd.DataFrame(results) 

96 

97 # fig, axes = plt.subplots(1, 2, figsize=(12, 5)) 

98 

99 # # Total Correlation plot 

100 # ax1 = sns.lineplot(data=df_metrics, x="epoch", y="total_correlation", hue="split", ax=axes[0]) 

101 # axes[0].set_title("Total Correlation") 

102 # axes[0].set_xlabel("Epoch") 

103 # axes[0].set_ylabel("Total Correlation") 

104 

105 # # Coverage plot 

106 # ax2 = sns.lineplot(data=df_metrics, x="epoch", y="coverage", hue="split", ax=axes[1]) 

107 # axes[1].set_title("Coverage") 

108 # axes[1].set_xlabel("Epoch") 

109 # axes[1].set_ylabel("Coverage") 

110 

111 # plt.tight_layout() 

112 # self.plots["Coverage-Correlation"] = fig 

113 # show_figure(fig) 

114 # plt.show() 

115 else: 

116 # Set Defaults 

117 if epoch is None: 

118 epoch = -1 

119 

120 ## Collect all metadata and latent spaces from datasets 

121 clin_data = [] 

122 latent_data = [] 

123 

124 if split == "all": 

125 split_list = ["train", "test", "valid"] 

126 else: 

127 split_list = [split] 

128 for s in split_list: 

129 split_ds = getattr(result.datasets, s, None) 

130 if split_ds is not None: 

131 for key, ds in split_ds.datasets.items(): 

132 if s == "test": 

133 df_latent = result.get_latent_df( 

134 epoch=-1, split=s, modality=key 

135 ) 

136 else: 

137 df_latent = result.get_latent_df( 

138 epoch=epoch, split=s, modality=key 

139 ) 

140 df_latent["modality"] = key 

141 df_latent["sample_ids"] = ( 

142 df_latent.index 

143 ) # Each sample can occur multiple times in latent space 

144 latent_data.append(df_latent) 

145 if hasattr(ds, "metadata") and ds.metadata is not None: 

146 df = ds.metadata.copy() 

147 df["sample_ids"] = df.index.astype(str) 

148 df["split"] = s 

149 df["modality"] = key 

150 clin_data.append(df) 

151 

152 if latent_data and clin_data: 

153 latent_data = pd.concat(latent_data, axis=0, ignore_index=True) 

154 clin_data = pd.concat(clin_data, axis=0, ignore_index=True) 

155 if "sample_ids" in clin_data.columns: 

156 clin_data = clin_data.drop_duplicates( 

157 subset="sample_ids" 

158 ).set_index("sample_ids") 

159 else: 

160 latent_data = pd.DataFrame() 

161 clin_data = pd.DataFrame() 

162 

163 ## Label options 

164 if param is None: 

165 modality = list(result.model.keys())[ 

166 0 

167 ] # Take the first since configs are same for all sub-VAEs 

168 model = result.model.get(modality, None) 

169 if model is None: 

170 raise ValueError( 

171 f"Model for modality {modality} not found in result.model" 

172 ) 

173 param = model.config.data_config.annotation_columns 

174 

175 if labels is None and param is None: 

176 labels = ["all"] * latent_data["sample_ids"].unique().shape[0] 

177 

178 if labels is None and isinstance(param, str): 

179 if param == "all": 

180 param = list(clin_data.columns) 

181 else: 

182 raise ValueError( 

183 "Please provide parameter to plot as a list not as string. If you want to plot all parameters, set param to 'all' and labels to None." 

184 ) 

185 

186 if labels is not None and param is not None: 

187 raise ValueError( 

188 "Please provide either labels or param, not both. If you want to plot all parameters, set param to 'all' and labels to None." 

189 ) 

190 

191 if labels is not None and param is None: 

192 if isinstance(labels, pd.Series): 

193 param = [labels.name] 

194 # Order by index of latent_data first, fill missing with "unknown" 

195 labels = labels.reindex( # ty: ignore 

196 latent_data["sample_ids"], # ty: ignore 

197 fill_value="unknown", # ty: ignore 

198 ).tolist() 

199 else: 

200 param = ["user_label"] # Default label if none provided 

201 if not isinstance(param, list): 

202 raise ValueError(f"param: should be converted to list, got: {param}") 

203 for p in param: 

204 if p in clin_data.columns: 

205 labels: List = clin_data.loc[ 

206 latent_data["sample_ids"], p 

207 ].tolist() # ty: ignore 

208 else: 

209 if clin_data.shape[0] == len(labels): # ty: ignore 

210 clin_data[p] = labels 

211 else: 

212 clin_data[p] = ["all"] * clin_data.shape[0] 

213 

214 if plot_type == "2D-scatter": 

215 ## Make 2D Embedding with UMAP 

216 if ( 

217 latent_data.drop( 

218 columns=["sample_ids", "modality"] 

219 ).shape[ # ty: ignore 

220 1 

221 ] # ty: ignore 

222 > 2 

223 ): 

224 reducer = UMAP(n_components=2) 

225 embedding = pd.DataFrame( 

226 reducer.fit_transform( 

227 latent_data.drop( 

228 columns=["sample_ids", "modality"] 

229 ) # ty: ignore 

230 ) 

231 ) 

232 embedding.columns = ["DIM1", "DIM2"] 

233 embedding["sample_ids"] = latent_data["sample_ids"] 

234 embedding["modality"] = latent_data["modality"] 

235 else: 

236 embedding = latent_data 

237 

238 # Merge with clinical data via sample_ids 

239 clin_data["sample_ids"] = clin_data.index.astype(str) 

240 clin_data.index = clin_data.index.astype(str) # Add this line 

241 embedding["sample_ids"] = embedding["sample_ids"].astype(str) 

242 

243 embedding = embedding.merge( 

244 clin_data.drop(columns=["modality"]), # ty: ignore 

245 left_on="sample_ids", 

246 right_index=True, 

247 how="left", 

248 ) 

249 

250 self.plots["2D-scatter"][epoch][split][p] = ( 

251 self._plot_translate_latent( 

252 embedding=embedding, 

253 color_param=p, 

254 style_param="modality", 

255 ) 

256 ) 

257 

258 fig = self.plots["2D-scatter"][epoch][split][p].figure 

259 # show_figure(fig) 

260 plt.show() 

261 

262 if plot_type == "Ridgeline": 

263 ## Make ridgeline plot 

264 if len(labels) != latent_data.shape[0]: # ty: ignore 

265 if labels[0] == "all": # ty: ignore 

266 labels = ["all"] * latent_data.shape[0] # ty: ignore 

267 else: 

268 raise ValueError( 

269 "Labels must match the number of samples in the latent space." 

270 ) 

271 

272 self.plots["Ridgeline"][epoch][split][p] = ( 

273 self._plot_latent_ridge_multi( 

274 lat_space=latent_data.drop( 

275 columns=["sample_ids"] 

276 ), # ty: ignore 

277 labels=labels, 

278 modality="modality", 

279 param=p, 

280 ) 

281 ) 

282 

283 fig = self.plots["Ridgeline"][epoch][split][p].figure 

284 show_figure(fig) 

285 plt.show() 

286 

287 def show_weights(self) -> None: 

288 ## TODO 

289 raise NotImplementedError( 

290 "Weight visualization for X-Modalix is not implemented." 

291 ) 

292 

293 @no_type_check 

294 def show_image_translation( # ty: ignore 

295 self, 

296 result: Result, 

297 from_key: str, 

298 to_key: str, 

299 n_sample_per_class: int = 3, 

300 param: Optional[str] = None, 

301 ) -> None: # ty: ignore 

302 """Visualizes image translation results for a given dataset. 

303 

304 Split by displaying a grid of original, translated, and reference images,grouped by class values. 

305 Args: 

306 result:The result object containing datasets and reconstructions. 

307 from_key: The source modality key (not directly used in visualization, but relevant for context). 

308 to_key: The target modality key. Must correspond to an image dataset (must contain "IMG"). 

309 split: The dataset split to visualize ("test", "train", or "valid"). Default is "test". 

310 n_sample_per_class: Number of samples to display per class value. Default is 3. 

311 param: The metadata column name used to group samples by class. 

312 Raises 

313 ValueError: If `to_key` does not correspond to an image dataset. 

314 """ 

315 

316 if "img" not in to_key: 

317 raise ValueError( 

318 f"You provided as 'to_key' {to_key} a non-image dataset. " 

319 "Image translation grid visualization is only possible for translation to IMG data type." 

320 ) 

321 else: 

322 split = "test" # Currently only test split is supported 

323 ## Get n samples per class 

324 if split == "test": 

325 meta = result.datasets.test.datasets[to_key].metadata 

326 paired_sample_ids = result.datasets.test.paired_sample_ids 

327 

328 # Restrict meta to only paired sample ids 

329 meta = meta.loc[paired_sample_ids] 

330 

331 if param is None: 

332 param = "user-label" 

333 meta[param] = ( 

334 "all" # Default to all samples if no parameter is provided 

335 ) 

336 

337 # Get possible class values 

338 class_values = meta[param].unique() 

339 if len(class_values) > 10: 

340 # Make warning 

341 warnings.warn( 

342 f"Found {len(class_values)} class values for parameter '{param}'. Only first 10 will be used to limit figure size" 

343 ) 

344 class_values = class_values[:10] 

345 

346 # Build dictionary of sample_ids per class value (max n_sample_per_class per class) 

347 sample_per_class = { 

348 val: meta[meta[param] == val] 

349 .sample( 

350 n=min(n_sample_per_class, (meta[param] == val).sum()), 

351 random_state=42, 

352 ) 

353 .index.tolist() 

354 for val in class_values 

355 } 

356 

357 print(f"Sample per class: {sample_per_class}") 

358 

359 # Lookup of sample indices per modality 

360 sample_ids_per_key = dict() 

361 

362 for key in result.sample_ids.get(epoch=-1, split="test").keys(): 

363 sample_ids_per_key[key] = result.sample_ids.get(epoch=-1, split="test")[ 

364 key 

365 ] 

366 # Original 

367 sample_ids_per_key["original"] = result.datasets.test.datasets[ 

368 to_key 

369 ].sample_ids 

370 

371 ## Generate Image Grid 

372 # Number of test (or train or valid) samples from all values in sample_per_class dictionary 

373 n_test_samples = sum(len(indices) for indices in sample_per_class.values()) 

374 

375 # # 

376 col_labels = [] 

377 for class_value in sample_per_class: 

378 col_labels.extend( 

379 [ 

380 class_value + " " + split + "-sample:" + s 

381 for s in sample_per_class[class_value] 

382 ] 

383 ) 

384 

385 row_labels = ["Original", "Translated", "Reference"] 

386 

387 fig, axes = plt.subplots( 

388 ncols=n_test_samples, # Number of classes 

389 nrows=3, # Original, translated, reference 

390 figsize=(n_test_samples * 2, 3 * 2), 

391 ) 

392 

393 for i, ax in enumerate(axes.flat): 

394 row = int(i / n_test_samples) 

395 # test_sample = sample_idx_list[i % n_test_samples] 

396 # print(f"Row: {row}, Column: {i % n_test_samples}") 

397 # print(f"Current sample: {col_labels[i % n_test_samples]}") 

398 

399 if row == 0: 

400 if split == "test": 

401 idx_original = list(sample_ids_per_key["original"]).index( 

402 col_labels[i % n_test_samples].split("sample:")[1] 

403 ) 

404 img_temp = result.datasets.test.datasets[to_key][idx_original][ 

405 1 

406 ].squeeze() # Stored as Tuple (index, tensor, sample_id) 

407 

408 # Original image 

409 ax.imshow(np.asarray(img_temp)) 

410 ax.axis("off") 

411 # Sample label 

412 ax.text( 

413 0.5, 

414 1.1, 

415 col_labels[i], 

416 va="bottom", 

417 ha="center", 

418 # rotation='vertical', 

419 rotation=45, 

420 transform=ax.transAxes, 

421 ) 

422 # Row label 

423 if i % n_test_samples == 0: 

424 ax.text( 

425 -0.1, 

426 0.5, 

427 row_labels[0], 

428 va="center", 

429 ha="right", 

430 transform=ax.transAxes, 

431 ) 

432 

433 if row == 1: 

434 # Translated image 

435 idx_translated = list(sample_ids_per_key["translation"]).index( 

436 col_labels[i % n_test_samples].split("sample:")[1] 

437 ) 

438 ax.imshow( 

439 result.reconstructions.get(epoch=-1, split=split)[ 

440 "translation" 

441 ][idx_translated].squeeze() 

442 ) 

443 ax.axis("off") 

444 # Row label 

445 if i % n_test_samples == 0: 

446 ax.text( 

447 -0.1, 

448 0.5, 

449 row_labels[1], 

450 va="center", 

451 ha="right", 

452 transform=ax.transAxes, 

453 ) 

454 

455 if row == 2: 

456 # Reference image reconstruction 

457 idx_reference = list( 

458 sample_ids_per_key[f"reference_{to_key}_to_{to_key}"] 

459 ).index(col_labels[i % n_test_samples].split("sample:")[1]) 

460 ax.imshow( 

461 result.reconstructions.get(epoch=-1, split=split)[ 

462 f"reference_{to_key}_to_{to_key}" 

463 ][idx_reference].squeeze() 

464 ) 

465 ax.axis("off") 

466 # Row label 

467 if i % n_test_samples == 0: 

468 ax.text( 

469 -0.1, 

470 0.5, 

471 row_labels[2], 

472 va="center", 

473 ha="right", 

474 transform=ax.transAxes, 

475 ) 

476 

477 self.plots["Image-translation"][to_key][split][param] = fig 

478 # show_figure(fig) 

479 plt.show() 

480 

481 @no_type_check 

482 def show_2D_translation( 

483 self, 

484 result: Result, 

485 translated_modality: str, 

486 split: str = "test", 

487 param: Optional[str] = None, 

488 reducer: str = "UMAP", 

489 ) -> None: 

490 ## TODO add similar labels/param logic from other visualizations 

491 dataset = result.datasets 

492 

493 ## Overwrite original datasets with new_datasets if available after predict with other data 

494 if dataset is None: 

495 dataset = DatasetContainer() 

496 

497 if bool(result.new_datasets.test): 

498 dataset.test = result.new_datasets.test 

499 

500 if split not in ["train", "valid", "test", "all"]: 

501 raise ValueError(f"Unknown split: {split}") 

502 

503 if dataset.test is None: 

504 raise ValueError("test of dataset is None") 

505 

506 if split == "test": 

507 df_processed = dataset.test._to_df(modality=translated_modality) 

508 else: 

509 raise NotImplementedError( 

510 "2D translation visualization is currently only implemented for the 'test' split since reconstruction is only performed on test-split." 

511 ) 

512 

513 # Get translated reconstruction 

514 tensor_list = result.reconstructions.get(epoch=-1, split=split)[ # ty: ignore 

515 "translation" 

516 ] # ty: ignore 

517 print(f"len of tensor-list: {len(tensor_list)}") 

518 tensor_ids = result.sample_ids.get(epoch=-1, split=split)["translation"] 

519 print(f"len of tensor_ids: {len(tensor_ids)}") 

520 

521 # Flatten each tensor and collect as rows (for image case) 

522 rows = [ 

523 t.flatten().cpu().numpy() if isinstance(t, torch.Tensor) else t.flatten() 

524 for t in tensor_list 

525 ] 

526 

527 # Create DataFrame 

528 df_translate_flat = pd.DataFrame( 

529 rows, 

530 columns=["Feature_" + str(i) for i in range(len(rows[0]))], 

531 index=tensor_ids, 

532 ) 

533 

534 if reducer == "UMAP": 

535 reducer_model = UMAP(n_components=2) 

536 elif reducer == "PCA": 

537 reducer_model = PCA(n_components=2) 

538 elif reducer == "TSNE": 

539 reducer_model = TSNE(n_components=2) 

540 

541 # making sure of index alignemnt 

542 common_ids = df_processed.index.intersection(df_translate_flat.index) 

543 df_processed = df_processed.loc[common_ids] 

544 df_translate_flat = df_translate_flat.loc[common_ids] 

545 df_translate_flat = df_translate_flat.reindex(df_processed.index) 

546 df_translate_flat.index = pd.Index([i for i in range(len(common_ids))]) 

547 X = np.vstack([df_processed.values, df_translate_flat.values]) 

548 df_red_comb = pd.DataFrame(reducer_model.fit_transform(X)) 

549 

550 # df_comb = pd.concat( 

551 # [df_processed, df_translate_flat], axis=0, ignore_index=True 

552 # ) 

553 

554 df_red_comb["origin"] = ["input"] * df_processed.shape[0] + [ 

555 "translated" 

556 ] * df_translate_flat.shape[0] 

557 

558 # df_red_comb = pd.DataFrame( 

559 # reducer_model.fit_transform( 

560 # pd.concat([df_processed, df_translate_flat], axis=0) 

561 # ) 

562 # ) 

563 

564 labels = ( 

565 list( 

566 result.datasets.test.datasets[translated_modality].metadata[param] 

567 ) # ty: ignore 

568 * 2 

569 ) 

570 df_red_comb[param] = ( 

571 labels + labels[0 : df_red_comb.shape[0] - len(labels)] 

572 ) ## TODO fix for not matching lengths 

573 

574 g = sns.FacetGrid( 

575 df_red_comb, 

576 col="origin", 

577 hue=param, 

578 sharex=True, 

579 sharey=True, 

580 height=8, 

581 aspect=1, 

582 ) 

583 g.map_dataframe(sns.scatterplot, x=0, y=1, alpha=0.7) 

584 g.add_legend() 

585 g.set_axis_labels(reducer + " DIM 1", reducer + " DIM 2") 

586 g.set_titles(col_template="{col_name}") 

587 

588 self.plots["2D-translation"][translated_modality][split][param] = g 

589 plt.show() 

590 

591 ## Utilities specific for X-Modalix 

592 @staticmethod 

593 def _plot_translate_latent( 

594 embedding, 

595 color_param, 

596 style_param=None, 

597 ): 

598 """Creates a 2D visualization of the 2D embedding of the latent space. 

599 Args: 

600 embedding: embedding on which is visualized. Assumes prior 2D dimension reduction. 

601 color_params: Clinical parameter to color scatter plot 

602 style_param: Parameter e.g. "Translate" to facet scatter plot 

603 Returns: 

604 fig: Figure handle 

605 

606 """ 

607 labels = list(embedding[color_param]) 

608 # logger = getlogger(cfg) 

609 numeric = False 

610 if not isinstance(labels[0], str): 

611 if len(np.unique(labels)) > 3: 

612 # TODO Decide if numeric to category should be optional in new Package 

613 # print( 

614 # f"The provided label column is numeric and converted to categories." 

615 # ) 

616 # labels = pd.qcut( 

617 # labels, q=4, labels=["1stQ", "2ndQ", "3rdQ", "4thQ"] 

618 # ).astype(str) 

619 # else: 

620 numeric = True 

621 else: 

622 labels = [str(x) for x in labels] 

623 

624 # check if label or embedding is longerm and duplicate the shorter one 

625 if len(labels) < embedding.shape[0]: 

626 print( 

627 "Given labels do not have the same length as given sample size. Labels will be duplicated." 

628 ) 

629 labels = [ 

630 label 

631 for label in labels 

632 for _ in range(embedding.shape[0] // len(labels)) 

633 ] 

634 elif len(labels) > embedding.shape[0]: 

635 labels = list(set(labels)) 

636 

637 if style_param is not None: 

638 embedding[color_param] = labels 

639 if numeric: 

640 palette = "bwr" 

641 else: 

642 palette = None 

643 plot = sns.relplot( 

644 data=embedding, 

645 x="DIM1", 

646 y="DIM2", 

647 hue=color_param, 

648 palette=palette, 

649 col=style_param, 

650 style=style_param, 

651 markers=True, 

652 alpha=0.4, 

653 ec="black", 

654 height=10, 

655 aspect=1, 

656 s=150, 

657 ) 

658 

659 return plot 

660 

661 @staticmethod 

662 def _plot_latent_ridge_multi( 

663 lat_space: pd.DataFrame, 

664 modality: Optional[str] = None, 

665 labels: Optional[Union[list, pd.Series, None]] = None, 

666 param: Optional[Union[str, None]] = None, 

667 ) -> sns.FacetGrid: 

668 """Creates a ridge line plot of latent space dimension where each row shows the density of a latent dimension and groups (ridges). 

669 Args: 

670 lat_space: DataFrame containing the latent space intensities for samples (rows) and latent dimensions (columns) 

671 labels: List of labels for each sample. If None, all samples are considered as one group. 

672 param: Clinical parameter to create groupings and coloring of ridges. Must be a column name (str) of clin_data 

673 Returns: 

674 g (sns.FacetGrid): FacetGrid object containing the ridge line plot 

675 """ 

676 sns.set_theme( 

677 style="white", rc={"axes.facecolor": (0, 0, 0, 0)} 

678 ) ## Necessary to enforce overplotting 

679 

680 df = pd.melt( 

681 lat_space, 

682 id_vars=modality, # ty: ignore 

683 var_name="latent dim", 

684 value_name="latent intensity", 

685 ) 

686 # print(df) 

687 df["sample"] = len(lat_space.drop(columns=modality).columns) * list( 

688 lat_space.index 

689 ) 

690 

691 if labels is None: 

692 param = "all" 

693 labels = ["all"] * len(df) 

694 

695 # print(labels[0]) 

696 if not isinstance(labels[0], str): 

697 if len(np.unique(labels)) > 3: 

698 # Change all non-float labels to NaN 

699 labels = [x if isinstance(x, float) else float("nan") for x in labels] 

700 labels = pd.qcut( 

701 x=pd.Series(labels), 

702 q=4, 

703 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"], 

704 ).astype(str) 

705 else: 

706 labels = [str(x) for x in labels] 

707 

708 df[param] = len(lat_space.drop(columns=modality).columns) * labels # type: ignore 

709 

710 exclude_missing_info = (df[param] == "unknown") | (df[param] == "nan") 

711 

712 xmin = ( 

713 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]] 

714 .groupby([param, "latent dim"], observed=False) 

715 .quantile(0.05) 

716 .min() 

717 ) 

718 xmax = ( 

719 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]] 

720 .groupby([param, "latent dim"], observed=False) 

721 .quantile(0.9) 

722 .max() 

723 ) 

724 

725 if len(np.unique(df[param])) > 8: 

726 cat_pal = sns.husl_palette(len(np.unique(df[param]))) 

727 else: 

728 cat_pal = sns.color_palette(n_colors=len(np.unique(df[param]))) 

729 

730 g = sns.FacetGrid( 

731 df[~exclude_missing_info], 

732 row="latent dim", 

733 col=modality, 

734 hue=param, 

735 aspect=12, 

736 height=0.8, 

737 xlim=(xmin.iloc[0], xmax.iloc[0]), 

738 palette=cat_pal, 

739 ) 

740 

741 g.map_dataframe( 

742 sns.kdeplot, 

743 "latent intensity", 

744 bw_adjust=0.5, 

745 clip_on=True, 

746 fill=True, 

747 alpha=0.5, 

748 warn_singular=False, 

749 ec="k", 

750 lw=1, 

751 ) 

752 

753 def label(data, color, label, text="latent dim"): 

754 ax = plt.gca() 

755 label_text = data[text].unique()[0] 

756 ax.text( 

757 0.0, 

758 0.2, 

759 label_text, 

760 fontweight="bold", 

761 ha="right", 

762 va="center", 

763 transform=ax.transAxes, 

764 ) 

765 

766 g.map_dataframe(label, text="latent dim") 

767 

768 g.set(xlim=(xmin.iloc[0], xmax.iloc[0])) 

769 # Set the subplots to overlap 

770 g.figure.subplots_adjust(hspace=-0.5) 

771 

772 # Remove axes details that don't play well with overlap 

773 g.set_titles("") 

774 g.set(yticks=[], ylabel="") 

775 g.despine(bottom=True, left=True) 

776 

777 for i, m in enumerate(df[modality].unique()): 

778 g.fig.get_axes()[i].set_title(m) 

779 

780 g.add_legend() 

781 

782 plt.close() 

783 return g 

784 

785 def _plot_evaluation( 

786 self, 

787 result: Result, 

788 ) -> dict: 

789 """Plots the evaluation results from the Result object. 

790 

791 Args: 

792 result: The Result object containing evaluation data. 

793 

794 Returns: 

795 The generated dictionary containing the evaluation plots. 

796 """ 

797 ## Plot all results 

798 

799 ml_plots = dict() 

800 plt.ioff() 

801 

802 for c in pd.unique(result.embedding_evaluation.CLINIC_PARAM): 

803 ml_plots[c] = dict() 

804 for m in pd.unique( 

805 result.embedding_evaluation.loc[ 

806 result.embedding_evaluation.CLINIC_PARAM == c, "metric" 

807 ] 

808 ): # ty: ignore 

809 ml_plots[c][m] = dict() 

810 for alg in pd.unique( 

811 result.embedding_evaluation.loc[ 

812 (result.embedding_evaluation.CLINIC_PARAM == c) 

813 & (result.embedding_evaluation.metric == m), 

814 "ML_ALG", 

815 ] 

816 ): # ty: ignore 

817 data = result.embedding_evaluation[ 

818 (result.embedding_evaluation.metric == m) 

819 & (result.embedding_evaluation.CLINIC_PARAM == c) 

820 & (result.embedding_evaluation.ML_ALG == alg) 

821 ] 

822 

823 sns_plot = sns.catplot( 

824 data=data, 

825 x="score_split", 

826 y="value", 

827 col="ML_TASK", 

828 row="MODALITY", 

829 hue="score_split", 

830 kind="bar", 

831 ) 

832 

833 min_y = data.value.min() 

834 if min_y > 0: 

835 min_y = 0 

836 

837 ml_plots[c][m][alg] = sns_plot.set(ylim=(min_y, None)) 

838 

839 self.plots["ML_Evaluation"] = ml_plots 

840 

841 return ml_plots