Coverage for src / autoencodix / visualize / _general_visualizer.py: 11%

288 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from dataclasses import field 

2from typing import Any, Dict, Optional, Union, Literal, no_type_check 

3import warnings 

4 

5import matplotlib.figure 

6import numpy as np 

7import pandas as pd 

8import seaborn as sns # type: ignore 

9from matplotlib import pyplot as plt 

10from umap import UMAP # type: ignore 

11 

12from autoencodix.base._base_visualizer import BaseVisualizer 

13from autoencodix.utils._result import Result 

14from autoencodix.utils._utils import nested_dict, show_figure 

15from autoencodix.configs.default_config import DefaultConfig 

16 

17 

18class GeneralVisualizer(BaseVisualizer): 

19 plots: Dict[str, Any] = field( 

20 default_factory=nested_dict 

21 ) ## Nested dictionary of plots as figure handles 

22 

23 def __init__(self): 

24 self.plots = nested_dict() 

25 

26 def __setitem__(self, key, elem): 

27 self.plots[key] = elem 

28 

29 def visualize(self, result: Result, config: DefaultConfig) -> Result: 

30 ## Make Model Weights plot 

31 if result.model.input_dim <= 3000: 

32 self.plots["ModelWeights"] = self._plot_model_weights(model=result.model) 

33 else: 

34 warnings.warn( 

35 f"Model weights plot is skipped since input dimension {result.model.input_dim} is larger than 3000 and heatmap would be too large." 

36 ) 

37 

38 ## Make long format of losses 

39 try: 

40 loss_df_melt = self._make_loss_format(result=result, config=config) 

41 

42 ## Make plot loss absolute 

43 self.plots["loss_absolute"] = self._make_loss_plot( 

44 df_plot=loss_df_melt, plot_type="absolute" 

45 ) 

46 ## Make plot loss relative 

47 self.plots["loss_relative"] = self._make_loss_plot( 

48 df_plot=loss_df_melt, plot_type="relative" 

49 ) 

50 except Exception as e: 

51 warnings.warn( 

52 f"We could not create visualizations for the loss plots.\n" 

53 f"This usually happens if you try to visualize after saving and loading " 

54 f"the pipeline object with `save_all=False`. This memory-efficient saving mode " 

55 f"does not retain past training loss data.\n\n" 

56 # f"Original error message: {e}" 

57 ) 

58 

59 return result 

60 

61 ## Plotting methods ## 

62 @no_type_check 

63 def show_latent_space( 

64 self, 

65 result: Result, 

66 plot_type: Literal[ 

67 "2D-scatter", "Ridgeline", "Coverage-Correlation" 

68 ] = "2D-scatter", 

69 labels: Optional[Union[list, pd.Series, None]] = None, 

70 focus_labels: Optional[Union[list, None]] = None, 

71 param: Optional[Union[list, str]] = None, 

72 epoch: Optional[Union[int, None]] = None, 

73 split: str = "all", 

74 n_downsample: Optional[int] = 10000, 

75 **kwargs, 

76 ) -> None: 

77 """Visualizes the latent space of the given result using different types of plots. 

78 

79 Args: 

80 result: The result object containing latent spaces and losses. 

81 plot_type: The type of plot to generate. Options are "2D-scatter", "Ridgeline", and "Coverage-Correlation". Default is "2D-scatter". 

82 labels: List of labels for the data points in the latent space. Default is None. 

83 focus_labels: List of labels which should be considered for coloring. All other labels are set to 'other'. Defaults to None where all labels are considered. 

84 param: List of parameters provided and stored as metadata. Strings must match column names. If not a list, string "all" is expected for convenient way to make plots for all parameters available. Default is None where no colored labels are plotted. 

85 epoch: The epoch number to visualize. If None, the last epoch is inferred from the losses. Default is None. 

86 split: The data split to visualize. Options are "train", "valid", "test", and "all". Default is "all". 

87 n_downsample: If provided, downsample the data to this number of samples for faster visualization. Default is 10000. Set to None to disable downsampling. 

88 **kwargs: additional arguments. 

89 

90 """ 

91 plt.ioff() 

92 if plot_type == "Coverage-Correlation": 

93 if "Coverage-Correlation" in self.plots: 

94 fig = self.plots["Coverage-Correlation"] 

95 show_figure(fig) 

96 plt.show() 

97 else: 

98 results = [] 

