Coverage for src / autoencodix / visualize / visualize.py: 10%

428 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import os 

2from dataclasses import field 

3from typing import Any, Dict, Optional, Union, Literal, no_type_check 

4import warnings 

5 

6import matplotlib.figure 

7import numpy as np 

8import pandas as pd 

9import seaborn as sns # type: ignore 

10import torch 

11from matplotlib import pyplot as plt 

12from umap import UMAP # type: ignore 

13 

14from autoencodix.base._base_visualizer import BaseVisualizer 

15from autoencodix.utils._result import Result 

16from autoencodix.utils._utils import nested_dict, nested_to_tuple, show_figure 

17from autoencodix.configs.default_config import DefaultConfig 

18 

19 

20class Visualizer(BaseVisualizer): 

21 plots: Dict[str, Any] = field( 

22 default_factory=nested_dict 

23 ) ## Nested dictionary of plots as figure handles 

24 

25 def __init__(self): 

26 self.plots = nested_dict() 

27 

28 def __setitem__(self, key, elem): 

29 self.plots[key] = elem 

30 

31 def visualize(self, result: Result, config: DefaultConfig) -> Result: 

32 ## Make Model Weights plot 

33 self.plots["ModelWeights"] = self.plot_model_weights(model=result.model) 

34 

35 ## Make long format of losses 

36 loss_df_melt = self.make_loss_format(result=result, config=config) 

37 

38 ## Make plot loss absolute 

39 self.plots["loss_absolute"] = self.make_loss_plot( 

40 df_plot=loss_df_melt, plot_type="absolute" 

41 ) 

42 ## Make plot loss relative 

43 self.plots["loss_relative"] = self.make_loss_plot( 

44 df_plot=loss_df_melt, plot_type="relative" 

45 ) 

46 

47 return result 

48 

49 ## Plotting methods ## 

50 

51 def save_plots( 

52 self, path: str, which: Union[str, list] = "all", format: str = "png" 

53 ) -> None: 

54 """Save specified plots to the given path in the specified format. 

55 

56 Args: 

57 path: The directory path where the plots will be saved. 

58 which: A list of plot names to save or a string specifying which plots to save. 

59 If 'all', all plots in the plots dictionary will be saved. 

60 If a single plot name is provided as a string, only that plot will be saved. 

61 format: The file format in which to save the plots (e.g., 'png', 'jpg'). 

62 

63 Raises: 

64 ValueError: If the 'which' parameter is not a list or a string. 

65 """ 

66 if not isinstance(which, list): 

67 ## Case when which is a string 

68 if which == "all": 

69 ## Case when all plots are to be saved 

70 if len(self.plots) == 0: 

71 print("No plots found in the plots dictionary") 

72 print("You need to run visualize() method first") 

73 else: 

74 for item in nested_to_tuple(self.plots): 

75 fig = item[-1] ## Figure is in last element of the tuple 

76 filename = "_".join(str(x) for x in item[0:-1]) 

77 fullpath = os.path.join(path, filename) 

78 fig.savefig(f"{fullpath}.{format}") 

79 else: 

80 ## Case when a single plot is provided as string 

81 if which not in self.plots.keys(): 

82 print(f"Plot {which} not found in the plots dictionary") 

83 print(f"All available plots are: {list(self.plots.keys())}") 

84 else: 

85 for item in nested_to_tuple( 

86 self.plots[which] 

87 ): # Plot all epochs and splits of type which 

88 fig = item[-1] ## Figure is in last element of the tuple 

89 filename = ( 

90 which # ty: ignore 

91 + "_" 

92 + "_".join(str(x) for x in item[0:-1]) 

93 ) 

94 fullpath = os.path.join(path, filename) 

95 fig.savefig(f"{fullpath}.{format}") 

96 else: 

97 ## Case when which is a list of plot specified as strings 

98 for key in which: 

99 if key not in self.plots.keys(): 

100 print(f"Plot {key} not found in the plots dictionary") 

101 print(f"All available plots are: {list(self.plots.keys())}") 

102 continue 

103 else: 

104 for item in nested_to_tuple( 

105 self.plots[key] 

106 ): # Plot all epochs and splits of type key 

107 fig = item[-1] ## Figure is in last element of the tuple 

108 filename = key + "_" + "_".join(str(x) for x in item[0:-1]) 

109 fullpath = os.path.join(path, filename) 

110 fig.savefig(f"{fullpath}.{format}") 

111 

112 def show_loss( 

113 self, plot_type: Literal["absolute", "relative"] = "absolute" 

114 ) -> None: 

115 """Display the loss plot. 

116 

117 Args: 

118 plot_type: The type of loss plot to display. Defaults to "absolute". 

119 """ 

120 if plot_type == "absolute": 

121 if "loss_absolute" not in self.plots.keys(): 

122 print("Absolute loss plot not found in the plots dictionary") 

123 print("You need to run visualize() method first") 

124 else: 

125 fig = self.plots["loss_absolute"] 

126 show_figure(fig) 

127 plt.show() 

128 if plot_type == "relative": 

129 if "loss_relative" not in self.plots.keys(): 

130 print("Relative loss plot not found in the plots dictionary") 

131 print("You need to run visualize() method first") 

132 else: 

133 fig = self.plots["loss_relative"] 

134 show_figure(fig) 

135 plt.show() 

136 

137 if plot_type not in ["absolute", "relative"]: 

138 print( 

139 "Type of loss plot not recognized. Please use 'absolute' or 'relative'" 

140 ) 

141 

142 @no_type_check 

143 def show_latent_space( 

144 self, 

145 result: Result, 

146 plot_type: str = "2D-scatter", 

147 labels: Optional[Union[list, pd.Series, None]] = None, 

148 param: Optional[Union[list, str]] = None, 

149 epoch: Optional[Union[int, None]] = None, 

150 split: str = "all", 

151 **kwargs, 

152 ) -> None: 

153 """Visualizes the latent space of the given result using different types of plots. 

154 

155 Args: 

156 result: The result object containing latent spaces and losses. 

157 plot_type The type of plot to generate. Options are "2D-scatter", "Ridgeline", and "Coverage-Correlation". Default is "2D-scatter". 

158 labels: List of labels for the data points in the latent space. Default is None. 

159 param : List of parameters provided and stored as metadata. Strings must match column names. If not a list, string "all" is expected for convenient way to make plots for all parameters available. Default is None where no colored labels are plotted. 

160 epoch: The epoch number to visualize. If None, the last epoch is inferred from the losses. Default is None. 

161 split: The data split to visualize. Options are "train", "valid", "test", and "all". Default is "all". 

162 

163 """ 

164 plt.ioff() 

165 if plot_type == "Coverage-Correlation": 

166 if "Coverage-Correlation" in self.plots: 

167 fig = self.plots["Coverage-Correlation"] 

168 show_figure(fig) 

169 plt.show() 

170 else: 

171 results = [] 

172 for epoch in range( 

173 result.model.config.checkpoint_interval, 

174 result.model.config.epochs + 1, 

175 result.model.config.checkpoint_interval, 

176 ): 

177 for split in ["train", "valid"]: 

178 latent_df = result.get_latent_df(epoch=epoch - 1, split=split) 

179 tc = self._total_correlation(latent_df) 

180 cov = self._coverage_calc(latent_df) 

181 results.append( 

182 { 

183 "epoch": epoch, 

184 "split": split, 

185 "total_correlation": tc, 

186 "coverage": cov, 

187 } 

188 ) 

189 

190 df_metrics = pd.DataFrame(results) 

191 

192 fig, axes = plt.subplots(1, 2, figsize=(12, 5)) 

193 

194 # Total Correlation plot 

195 _ = sns.lineplot( 

196 data=df_metrics, 

197 x="epoch", 

198 y="total_correlation", 

199 hue="split", 

200 ax=axes[0], 

201 ) 

202 axes[0].set_title("Total Correlation") 

203 axes[0].set_xlabel("Epoch") 

204 axes[0].set_ylabel("Total Correlation") 

205 

206 # Coverage plot 

207 _ = sns.lineplot( 

208 data=df_metrics, x="epoch", y="coverage", hue="split", ax=axes[1] 

209 ) 

210 axes[1].set_title("Coverage") 

211 axes[1].set_xlabel("Epoch") 

212 axes[1].set_ylabel("Coverage") 

213 

214 plt.tight_layout() 

215 self.plots["Coverage-Correlation"] = fig 

216 show_figure(fig) 

217 plt.show() 

218 

219 else: 

220 # Set Defaults 

221 if epoch is None: 

222 epoch = result.model.config.epochs - 1 

223 

224 ## Getting clin_data 

225 if not hasattr(result.datasets, "train"): 

226 raise ValueError("no train split in datasets") 

227 

228 if not hasattr(result.datasets, "valid"): 

229 raise ValueError("no valid split in datasets") 

230 if result.datasets.train is None: 

231 raise ValueError("train is None") 

232 if result.datasets.valid is None: 

233 raise ValueError("train is None") 

234 if result.datasets.test is None: 

235 raise ValueError("train is None") 

236 

237 if not hasattr(result.datasets.train, "metadata"): 

238 raise ValueError("train dataset has no metadata") 

239 if not hasattr(result.datasets.valid, "metadata"): 

240 raise ValueError("valid dataset has no metadata") 

241 

242 # Check if metadata is a dictionary and contains 'paired' 

243 if isinstance(result.datasets.train.metadata, dict): 

244 if "paired" in result.datasets.train.metadata: 

245 clin_data = result.datasets.train.metadata["paired"] 

246 if hasattr(result.datasets, "test"): 

247 clin_data = pd.concat( 

248 [clin_data, result.datasets.test.metadata["paired"]], 

249 axis=0, 

250 ) 

251 if hasattr(result.datasets, "valid"): 

252 clin_data = pd.concat( 

253 [clin_data, result.datasets.valid.metadata["paired"]], 

254 axis=0, 

255 ) 

256 else: 

257 # Raise error no annotation given 

258 raise ValueError( 

259 "Please provide paired annotation data with key 'paired' in metadata dictionary." 

260 ) 

261 elif isinstance(result.datasets.train.metadata, pd.DataFrame): 

262 clin_data = result.datasets.train.metadata 

263 if hasattr(result.datasets, "test"): 

264 clin_data = pd.concat( 

265 [clin_data, result.datasets.test.metadata], 

266 axis=0, 

267 ) 

268 if hasattr(result.datasets, "valid"): 

269 clin_data = pd.concat( 

270 [clin_data, result.datasets.valid.metadata], 

271 axis=0, 

272 ) 

273 else: 

274 # Raise error no annotation given 

275 raise ValueError( 

276 "Metadata is not a dictionary or DataFrame. Please provide a valid annotation data type." 

277 ) 

278 else: 

279 # Raise error no annotation given 

280 raise ValueError( 

281 "No annotation data found. Please provide a valid annotation data type." 

282 ) 

283 

284 if split == "all": 

285 df_latent = pd.concat( 

286 [ 

287 result.get_latent_df(epoch=epoch, split="train"), 

288 result.get_latent_df(epoch=epoch, split="valid"), 

289 result.get_latent_df(epoch=-1, split="test"), 

290 ] 

291 ) 

292 else: 

293 if split == "test": 

294 df_latent = result.get_latent_df(epoch=-1, split=split) 

295 else: 

296 df_latent = result.get_latent_df(epoch=epoch, split=split) 

297 

298 if labels is None and param is None: 

299 labels = ["all"] * df_latent.shape[0] 

300 

301 if labels is None and isinstance(param, str): 

302 if param == "all": 