99 for epoch in range( 

100 result.model.config.checkpoint_interval, 

101 result.model.config.epochs + 1, 

102 result.model.config.checkpoint_interval, 

103 ): 

104 for split in ["train", "valid"]: 

105 latent_df = result.get_latent_df(epoch=epoch - 1, split=split) 

106 tc = self._total_correlation(latent_df) 

107 cov = self._coverage_calc(latent_df) 

108 results.append( 

109 { 

110 "epoch": epoch, 

111 "split": split, 

112 "total_correlation": tc, 

113 "coverage": cov, 

114 } 

115 ) 

116 

117 df_metrics = pd.DataFrame(results) 

118 

119 fig, axes = plt.subplots(1, 2, figsize=(12, 5)) 

120 

121 # Total Correlation plot 

122 _ = sns.lineplot( 

123 data=df_metrics, 

124 x="epoch", 

125 y="total_correlation", 

126 hue="split", 

127 ax=axes[0], 

128 ) 

129 axes[0].set_title("Total Correlation") 

130 axes[0].set_xlabel("Epoch") 

131 axes[0].set_ylabel("Total Correlation") 

132 

133 # Coverage plot 

134 _ = sns.lineplot( 

135 data=df_metrics, x="epoch", y="coverage", hue="split", ax=axes[1] 

136 ) 

137 axes[1].set_title("Coverage") 

138 axes[1].set_xlabel("Epoch") 

139 axes[1].set_ylabel("Coverage") 

140 

141 plt.tight_layout() 

142 self.plots["Coverage-Correlation"] = fig 

143 show_figure(fig) 

144 plt.show() 

145 

146 else: 

147 # Set Defaults 

148 if epoch is None: 

149 epoch = result.model.config.epochs - 1 

150 

151 # ## Getting clin_data 

152 clin_data = self._collect_all_metadata(result=result) 

153 # if hasattr(result.datasets.train, "metadata"): 

154 # # Check if metadata is a dictionary and contains 'paired' 

155 # if isinstance(result.datasets.train.metadata, dict): 

156 # if "paired" in result.datasets.train.metadata: 

157 # clin_data = result.datasets.train.metadata["paired"] 

158 # if hasattr(result.datasets, "test"): 

159 # clin_data = pd.concat( 

160 # [ 

161 # clin_data, 

162 # result.datasets.test.metadata[ # ty: ignore 

163 # "paired" 

164 # ], # ty: ignore 

165 # ], # ty: ignore 

166 # axis=0, 

167 # ) 

168 # if hasattr(result.datasets, "valid"): 

169 # clin_data = pd.concat( 

170 # [ 

171 # clin_data, 

172 # result.datasets.valid.metadata[ # ty: ignore 

173 # "paired" 

174 # ], # ty: ignore 

175 # ], # ty: ignore 

176 # axis=0, 

177 # ) 

178 # else: 

179 # # Iterate over all splits and keys, concatenate if DataFrame 

180 # clin_data = pd.DataFrame() 

181 # for split_name in ["train", "test", "valid"]: 

182 # split_temp = getattr(result.datasets, split_name, None) 

183 # if split_temp is not None and hasattr( 

184 # split_temp, "metadata" 

185 # ): 

186 # for key in split_temp.metadata.keys(): 

187 # if isinstance( 

188 # split_temp.metadata[key], pd.DataFrame 

189 # ): 

190 # clin_data = pd.concat( 

191 # [ 

192 # clin_data, 

193 # split_temp.metadata[key], 

194 # ], 

195 # axis=0, 

196 # ) 

197 # # remove duplicate rows 

198 # clin_data = clin_data[~clin_data.index.duplicated(keep="first")] 

199 # # if clin_data.empty: 

200 # # # Raise error no annotation given 

201 # # raise ValueError( 

202 # # "Please provide paired annotation data with key 'paired' in metadata dictionary." 

203 # # ) 

204 # elif isinstance(result.datasets.train.metadata, pd.DataFrame): 

205 # clin_data = result.datasets.train.metadata 

206 # if hasattr(result.datasets, "test"): 

207 # clin_data = pd.concat( 

208 # [clin_data, result.datasets.test.metadata], # ty: ignore 

209 # axis=0, 

210 # ) 

211 # if hasattr(result.datasets, "valid"): 

212 # clin_data = pd.concat( 

213 # [clin_data, result.datasets.valid.metadata], # ty: ignore 