303 param = list(clin_data.columns) 

304 else: 

305 raise ValueError( 

306 "Please provide parameter to plot as a list not as string. If you want to plot all parameters, set param to 'all' and labels to None." 

307 ) 

308 

309 if labels is not None and param is not None: 

310 raise ValueError( 

311 "Please provide either labels or param, not both. If you want to plot all parameters, set param to 'all' and labels to None." 

312 ) 

313 

314 if labels is not None and param is None: 

315 if isinstance(labels, pd.Series): 

316 param = [labels.name] 

317 # Order by index of df_latent first, fill missing with "unknown" 

318 labels = labels.reindex( 

319 df_latent.index, fill_value="unknown" 

320 ).tolist() 

321 else: 

322 param = ["user_label"] # Default label if none provided 

323 

324 for p in param: 

325 if p in clin_data.columns: 

326 labels = clin_data.loc[df_latent.index, p].tolist() 

327 

328 if plot_type == "2D-scatter": 

329 ## Make 2D Embedding with UMAP 

330 if df_latent.shape[1] > 2: 

331 reducer = UMAP(n_components=2) 

332 embedding = pd.DataFrame(reducer.fit_transform(df_latent)) 

333 else: 

334 embedding = df_latent 

335 

336 self.plots["2D-scatter"][epoch][split][p] = self.plot_2D( 

337 embedding=embedding, 

338 labels=labels, 

339 param=p, 

340 layer=f"2D latent space (epoch {epoch + 1})", # we start counting epochs at 0, so add 1 for display 

341 figsize=(12, 8), 

342 center=True, 

343 ) 

344 

345 fig = self.plots["2D-scatter"][epoch][split][p] 

346 show_figure(fig) 

347 plt.show() 

348 

349 if plot_type == "Ridgeline": 

350 ## Make ridgeline plot 

351 

352 self.plots["Ridgeline"][epoch][split][p] = self.plot_latent_ridge( 

353 lat_space=df_latent, labels=labels, param=p 

354 ) 

355 

356 fig = self.plots["Ridgeline"][epoch][split][p].figure 

357 show_figure(fig) 

358 plt.show() 

359 

360 def show_weights(self) -> None: 

361 """Display the model weights plot if it exists in the plots dictionary.""" 

362 

363 if "ModelWeights" not in self.plots.keys(): 

364 print("Model weights not found in the plots dictionary") 

365 print("You need to run visualize() method first") 

366 else: 

367 fig = self.plots["ModelWeights"] 

368 show_figure(fig) 

369 plt.show() 

370 

371 # def plot_model_weights(model: torch.nn.Module) -> matplotlib.figure.Figure: 

372 # """ 

373 # Visualization of model weights in encoder and decoder layers as heatmap for each layer as subplot. 

374 # ARGS: 

375 # model (torch.nn.Module): PyTorch model instance. 

376 # filepath (str): Path specifying save name and location. 

377 # RETURNS: 

378 # fig (matplotlib.figure): Figure handle (of last plot) 

379 # """ 

380 # all_weights = [] 

381 # names = [] 

382 # if hasattr(model, "ontologies"): 

383 # if model.ontologies is not None: 

384 # # If model is Ontix 

385 # # Get node names from ontologies 

386 # node_names = list() 

387 # for ontology in model.ontologies: 

388 # node_names.append(ontology.keys()) 

389 

390 # node_names.append(model.feature_order) # Add feature order as last layer 

391 

392 # for name, param in model.named_parameters(): 

393 # if "weight" in name and len(param.shape) == 2: 

394 # if "var" not in name: ## For VAE plot only mu weights 

395 # all_weights.append(param.detach().cpu().numpy()) 

396 # names.append(name[:-7]) 

397 

398 # layers = int(len(all_weights) / 2) 

399 # fig, axes = plt.subplots(2, layers, sharex=False, figsize=(20, 10)) 

400 

401 # for layer in range(layers): 

402 # ## Encoder Layer 

403 # if layers > 1: 

404 # sns.heatmap( 

405 # all_weights[layer], 

406 # cmap=sns.color_palette("Spectral", as_cmap=True), 

407 # ax=axes[0, layer], 

408 # ).set(title=names[layer]) 

409 # ## Decoder Layer 

410 # sns.heatmap( 

411 # all_weights[layers + layer], 

412 # cmap=sns.color_palette("Spectral", as_cmap=True), 

413 # ax=axes[1, layer], 

414 # ).set(title=names[layers + layer]) 

415 # axes[1, layer].set_xlabel("In Node", size=12) 

416 # if model.ontologies is not None: 

417 # axes[1, layer].set_xticks( 

418 # ticks=range(len(node_names[layer])), 

419 # labels=node_names[layer], 

420 # rotation=90, 

421 # fontsize=8, 

422 # ) 

423 # axes[1, layer].set_yticks( 

424 # ticks=range(len(node_names[layer + 1])), 

425 # labels=node_names[layer + 1], 

426 # rotation=0, 

427 # fontsize=8, 

428 # ) 

429 # else: 

430 # sns.heatmap( 

431 # all_weights[layer], 

432 # cmap=sns.color_palette("Spectral", as_cmap=True), 

433 # ax=axes[layer], 

434 # ).set(title=names[layer]) 

435 # ## Decoder Layer 

436 # sns.heatmap( 

437 # all_weights[layer + 2], 

438 # cmap=sns.color_palette("Spectral", as_cmap=True), 

439 # ax=axes[layer + 1], 

440 # ).set(title=names[layer + 2]) 

441 # axes[1].set_xlabel("In Node", size=12) 

442 

443 # if layers > 1: 

444 # axes[1, 0].set_ylabel("Out Node", size=12) 

445 # axes[0, 0].set_ylabel("Out Node", size=12) 

446 # else: 

447 # axes[1].set_ylabel("Out Node", size=12) 

448 # axes[0].set_ylabel("Out Node", size=12) 