214 # axis=0, 

215 # ) 

216 # else: 

217 # # Raise error no annotation given 

218 # raise ValueError( 

219 # "Metadata is not a dictionary or DataFrame. Please provide a valid annotation data type." 

220 # ) 

221 # else: 

222 # # Iterate over all splits and keys, concatenate if DataFrame 

223 # clin_data = pd.DataFrame() 

224 # for split_name in ["train", "test", "valid"]: 

225 # split_temp = getattr(result.datasets, split_name, None) 

226 # if split_temp is not None: 

227 # for key in split_temp.datasets.keys(): 

228 # if isinstance( 

229 # split_temp.datasets[key].metadata, pd.DataFrame 

230 # ): 

231 # clin_data = pd.concat( 

232 # [ 

233 # clin_data, 

234 # split_temp.datasets[key].metadata, 

235 # ], 

236 # axis=0, 

237 # ) 

238 # if len(clin_data) == 0: ## New predict case 

239 # for split_name in ["train", "test", "valid"]: 

240 # split_temp = getattr(result.new_datasets, split_name, None) 

241 # if split_temp is not None: 

242 # if len(split_temp.datasets.keys()) > 0: 

243 # for key in split_temp.datasets.keys(): 

244 # if isinstance( 

245 # split_temp.datasets[key].metadata, pd.DataFrame 

246 # ): 

247 # clin_data = pd.concat( 

248 # [ 

249 # clin_data, 

250 # split_temp.datasets[key].metadata, 

251 # ], 

252 # axis=0, 

253 # ) 

254 # else: 

255 # if isinstance( 

256 # split_temp.metadata, pd.DataFrame 

257 # ): 

258 # clin_data = pd.concat( 

259 # [ 

260 # clin_data, 

261 # split_temp.metadata, 

262 # ], 

263 # axis=0, 

264 # ) 

265 # # remove duplicate rows 

266 # clin_data = clin_data[~clin_data.index.duplicated(keep="first")] 

267 

268 # # Raise error no annotation given 

269 # raise ValueError( 

270 # "No annotation data found. Please provide a valid annotation data type." 

271 # ) 

272 

273 if split == "all": 

274 df_latent = pd.concat( 

275 [ 

276 result.get_latent_df(epoch=epoch, split="train"), 

277 result.get_latent_df(epoch=epoch, split="valid"), 

278 result.get_latent_df(epoch=-1, split="test"), 

279 ] 

280 ) 

281 else: 

282 if split == "test": 

283 df_latent = result.get_latent_df(epoch=-1, split=split) 

284 else: 

285 df_latent = result.get_latent_df(epoch=epoch, split=split) 

286 

287 ## Label options 

288 if labels is None and param is None: 

289 labels = ["all"] * df_latent.shape[0] 

290 

291 if labels is None and isinstance(param, str): 

292 if param == "all": 

293 param = list(clin_data.columns) 

294 else: 

295 raise ValueError( 

296 "Please provide parameter to plot as a list not as string. If you want to plot all parameters, set param to 'all' and labels to None." 

297 ) 

298 

299 if labels is not None and param is not None: 

300 raise ValueError( 

301 "Please provide either labels or param, not both. If you want to plot all parameters, set param to 'all' and labels to None." 

302 ) 

303 

304 if labels is not None and param is None: 

305 if isinstance(labels, pd.Series): 

306 param = [labels.name] 

307 # Order by index of df_latent first, fill missing with "unknown" 

308 labels = labels.reindex( 

309 df_latent.index, fill_value="unknown" 

310 ).tolist() 

311 else: 

312 param = ["user_label"] # Default label if none provided 

313 if not isinstance(param, list): 

314 raise TypeError("Param needs to be converted to a list") 

315 for p in param: 

316 if p in clin_data.columns: 

317 labels = clin_data.loc[df_latent.index, p].tolist() # ty: ignore 

318 

319 if n_downsample is not None: 

320 if df_latent.shape[0] > n_downsample: 

321 sample_idx = np.random.choice( 

322 df_latent.shape[0], n_downsample, replace=False 

323 ) 

324 df_latent = df_latent.iloc[sample_idx] 

325 if labels is not None: 

326 labels = [labels[i] for i in sample_idx] 

327 

328 if plot_type == "2D-scatter": 

329 ## Make 2D Embedding with UMAP 

330 if df_latent.shape[1] > 2: 