449 

450 # ## Add title 

451 # fig.suptitle("Model Weights", size=20) 

452 # plt.close() 

453 # return fig 

454 

455 ## NEW VERSION 

456 # @staticmethod 

457 # def plot_model_weights(model: torch.nn.Module) -> matplotlib.figure.Figure: 

458 # """ 

459 # Visualization of model weights in encoder and decoder layers as heatmap for each layer as subplot. 

460 # ARGS: 

461 # model (torch.nn.Module): PyTorch model instance. 

462 # filepath (str): Path specifying save name and location. 

463 # RETURNS: 

464 # fig (matplotlib.figure): Figure handle (of last plot) 

465 # """ 

466 # all_weights = [] 

467 # names = [] 

468 # if hasattr(model, "ontologies"): 

469 # if model.ontologies is not None: 

470 # # If model is Ontix 

471 # # Get node names from ontologies 

472 # node_names = list() 

473 # for ontology in model.ontologies: 

474 # node_names.append(ontology.keys()) 

475 

476 # node_names.append(model.feature_order) # Add feature order as last layer 

477 

478 # for name, param in model.named_parameters(): 

479 # if "weight" in name and len(param.shape) == 2: 

480 # if "var" not in name: ## For VAE plot only mu weights 

481 # all_weights.append(param.detach().cpu().numpy()) 

482 # names.append(name[:-7]) 

483 

484 # layers = int(len(all_weights) / 2) 

485 # fig, axes = plt.subplots(2, layers, sharex=False, figsize=(20, 10)) 

486 

487 # for layer in range(layers): 

488 # ## Encoder Layer 

489 # if layers > 1: 

490 # sns.heatmap( 

491 # all_weights[layer], 

492 # cmap=sns.color_palette("Spectral", as_cmap=True), 

493 # ax=axes[0, layer], 

494 # ).set(title=names[layer]) 

495 # ## Decoder Layer 

496 # sns.heatmap( 

497 # all_weights[layers + layer], 

498 # cmap=sns.color_palette("Spectral", as_cmap=True), 

499 # ax=axes[1, layer], 

500 # ).set(title=names[layers + layer]) 

501 # axes[1, layer].set_xlabel("In Node", size=12) 

502 # if model.ontologies is not None: 

503 # axes[1, layer].set_xticks( 

504 # ticks=range(len(node_names[layer])), 

505 # labels=node_names[layer], 

506 # rotation=90, 

507 # fontsize=8, 

508 # ) 

509 # axes[1, layer].set_yticks( 

510 # ticks=range(len(node_names[layer + 1])), 

511 # labels=node_names[layer + 1], 

512 # rotation=0, 

513 # fontsize=8, 

514 # ) 

515 # else: 

516 # sns.heatmap( 

517 # all_weights[layer], 

518 # cmap=sns.color_palette("Spectral", as_cmap=True), 

519 # ax=axes[layer], 

520 # ).set(title=names[layer]) 

521 # ## Decoder Layer 

522 # sns.heatmap( 

523 # all_weights[layer + 2], 

524 # cmap=sns.color_palette("Spectral", as_cmap=True), 

525 # ax=axes[layer + 1], 

526 # ).set(title=names[layer + 2]) 

527 # axes[1].set_xlabel("In Node", size=12) 

528 

529 # if layers > 1: 

530 # axes[1, 0].set_ylabel("Out Node", size=12) 

531 # axes[0, 0].set_ylabel("Out Node", size=12) 

532 # else: 

533 # axes[1].set_ylabel("Out Node", size=12) 

534 # axes[0].set_ylabel("Out Node", size=12) 

535 

536 # ## Add title 

537 # fig.suptitle("Model Weights", size=20) 

538 # plt.close() 

539 # return fig 

540 

541 ## NEW VERSION 

542 def plot_model_weights(model: torch.nn.Module) -> matplotlib.figure.Figure: 

543 """Visualization of model weights in encoder and decoder layers as heatmap for each layer as subplot. 

544 

545 Handles non-symmetrical autoencoder architectures. 

546 Plots _mu layer for encoder as well. 

547 Uses node_names for decoder layers if model has ontologies. 

548 

549 Args: 

550 model: PyTorch model instance. 

551 Returns: 

552 fig: Figure handle (of last plot) 

553 """ 

554 all_weights = [] 

555 names = [] 

556 node_names = [] 

557 if hasattr(model, "ontologies"): 

558 if model.ontologies is not None: 

559 node_names = [] 

560 for ontology in model.ontologies: 

561 node_names.append(list(ontology.keys())) 

562 node_names.append(model.feature_order) 

563 

564 # Collect encoder and decoder weights separately 

565 encoder_weights = [] 

566 encoder_names = [] 

567 decoder_weights = [] 

568 decoder_names = [] 

569 for name, param in model.named_parameters(): 

570 # print(name) 

571 if "weight" in name and len(param.shape) == 2: 

572 if "encoder" in name and "var" not in name and "_mu" not in name: 

573 encoder_weights.append(param.detach().cpu().numpy()) 

574 encoder_names.append(name[:-7]) 

575 elif "_mu" in name: 

576 encoder_weights.append(param.detach().cpu().numpy()) 

577 encoder_names.append(name[:-7]) 

578 elif "decoder" in name and "var" not in name: 

579 decoder_weights.append(param.detach().cpu().numpy()) 

580 decoder_names.append(name[:-7]) 

581 elif ( 

582 "encoder" not in name 

583 and "decoder" not in name 

584 and "var" not in name 

585 ): 

586 # fallback for models without explicit encoder/decoder in name 

587 all_weights.append(param.detach().cpu().numpy()) 

588 names.append(name[:-7]) 

589 

590 if encoder_weights or decoder_weights: 

591 n_enc = len(encoder_weights) 

592 n_dec = len(decoder_weights) 

593 n_cols = max(n_enc, n_dec) 

594 fig, axes = plt.subplots(2, n_cols, sharex=False, figsize=(15 * n_cols, 15)) 

595 if n_cols == 1: 

596 axes = axes.reshape(2, 1) 

597 # Plot encoder weights 

598 for i in range(n_enc): 

599 ax = axes[0, i] 

600 sns.heatmap( 

601 encoder_weights[i], 

602 cmap=sns.color_palette("Spectral", as_cmap=True), 

603 center=0, 

604 ax=ax, 

605 ).set(title=encoder_names[i]) 

606 ax.set_ylabel("Out Node", size=12) 

607 # Hide unused encoder subplots 

608 for i in range(n_enc, n_cols): 

609 axes[0, i].axis("off") 

610 # Plot decoder weights 

611 for i in range(n_dec): 

612 ax = axes[1, i] 

613 heatmap_kwargs = {} 

614 

615 sns.heatmap( 

616 decoder_weights[i], 

617 cmap=sns.color_palette("Spectral", as_cmap=True), 

618 center=0, 

619 ax=ax, 

620 **heatmap_kwargs, 

621 ).set(title=decoder_names[i]) 

622 if model.ontologies is not None: 

623 axes[1, i].set_xticks( 

624 ticks=range(len(node_names[i])), 

625 labels=node_names[i], 

626 rotation=90, 

627 fontsize=8, 

628 ) 

629 axes[1, i].set_yticks( 

630 ticks=range(len(node_names[i + 1])), 

631 labels=node_names[i + 1], 

632 rotation=0, 

633 fontsize=8, 

634 ) 

635 ax.set_xlabel("In Node", size=12) 

636 ax.set_ylabel("Out Node", size=12) 

637 # Hide unused decoder subplots 

638 for i in range(n_dec, n_cols): 

639 axes[1, i].axis("off") 

640 else: 

641 # fallback: plot all weights in order, split in half for encoder/decoder 

642 n_layers = len(all_weights) // 2 

643 fig, axes = plt.subplots( 

644 2, n_layers, sharex=False, figsize=(5 * n_layers, 10) 

645 ) 

646 for layer in range(n_layers): 

647 sns.heatmap( 

648 all_weights[layer], 

649 cmap=sns.color_palette("Spectral", as_cmap=True), 

650 center=0, 

651 ax=axes[0, layer], 

652 ).set(title=names[layer]) 

653 sns.heatmap( 

654 all_weights[n_layers + layer], 

655 cmap=sns.color_palette("Spectral", as_cmap=True), 

656 center=0, 

657 ax=axes[1, layer], 

658 ).set(title=names[n_layers + layer]) 

659 axes[1, layer].set_xlabel("In Node", size=12) 

660 axes[0, layer].set_ylabel("Out Node", size=12) 

661 axes[1, layer].set_ylabel("Out Node", size=12) 

662 

663 fig.suptitle("Model Weights", size=20) 

664 plt.close() 

665 return fig 

666 

667 @staticmethod 

668 def plot_2D( 

669 embedding: pd.DataFrame, 

670 labels: list, 

671 param: Optional[Union[str, None]] = None, 

672 layer: str = "latent space", 

673 figsize: tuple = (24, 15), 

674 center: bool = True, 

675 plot_numeric: bool = False, 

676 xlim: Optional[Union[tuple, None]] = None, 

677 ylim: Optional[Union[tuple, None]] = None, 

678 scale: Optional[Union[str, None]] = None, 

679 no_leg: bool = False, 

680 ) -> matplotlib.figure.Figure: 

681 """Plots a 2D scatter plot of the given embedding with labels. 

682 

683 Args: 

684 embedding: DataFrame containing the 2D embedding coordinates. 

685 labels: List of labels corresponding to each point in the embedding. 

686 param: Title for the legend. Defaults to None. 

687 layer: Title for the plot. Defaults to "latent space". 

688 figsize: Size of the figure. Defaults to (24, 15). 

689 center: If True, centers the plot based on label means. Defaults to True. 

690 plot_numeric Defaults to False. 

691 xlim: Defaults to None. 

692 ylim: Defaults to None. 

693 scale: Defaults to None. 

694 no_leg: Defaults to False. 

695 

696 Returns: 

697 The resulting matplotlib figure. 

698 """ 

699 

700 numeric = False 

701 if not isinstance(labels[0], str): 

702 if len(np.unique(labels)) > 3: 

703 if not plot_numeric: 

704 print( 

705 "The provided label column is numeric and converted to categories." 

706 ) 

707 # Change non-float labels to NaN 

708 labels = [ 

709 x if isinstance(x, float) else float("nan") for x in labels 

710 ] 

711 labels = ( 

712 pd.qcut( 

713 x=pd.Series(labels), 

714 q=4, 

715 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"], 

716 ) 

717 .astype(str) 

718 .to_list() 

719 ) 

720 else: 

721 center = False ## Disable centering for numeric params 

722 numeric = True 

723 else: 

724 labels = [str(x) for x in labels] 

725 

726 fig, ax1 = plt.subplots(figsize=figsize) 

727 

728 # check if label or embedding is longerm and duplicate the shorter one 

729 if len(labels) < embedding.shape[0]: 

730 print( 

731 "Given labels do not have the same length as given sample size. Labels will be duplicated." 

732 ) 