331 reducer = UMAP(n_components=2) 

332 embedding = pd.DataFrame(reducer.fit_transform(df_latent)) 

333 else: 

334 embedding = df_latent 

335 

336 fig = self._plot_2D( 

337 embedding=embedding, 

338 labels=labels, 

339 focus_labels=focus_labels, 

340 param=p, 

341 layer=f"2D latent space (epoch {epoch+1})", # we start counting epochs at 0, so add 1 for display 

342 figsize=(12, 8), 

343 center=True, 

344 ) 

345 if focus_labels is None: 

346 self.plots["2D-scatter"][epoch][split][p] = fig 

347 else: 

348 focus_group = "group_" + str( 

349 len( 

350 self.plots["2D-scatter"][epoch][split][ 

351 p + "_focus" 

352 ].keys() 

353 ) 

354 + 1 

355 ) 

356 self.plots["2D-scatter"][epoch][split][p + "_focus"][ 

357 focus_group 

358 ] = fig 

359 show_figure(fig) 

360 plt.show() 

361 

362 if plot_type == "Ridgeline": 

363 ## Make ridgeline plot 

364 

365 fig = self._plot_latent_ridge( 

366 lat_space=df_latent, 

367 labels=labels, 

368 focus_labels=focus_labels, 

369 param=p, 

370 ) 

371 if focus_labels is None: 

372 self.plots["Ridgeline"][epoch][split][p] = fig 

373 else: 

374 focus_group = "group_" + str( 

375 len( 

376 self.plots["Ridgeline"][epoch][split][ 

377 p + "_focus" 

378 ].keys() 

379 ) 

380 + 1 

381 ) 

382 self.plots["Ridgeline"][epoch][split][p + "_focus"][ 

383 focus_group 

384 ] = fig 

385 show_figure(fig.figure) 

386 plt.show() 

387 

388 if plot_type == "Clustermap": 

389 ## Make clustermap plot 

390 

391 fig = self._plot_latent_clustermap( 

392 lat_space=df_latent, 

393 labels=labels, 

394 focus_labels=focus_labels, 

395 param=p, 

396 ) 

397 if focus_labels is None: 

398 self.plots["Clustermap"][epoch][split][p] = fig 

399 else: 

400 focus_group = "group_" + str( 

401 len( 

402 self.plots["Clustermap"][epoch][split][ 

403 p + "_focus" 

404 ].keys() 

405 ) 

406 + 1 

407 ) 

408 self.plots["Clustermap"][epoch][split][p + "_focus"][ 

409 focus_group 

410 ] = fig 

411 show_figure(fig) 

412 plt.show() 

413 

414 def show_weights(self) -> None: 

415 """Display the model weights plot if it exists in the plots dictionary.""" 

416 

417 if "ModelWeights" not in self.plots.keys(): 

418 print("Model weights not found in the plots dictionary") 

419 print("You need to run visualize() method first") 

420 else: 

421 fig = self.plots["ModelWeights"] 

422 show_figure(fig) 

423 plt.show() 

424 

425 ### Moved to Base 

426 # def show_evaluation( 

427 # self, 

428 # param: str, 

429 # metric: str, 

430 # ml_alg: Optional[str] = None, 

431 # ) -> None: 

432 

433 ### Utilities ### 

434 @staticmethod 

435 def _plot_2D( 

436 embedding: pd.DataFrame, 

437 labels: list, 

438 focus_labels: Optional[Union[list, None]] = None, 

439 param: Optional[Union[str, None]] = None, 

440 layer: str = "latent space", 

441 figsize: tuple = (24, 15), 

442 center: bool = True, 

443 plot_numeric: bool = False, 

444 xlim: Optional[Union[tuple, None]] = None, 

445 ylim: Optional[Union[tuple, None]] = None, 

446 scale: Optional[Union[str, None]] = None, 

447 no_leg: bool = False, 

448 ) -> matplotlib.figure.Figure: 