733 labels = [ 

734 label 

735 for label in labels 

736 for _ in range(embedding.shape[0] // len(labels)) 

737 ] 

738 elif len(labels) > embedding.shape[0]: 

739 labels = list(set(labels)) 

740 

741 if numeric: 

742 ax2 = sns.scatterplot( 

743 x=embedding.iloc[:, 0], 

744 y=embedding.iloc[:, 1], 

745 hue=labels, 

746 palette="bwr", 

747 s=40, 

748 alpha=0.5, 

749 ec="black", 

750 ) 

751 else: 

752 ax2 = sns.scatterplot( 

753 x=embedding.iloc[:, 0], 

754 y=embedding.iloc[:, 1], 

755 hue=labels, 

756 hue_order=np.unique(labels), 

757 s=40, 

758 alpha=0.5, 

759 ec="black", 

760 ) 

761 if center: 

762 means = embedding.groupby(by=labels).mean() 

763 

764 ax2 = sns.scatterplot( 

765 x=means.iloc[:, 0], 

766 y=means.iloc[:, 1], 

767 hue=np.unique(labels), 

768 hue_order=np.unique(labels), 

769 s=200, 

770 ec="black", 

771 alpha=0.9, 

772 marker="*", 

773 legend=False, 

774 ax=ax2, 

775 ) 

776 

777 if xlim is not None: 

778 ax2.set_xlim(xlim[0], xlim[1]) 

779 

780 if ylim is not None: 

781 ax2.set_ylim(ylim[0], ylim[1]) 

782 

783 if scale is not None: 

784 plt.yscale(scale) 

785 plt.xscale(scale) 

786 ax2.set_xlabel("Dim 1") 

787 ax2.set_ylabel("Dim 2") 

788 legend_cols = 1 

789 if len(np.unique(labels)) > 10: 

790 legend_cols = 2 

791 

792 if no_leg: 

793 plt.legend([], [], frameon=False) 

794 else: 

795 sns.move_legend( 

796 ax2, 

797 "upper left", 

798 bbox_to_anchor=(1, 1), 

799 ncol=legend_cols, 

800 title=param, 

801 frameon=False, 

802 ) 

803 

804 # Add title to the plot 

805 ax2.set_title(layer) 

806 

807 plt.close() 

808 return fig 

809 

810 @staticmethod 

811 def plot_latent_ridge( 

812 lat_space: pd.DataFrame, 

813 labels: Optional[Union[list, pd.Series, None]] = None, 

814 param: Optional[Union[str, None]] = None, 

815 ) -> sns.FacetGrid: 

816 """Creates a ridge line plot of latent space dimension where each row shows the density of a latent dimension and groups (ridges). 

817 Args: 

818 lat_space: If None, all samples are considered as one group. 

819 param: Must be a column name (str) of clin_data 

820 Returns: 

821 g: FacetGrid object containing the ridge line plot 

822 """ 

823 sns.set_theme( 

824 style="white", rc={"axes.facecolor": (0, 0, 0, 0)} 

825 ) ## Necessary to enforce overplotting 

826 

827 df = pd.melt(lat_space, var_name="latent dim", value_name="latent intensity") 

828 df["sample"] = len(lat_space.columns) * list(lat_space.index) 

829 

830 if labels is None: 

831 param = "all" 

832 labels = ["all"] * len(df) 

833 

834 # print(labels[0]) 

835 if not isinstance(labels[0], str): 

836 if len(np.unique(labels)) > 3: 

837 # Change non-float labels to NaN 

838 labels = [x if isinstance(x, float) else float("nan") for x in labels] 

839 labels = pd.qcut( 

840 x=pd.Series(labels), 

841 q=4, 

842 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"], 

843 ).astype(str) 

844 else: 

845 labels = [str(x) for x in labels] 

846 

847 df[param] = len(lat_space.columns) * labels # type: ignore 

848 

849 exclude_missing_info = (df[param] == "unknown") | (df[param] == "nan") 

850 

851 xmin = ( 

852 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]] 

853 .groupby([param, "latent dim"], observed=False) 

854 .quantile(0.05) 

855 .min() 

856 ) 

857 xmax = ( 

858 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]] 

859 .groupby([param, "latent dim"], observed=False) 

860 .quantile(0.9) 

861 .max() 

862 ) 

863 

864 if len(np.unique(df[param])) > 8: 

865 cat_pal = sns.husl_palette(len(np.unique(df[param]))) 

866 else: 

867 cat_pal = sns.color_palette(n_colors=len(np.unique(df[param]))) 

868 

869 g = sns.FacetGrid( 

870 df[~exclude_missing_info], 

871 row="latent dim", 

872 hue=param, 

873 aspect=12, 

874 height=0.8, 

875 xlim=(xmin.iloc[0], xmax.iloc[0]), 

876 palette=cat_pal, 

877 ) 

878 

879 g.map_dataframe( 

880 sns.kdeplot, 

881 "latent intensity", 

882 bw_adjust=0.5, 

883 clip_on=True, 

884 fill=True, 

885 alpha=0.5, 

886 warn_singular=False, 

887 ec="k", 

888 lw=1, 

889 ) 

890 

891 def label(data, color, label, text="latent dim"): 

892 ax = plt.gca() 

893 label_text = data[text].unique()[0] 

894 ax.text( 

895 0.0, 

896 0.2, 

897 label_text, 

898 fontweight="bold", 

899 ha="right", 

900 va="center", 

901 transform=ax.transAxes, 

902 ) 

903 

904 g.map_dataframe(label, text="latent dim") 

905 

906 g.set(xlim=(xmin.iloc[0], xmax.iloc[0])) 

907 # Set the subplots to overlap 

908 g.figure.subplots_adjust(hspace=-0.5) 

909 

910 # Remove axes details that don't play well with overlap 

911 g.set_titles("") 

912 g.set(yticks=[], ylabel="") 

913 g.despine(bottom=True, left=True) 

914 

915 g.add_legend() 

916 

917 plt.close() 

918 return g 

919 

920 @staticmethod 

921 def make_loss_plot( 

922 df_plot: pd.DataFrame, plot_type: str 

923 ) -> matplotlib.figure.Figure: 