449 """Plots a 2D scatter plot of the given embedding with labels. 

450 

451 Args: 

452 embedding: DataFrame containing the 2D embedding coordinates. 

453 labels: List of labels corresponding to each point in the embedding. 

454 focus_labels: List of labels which should be considered for coloring. All other labels are set to 'other'. Defaults to None where all labels are considered. 

455 param: Title for the legend. Defaults to None. 

456 layer: Title for the plot. Defaults to "latent space". 

457 figsize: Size of the figure. Defaults to (24, 15). 

458 center: If True, centers the plot based on label means. Defaults to True. 

459 plot_numeric: If True, treats labels as numeric. Defaults to False. 

460 xlim: Limits for the x-axis. Defaults to None. 

461 ylim: Limits for the y-axis. Defaults to None. 

462 scale:: Scale for the axes (e.g., 'log'). Defaults to None. 

463 no_leg: If True, no legend is displayed. Defaults to False. 

464 

465 Returns: 

466 The resulting matplotlib figure. 

467 """ 

468 

469 numeric = False 

470 if not isinstance(labels[0], str): 

471 if len(np.unique(labels)) > 3: 

472 if not plot_numeric: 

473 print( 

474 "The provided label column is numeric and converted to categories." 

475 ) 

476 labels = [ 

477 float("nan") if not isinstance(x, float) else x for x in labels 

478 ] 

479 labels = ( 

480 pd.qcut( 

481 x=pd.Series(labels), 

482 q=4, 

483 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"], 

484 ) 

485 .astype(str) 

486 .to_list() 

487 ) 

488 else: 

489 center = False ## Disable centering for numeric params 

490 numeric = True 

491 else: 

492 labels = [str(x) for x in labels] 

493 

494 # check if label or embedding is longerm and duplicate the shorter one 

495 if len(labels) < embedding.shape[0]: 

496 print( 

497 "Given labels do not have the same length as given sample size. Labels will be duplicated." 

498 ) 