924 """Generates a plot for visualizing loss values from a DataFrame. 

925 

926 Args: 

927 df_plot: DataFrame containing the loss values to be plotted. It should have the columns: 

928 - "Loss Term": The type of loss term (e.g., "total_loss", "reconstruction_loss"). 

929 - "Epoch": The epoch number. 

930 - "Loss Value": The value of the loss. 

931 - "Split": The data split (e.g., "train", "validation"). 

932 

933 plot_type: The type of plot to generate. It can be either "absolute" or "relative". 

934 - "absolute": Generates a line plot for each unique loss term. 

935 - "relative": Generates a density plot for each data split, excluding the "total_loss" term. 

936 

937 Returns: 

938 The generated matplotlib figure containing the loss plots. 

939 """ 

940 fig_width_abs = 5 * len(df_plot["Loss Term"].unique()) 

941 fig_width_rel = 5 * len(df_plot["Split"].unique()) 

942 if plot_type == "absolute": 

943 fig, axes = plt.subplots( 

944 1, 

945 len(df_plot["Loss Term"].unique()), 

946 figsize=(fig_width_abs, 5), 

947 sharey=False, 

948 ) 

949 ax = 0 

950 for term in df_plot["Loss Term"].unique(): 

951 axes[ax] = sns.lineplot( 

952 data=df_plot[(df_plot["Loss Term"] == term)], 

953 x="Epoch", 

954 y="Loss Value", 

955 hue="Split", 

956 ax=axes[ax], 

957 ).set_title(term) 

958 ax += 1 

959 

960 plt.close() 

961 

962 if plot_type == "relative": 

963 # Check if loss values are positive 

964 if (df_plot["Loss Value"] < 0).any(): 

965 # Warning 

966 warnings.warn( 

967 "Loss values contain negative values. Check your loss function if correct. Loss will be clipped to zero for plotting." 

968 ) 

969 df_plot["Loss Value"] = df_plot["Loss Value"].clip(lower=0) 

970 

971 # Exclude loss terms where all Loss Value are zero or NaN over all epochs 

972 valid_terms = [ 

973 term 

974 for term in df_plot["Loss Term"].unique() 

975 if ( 

976 (df_plot[df_plot["Loss Term"] == term]["Loss Value"].notna().any()) 

977 and (df_plot[df_plot["Loss Term"] == term]["Loss Value"] != 0).any() 

978 ) 

979 ] 

980 exclude = ( 

981 (df_plot["Loss Term"] != "total_loss") 

982 & ~(df_plot["Loss Term"].str.contains("_factor")) 

983 & (df_plot["Loss Term"].isin(valid_terms)) 

984 ) 

985 

986 fig, axes = plt.subplots(1, 2, figsize=(fig_width_rel, 5), sharey=True) 

987 

988 ax = 0 

989 

990 for split in df_plot["Split"].unique(): 

991 axes[ax] = sns.kdeplot( 

992 data=df_plot[exclude & (df_plot["Split"] == split)], 

993 x="Epoch", 

994 hue="Loss Term", 

995 multiple="fill", 

996 weights="Loss Value", 

997 clip=[0, df_plot["Epoch"].max()], 

998 ax=axes[ax], 

999 ).set_title(split) 

1000 ax += 1 

1001 

1002 plt.close() 

1003 

1004 return fig 

1005 

1006 @staticmethod 

1007 def make_loss_format(result: Result, config: DefaultConfig) -> pd.DataFrame: 

1008 loss_df_melt = pd.DataFrame() 

1009 for term in result.sub_losses.keys(): 

1010 # Get the loss values and ensure it's a dictionary 

1011 loss_values = result.sub_losses.get(key=term).get() 

1012 

1013 # Add explicit type checking/conversion 

1014 if not isinstance(loss_values, dict): 

1015 # If it's not a dict, try to convert it or handle appropriately 

1016 if hasattr(loss_values, "to_dict"): 

1017 loss_values = loss_values.to_dict() # type: ignore 

1018 else: 

1019 # For non-convertible types, you might need a custom solution 

1020 # For numpy arrays, you could do something like: 

1021 if hasattr(loss_values, "shape"): 

1022 # For numpy arrays, create a dict with indices as keys 

1023 loss_values = {i: val for i, val in enumerate(loss_values)} 

1024 

1025 # Now create the DataFrame 

1026 loss_df = pd.DataFrame.from_dict(loss_values, orient="index") # type: ignore 

1027 

1028 # Rest of your code remains the same 

1029 if term == "var_loss": 

1030 loss_df = loss_df * config.beta 

1031 loss_df["Epoch"] = loss_df.index + 1 

1032 loss_df["Loss Term"] = term 

1033 

1034 loss_df_melt = pd.concat( 

1035 [ 

1036 loss_df_melt, 

1037 loss_df.melt( 

1038 id_vars=["Epoch", "Loss Term"], 

1039 var_name="Split", 

1040 value_name="Loss Value", 

1041 ), 

1042 ], 

1043 axis=0, 

1044 ).reset_index(drop=True) 

1045 

1046 # Similar handling for the total losses 

1047 loss_values = result.losses.get() 

1048 if not isinstance(loss_values, dict): 

1049 if hasattr(loss_values, "to_dict"): 

1050 loss_values = loss_values.to_dict() # ty: ignore 

1051 else: 

1052 if hasattr(loss_values, "shape"): 

1053 loss_values = {i: val for i, val in enumerate(loss_values)} 

1054 

1055 loss_df = pd.DataFrame.from_dict(loss_values, orient="index") # type: ignore 

1056 loss_df["Epoch"] = loss_df.index + 1 

1057 loss_df["Loss Term"] = "total_loss" 

1058 

1059 loss_df_melt = pd.concat( 

1060 [ 

1061 loss_df_melt, 

1062 loss_df.melt( 

1063 id_vars=["Epoch", "Loss Term"], 

1064 var_name="Split", 

1065 value_name="Loss Value", 

1066 ), 

1067 ], 

1068 axis=0, 

1069 ).reset_index(drop=True) 

1070 

1071 loss_df_melt["Loss Value"] = loss_df_melt["Loss Value"].astype(float) 