499 labels = [ 

500 label 

501 for label in labels 

502 for _ in range(embedding.shape[0] // len(labels)) 

503 ] 

504 elif len(labels) > embedding.shape[0]: 

505 labels = list(set(labels)) 

506 

507 if len(np.unique(labels)) > 20 and focus_labels is None: 

508 warnings.warn( 

509 f"The provided label column has {len(np.unique(labels))} unique labels which might make the scatter plot unclear." 

510 ) 

511 # Restrict to top 20 labels 

512 focus_labels = pd.Series(labels).value_counts().nlargest(20).index.tolist() 

513 print(f"Focusing on top 20 labels instead") 

514 

515 if focus_labels is not None: 

516 labels = [label if label in focus_labels else "other" for label in labels] 

517 

518 # Increase figure size width if legend has has more than 10 labels (two columns) 

519 if len(np.unique(labels)) > 10: 

520 figsize = (figsize[0] * 1.5, figsize[1]) 

521 # Increase figure size width if legend labels are very long 

522 max_label_length = max([len(str(label)) for label in np.unique(labels)]) 

523 figsize = (int(figsize[0] + max_label_length * 0.2), figsize[1]) 

524 

525 fig, ax2 = plt.subplots(1, 1, figsize=figsize) 

526 if numeric: 

527 ax2 = sns.scatterplot( 

528 x=embedding.iloc[:, 0], 

529 y=embedding.iloc[:, 1], 

530 hue=labels, 

531 palette="bwr", 

532 s=40, 

533 alpha=0.5, 

534 ec="black", 

535 ) 

536 else: 

537 if len(np.unique(labels)) > 8: 

538 cat_pal = sns.color_palette("tab20", n_colors=len(np.unique(labels))) 

539 else: 

540 cat_pal = sns.color_palette("tab10", n_colors=len(np.unique(labels))) 

541 

542 if "other" in np.unique(labels): 

543 # set color of "other" to light grey 

544 other_color = (0.3, 0.3, 0.3) 

545 cat_pal[list(np.unique(labels)).index("other")] = other_color 

546 

547 # Adjust alpha depending on number of points 

548 if len(labels) > 10000: 

549 point_alpha = 0.2 

550 point_size = 10 

551 elif len(labels) > 5000: 

552 point_alpha = 0.4 

553 point_size = 20 

554 else: 

555 point_alpha = 0.7 

556 point_size = 40 

557 

558 ax2 = sns.scatterplot( 

559 x=embedding.iloc[:, 0], 

560 y=embedding.iloc[:, 1], 

561 hue=labels, 

562 hue_order=np.unique(labels), 

563 palette=cat_pal, 

564 s=point_size, 

565 alpha=point_alpha, 

566 ec="black", 

567 ) 

568 if center: 

569 means = embedding.groupby(by=labels).mean() 

570 

571 ax2 = sns.scatterplot( 

572 x=means.iloc[:, 0], 

573 y=means.iloc[:, 1], 

574 hue=np.unique(labels), 

575 hue_order=np.unique(labels), 

576 palette=cat_pal, 

577 s=200, 

578 ec="black", 

579 alpha=0.7, 

580 marker="*", 

581 legend=False, 

582 ax=ax2, 

583 ) 

584 

585 if xlim is not None: 

586 ax2.set_xlim(xlim[0], xlim[1]) 

587 

588 if ylim is not None: 

589 ax2.set_ylim(ylim[0], ylim[1]) 

590 

591 if scale is not None: 

592 plt.yscale(scale) 

593 plt.xscale(scale) 

594 ax2.set_xlabel("Dim 1") 

595 ax2.set_ylabel("Dim 2") 

596 legend_cols = 1 

597 if len(np.unique(labels)) > 10: 

598 legend_cols = 2 

599 

600 if no_leg: 

601 plt.legend([], [], frameon=False) 

602 else: 

603 sns.move_legend( 

604 ax2, 

605 "upper left", 

606 bbox_to_anchor=(1, 1), 

607 ncol=legend_cols, 

608 title=param, 

609 frameon=False, 

610 ) 

611 

612 # Add title to the plot 

613 ax2.set_title(layer) 

614 plt.tight_layout() 

615 

616 plt.close() 

617 return fig 

618 

619 @staticmethod 

620 def _plot_latent_clustermap( 

621 lat_space: pd.DataFrame, 

622 labels: Optional[Union[list, pd.Series, None]] = None, 

623 focus_labels: Optional[Union[list, None]] = None, 

624 param: Optional[Union[str, None]] = None, 

625 ) -> matplotlib.figure.Figure: 

626 """Creates a clustermap of the latent space dimension where each row shows the intensity of a latent dimension and columns are clustered. 

627 

628 Args: 

629 lat_space: DataFrame containing the latent space intensities for samples (rows) and latent dimensions (columns) 

630 labels: List of labels for each sample. If None, all samples are considered as one group. 

631 focus_labels: List of labels which should be considered for coloring. All other labels are set to 'other'. Defaults to None where all labels are considered. 

632 param: Clinical parameter to create groupings and coloring of ridges. Must be a column name (str) of clin_data 

633 Returns: 

634 fig: Figure object containing the clustermap 

635 """ 

636 if len(np.unique(labels)) > 50 and focus_labels is None: 

637 warnings.warn( 

638 f"The provided label column has {len(np.unique(labels))} unique labels which might make the clustermap plot too big." 

639 ) 

640 # Restrict to top 50 labels 

641 focus_labels = pd.Series(labels).value_counts().nlargest(50).index.tolist() 

642 print(f"Focusing on top 50 labels instead") 

643 

644 if focus_labels is not None: 

645 labels = [label if label in focus_labels else "other" for label in labels] 

646 

647 lat_space[param] = labels 

648 

649 cluster_figure = sns.clustermap( 

650 lat_space.groupby(param).mean(), 

651 col_cluster=False, 

652 row_cluster=True, 

653 figsize=(1 * lat_space.shape[1], 4 + 0.5 * len(set(labels))), 

654 dendrogram_ratio=0.1, 

655 cmap="icefire", 

656 cbar_kws={"orientation": "horizontal"}, 

657 cbar_pos=(0.2, 0.95, 0.3, 0.02), 

658 ).fig 

659 

660 plt.close() 

661 lat_space.drop(columns=[param], inplace=True) 

662 return cluster_figure 

663 

664 @staticmethod 

665 def _plot_latent_ridge( 

666 lat_space: pd.DataFrame, 

667 labels: Optional[Union[list, pd.Series, None]] = None, 

668 focus_labels: Optional[Union[list, None]] = None, 

669 param: Optional[Union[str, None]] = None, 

670 ) -> sns.FacetGrid: 

671 """Creates a ridge line plot of latent space dimension where each row shows the density of a latent dimension and groups (ridges). 

672 

673 Args: 

674 lat_space: DataFrame containing the latent space intensities for samples (rows) and latent dimensions (columns) 

675 labels: List of labels for each sample. If None, all samples are considered as one group. 

676 focus_labels: List of labels which should be considered for coloring. All other labels are set to 'other'. Defaults to None where all labels are considered. 

677 param: Clinical parameter to create groupings and coloring of ridges. Must be a column name (str) of clin_data 

678 Returns: 

679 g: FacetGrid object containing the ridge line plot 

680 """ 

681 sns.set_theme( 

682 style="white", rc={"axes.facecolor": (0, 0, 0, 0)} 

683 ) ## Necessary to enforce overplotting 

684 

685 df = pd.melt(lat_space, var_name="latent dim", value_name="latent intensity") 

686 df["sample"] = len(lat_space.columns) * list(lat_space.index) 

687 

688 if labels is None: 

689 param = "all" 

690 labels = ["all"] * len(df) 

691 

692 # print(labels[0]) 

693 if not isinstance(labels[0], str): 

694 if len(np.unique(labels)) > 3: 

695 # Change all non-float labels to NaN 

696 labels = [x if isinstance(x, float) else float("nan") for x in labels] 

697 labels = list( 

698 pd.qcut( 

699 x=pd.Series(labels), 

700 q=4, 

701 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"], 

702 ).astype(str) 

703 ) 

704 else: 

705 labels = [str(x) for x in labels] 

706 

707 if len(np.unique(labels)) > 20 and focus_labels is None: 

708 warnings.warn( 

709 f"The provided label column has {len(np.unique(labels))} unique labels which might make the ridgeline plot unclear." 

710 ) 

711 # Restrict to top 20 labels 

712 focus_labels = pd.Series(labels).value_counts().nlargest(20).index.tolist() 

713 print(f"Focusing on top 20 labels instead") 

714 

715 if focus_labels is not None: 

716 labels = [label if label in focus_labels else "other" for label in labels] 

717 

718 df[param] = len(lat_space.columns) * labels # type: ignore 

719 

720 exclude_missing_info = (df[param] == "unknown") | (df[param] == "nan") 

721 

722 xmin = ( 

723 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]] 

724 .groupby([param, "latent dim"], observed=False) 

725 .quantile(0.05) 

726 .min() 

727 ) 

728 xmax = ( 

729 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]] 

730 .groupby([param, "latent dim"], observed=False) 

731 .quantile(0.9) 

732 .max() 

733 ) 

734 

735 # if len(np.unique(df[param])) > 8: 

736 # cat_pal = sns.husl_palette(len(np.unique(df[param]))) 

737 # else: 

738 # cat_pal = sns.color_palette(n_colors=len(np.unique(df[param]))) 

739 

740 if len(np.unique(labels)) > 8: 

741 cat_pal = sns.color_palette("tab20", n_colors=len(labels)) 

742 else: 

743 cat_pal = sns.color_palette("tab10", n_colors=len(labels)) 

744 

745 if "other" in np.unique(labels): 

746 # set color of "other" to light grey 

747 other_color = (0.3, 0.3, 0.3) 

748 cat_pal[list(np.unique(labels)).index("other")] = other_color 

749 

750 # Length of longest latent dim string for aspect ratio 

751 len_longest_latent_dim = max([len(str(x)) for x in lat_space.columns]) 

752 

753 g = sns.FacetGrid( 

754 df[~exclude_missing_info], 

755 row="latent dim", 

756 hue=param, 

757 aspect=12 + len_longest_latent_dim / 4, 

758 height=0.8, 

759 xlim=(xmin.iloc[0], xmax.iloc[0]), 

760 palette=cat_pal, 

761 ) 

762 

763 g.map_dataframe( 

764 sns.kdeplot, 

765 "latent intensity", 

766 bw_adjust=0.5, 

767 clip_on=True, 

768 fill=True, 

769 alpha=0.5, 

770 warn_singular=False, 

771 ec="k", 

772 lw=1, 

773 ) 

774 

775 def label(data, color, label, text="latent dim"): 

776 ax = plt.gca() 

777 label_text = data[text].unique()[0] 

778 ax.text( 

779 0.0, 

780 0.2, 

781 label_text, 

782 fontweight="bold", 

783 ha="right", 

784 va="center", 

785 transform=ax.transAxes, 

786 ) 

787 

788 g.map_dataframe(label, text="latent dim") 

789 

790 g.set(xlim=(xmin.iloc[0], xmax.iloc[0])) 

791 # Set the subplots to overlap 

792 g.figure.subplots_adjust(hspace=-0.5) 

793 

794 # Remove axes details that don't play well with overlap 

795 g.set_titles("") 

796 g.set(yticks=[], ylabel="") 

797 g.despine(bottom=True, left=True) 