1072 return loss_df_melt 

1073 

1074 @no_type_check 

1075 def plot_evaluation( 

1076 self, 

1077 result: Result, 

1078 ) -> dict: 

1079 """Plots the evaluation results from the Result object. 

1080 

1081 Args: 

1082 result: The Result object containing evaluation data. 

1083 

1084 Returns: 

1085 The generated dictionary containing the evaluation plots. 

1086 """ 

1087 ## Plot all results 

1088 

1089 ml_plots = dict() 

1090 plt.ioff() 

1091 

1092 for c in pd.unique(result.embedding_evaluation.CLINIC_PARAM): 

1093 ml_plots[c] = dict() 

1094 for m in pd.unique( 

1095 result.embedding_evaluation.loc[ 

1096 result.embedding_evaluation.CLINIC_PARAM == c, "metric" 

1097 ] 

1098 ): 

1099 ml_plots[c][m] = dict() 

1100 for alg in pd.unique( 

1101 result.embedding_evaluation.loc[ 

1102 (result.embedding_evaluation.CLINIC_PARAM == c) 

1103 & (result.embedding_evaluation.metric == m), 

1104 "ML_ALG", 

1105 ] 

1106 ): 

1107 data = result.embedding_evaluation[ 

1108 (result.embedding_evaluation.metric == m) 

1109 & (result.embedding_evaluation.CLINIC_PARAM == c) 

1110 & (result.embedding_evaluation.ML_ALG == alg) 

1111 ] 

1112 

1113 sns_plot = sns.catplot( 

1114 data=data, 

1115 x="score_split", 

1116 y="value", 

1117 col="ML_TASK", 

1118 hue="score_split", 

1119 kind="bar", 

1120 ) 

1121 

1122 min_y = data.value.min() 

1123 if min_y > 0: 

1124 min_y = 0 

1125 

1126 ml_plots[c][m][alg] = sns_plot.set(ylim=(min_y, None)) 

1127 

1128 self.plots["ML_Evaluation"] = ml_plots 

1129 

1130 return ml_plots 

1131 

1132 def show_evaluation( 

1133 self, 

1134 param: str, 

1135 metric: str, 

1136 ml_alg: Optional[str] = None, 

1137 ) -> None: 

1138 """Displays the evaluation plot for a specific clinical parameter, metric, and optionally ML algorithm. 

1139 

1140 Args: 

1141 param: The clinical parameter to visualize. 

1142 metric: The metric to visualize. 

1143 ml_alg: If None, plots all available algorithms. 

1144 """ 

1145 plt.ioff() 

1146 if "ML_Evaluation" not in self.plots.keys(): 

1147 print("ML Evaluation plots not found in the plots dictionary") 

1148 print("You need to run evaluate() method first") 

1149 return None 

1150 if param not in self.plots["ML_Evaluation"].keys(): 

1151 print(f"Parameter {param} not found in the ML Evaluation plots") 

1152 print(f"Available parameters: {list(self.plots['ML_Evaluation'].keys())}") 

1153 return None 

1154 if metric not in self.plots["ML_Evaluation"][param].keys(): 

1155 print(f"Metric {metric} not found in the ML Evaluation plots for {param}") 

1156 print( 

1157 f"Available metrics: {list(self.plots['ML_Evaluation'][param].keys())}" 

1158 ) 

1159 return None 

1160 

1161 algs = list(self.plots["ML_Evaluation"][param][metric].keys()) 

1162 if ml_alg is not None: 

1163 if ml_alg not in algs: 

1164 print(f"ML algorithm {ml_alg} not found for {param} and {metric}") 

1165 print(f"Available ML algorithms: {algs}") 

1166 return None 

1167 fig = self.plots["ML_Evaluation"][param][metric][ml_alg].figure 

1168 show_figure(fig) 

1169 plt.show() 

1170 else: 

1171 for alg in algs: 

1172 print(f"Showing plot for ML algorithm: {alg}") 

1173 fig = self.plots["ML_Evaluation"][param][metric][alg].figure 

1174 show_figure(fig) 

1175 plt.show() 

1176 

1177 @staticmethod 

1178 def _total_correlation(latent_space: pd.DataFrame) -> float: 

1179 """Function to compute the total correlation as described here (Equation2): https://doi.org/10.3390/e21100921 

1180 

1181 Args: 

1182 latent_space - (pd.DataFrame): latent space with dimension sample vs. latent dimensions 

1183 Returns: 

1184 tc - (float): total correlation across latent dimensions 

1185 """ 

1186 lat_cov = np.cov(latent_space.T) 

1187 tc = 0.5 * (np.sum(np.log(np.diag(lat_cov))) - np.linalg.slogdet(lat_cov)[1]) 

1188 return tc 

1189 

1190 @staticmethod 

1191 def _coverage_calc(latent_space: pd.DataFrame) -> float: 

1192 """Function to compute the coverage as described here (Equation3): https://doi.org/10.3390/e21100921 

1193 

1194 Args: 

1195 latent_space: latent dimensions 

1196 Returns: 

1197 cov: coverage across latent dimensions 

1198 """ 

1199 bins_per_dim = int( 

1200 np.power(len(latent_space.index), 1 / len(latent_space.columns)) 

1201 ) 

1202 if bins_per_dim < 2: 

1203 warnings.warn( 

1204 "Coverage calculation fails since combination of sample size and latent dimension results in less than 2 bins." 

1205 ) 

1206 cov = np.nan 

1207 else: 

1208 latent_bins = latent_space.apply(lambda x: pd.cut(x, bins=bins_per_dim)) 

1209 latent_bins = pd.Series(zip(*[latent_bins[col] for col in latent_bins])) 

1210 cov = len(latent_bins.unique()) / np.power( 

1211 bins_per_dim, len(latent_space.columns) 

1212 ) 

1213 

1214 return cov