798 

799 g.add_legend() 

800 

801 plt.close() 

802 return g 

803 

804 def _plot_evaluation( 

805 self, 

806 result: Result, 

807 ) -> dict: 

808 """Plots the evaluation results from the Result object. 

809 

810 Args: 

811 result: The Result object containing evaluation data. 

812 

813 Returns: 

814 The generated dictionary containing the evaluation plots. 

815 """ 

816 ## Plot all results 

817 

818 ml_plots = dict() 

819 plt.ioff() 

820 if not hasattr(result.embedding_evaluation, "CLINIC_PARAM"): 

821 warnings.warn( 

822 "We could not create visualizations for the evaluation plots.\n" 

823 "This usually happens if you try to visualize after saving and loading " 

824 "the pipeline object with `save_all=False`. This memory-efficient saving mode " 

825 "Set save_all=True to avoid this, also this might be fixed soon." 

826 ) 

827 return {} 

828 

829 for c in pd.unique(result.embedding_evaluation.CLINIC_PARAM): 

830 ml_plots[c] = dict() 

831 for m in pd.unique( # ty: ignore 

832 result.embedding_evaluation.loc[ 

833 result.embedding_evaluation.CLINIC_PARAM == c, "metric" 

834 ] 

835 ): 

836 ml_plots[c][m] = dict() 

837 for alg in pd.unique( # ty: ignore 

838 result.embedding_evaluation.loc[ 

839 (result.embedding_evaluation.CLINIC_PARAM == c) 

840 & (result.embedding_evaluation.metric == m), 

841 "ML_ALG", 

842 ] 

843 ): 

844 data = result.embedding_evaluation[ 

845 (result.embedding_evaluation.metric == m) 

846 & (result.embedding_evaluation.CLINIC_PARAM == c) 

847 & (result.embedding_evaluation.ML_ALG == alg) 

848 ] 

849 

850 # Check for missing values 

851 if data["value"].isnull().any(): 

852 warnings.warn( 

853 f"Missing values found in evaluation data for parameter '{c}', metric '{m}', and algorithm '{alg}'. These will be ignored in the plot." 

854 ) 

855 data = data.dropna() 

856 

857 sns_plot = sns.catplot( 

858 data=data, 

859 x="score_split", 

860 y="value", 

861 col="ML_TASK", 

862 hue="score_split", 

863 kind="bar", 

864 ) 

865 

866 min_y = data.value.min() 

867 if min_y > 0: 

868 min_y = 0 

869 

870 ml_plots[c][m][alg] = sns_plot.set(ylim=(min_y, None)) 

871 

872 self.plots["ML_Evaluation"] = ml_plots 

873 

874 return ml_plots 

875 

876 @staticmethod 

877 def _total_correlation(latent_space: pd.DataFrame) -> float: 

878 """Function to compute the total correlation as described here (Equation2): https://doi.org/10.3390/e21100921 

879 

880 Args: 

881 latent_space: latent space with dimension sample vs. latent dimensions 

882 Returns: 

883 tc: total correlation across latent dimensions 

884 """ 

885 lat_cov = np.cov(latent_space.T) 

886 tc = 0.5 * (np.sum(np.log(np.diag(lat_cov))) - np.linalg.slogdet(lat_cov)[1]) 

887 return tc 

888 

889 @staticmethod 

890 def _coverage_calc(latent_space: pd.DataFrame) -> float: 

891 """Function to compute the coverage as described here (Equation3): https://doi.org/10.3390/e21100921 

892 

893 Args: 

894 latent_space: latent space with dimension sample vs. latent dimensions 

895 Returns: 

896 cov: coverage across latent dimensions 

897 """ 

898 bins_per_dim = int( 

899 np.power(len(latent_space.index), 1 / len(latent_space.columns)) 

900 ) 

901 if bins_per_dim < 2: 

902 warnings.warn( 

903 "Coverage calculation fails since combination of sample size and latent dimension results in less than 2 bins." 

904 ) 

905 cov = np.nan 

906 else: 

907 latent_bins = latent_space.apply(lambda x: pd.cut(x, bins=bins_per_dim)) 

908 latent_bins = pd.Series(zip(*[latent_bins[col] for col in latent_bins])) 

909 cov = len(latent_bins.unique()) / np.power( 

910 bins_per_dim, len(latent_space.columns) 

911 ) 

912 

913 return cov