Coverage for src / autoencodix / base / _base_visualizer.py: 13%

261 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import abc 

2import os 

3from typing import Optional, Union 

4import pandas as pd 

5import matplotlib 

6from matplotlib import pyplot as plt 

7import seaborn as sns # type: ignore 

8import seaborn.objects as so 

9import torch 

10import warnings 

11 

12from autoencodix.utils._result import Result 

13from autoencodix.utils._utils import nested_dict, nested_to_tuple, show_figure 

14from autoencodix.configs.default_config import DefaultConfig 

15 

16 

17class BaseVisualizer(abc.ABC): 

18 """Defines the interface for visualizing training results. 

19 

20 Attributes: 

21 plots: A nested dictionary to store various plots. 

22 """ 

23 

24 def __init__(self): 

25 self.plots = nested_dict() 

26 

27 def __setitem__(self, key, elem): 

28 self.plots[key] = elem 

29 

30 ### Abstract Methods ### 

31 @abc.abstractmethod 

32 def visualize(self, result: Result, config: DefaultConfig) -> Result: 

33 pass 

34 

35 @abc.abstractmethod 

36 def show_latent_space( 

37 self, 

38 result: Result, 

39 plot_type: str = "2D-scatter", 

40 labels: Optional[Union[list, pd.Series, None]] = None, 

41 param: Optional[Union[list, str]] = None, 

42 epoch: Optional[Union[int, None]] = None, 

43 split: str = "all", 

44 ) -> None: 

45 pass 

46 

47 @abc.abstractmethod 

48 def show_weights(self) -> None: 

49 pass 

50 

51 ### General Functions used by all Visualizers in similar way ### 

52 

53 def show_loss(self, plot_type: str = "absolute") -> None: 

54 """ 

55 Display the loss plot. 

56 Args: 

57 plot_type: Type of loss plot to display. Options are "absolute" or "relative". Options are 

58 "absolute" for the absolute loss plot and 

59 "relative" for the relative loss plot. 

60 Defaults to "absolute". 

61 Returns: 

62 None 

63 """ 

64 if plot_type == "absolute": 

65 if "loss_absolute" not in self.plots.keys(): 

66 print("Absolute loss plot not found in the plots dictionary") 

67 print( 

68 "This happens, when you did not run visualize() or if you saved and loaded the model with `save_all=False`" 

69 ) 

70 else: 

71 fig = self.plots["loss_absolute"] 

72 show_figure(fig) 

73 plt.show() 

74 if plot_type == "relative": 

75 if "loss_relative" not in self.plots.keys(): 

76 print("Relative loss plot not found in the plots dictionary") 

77 

78 print( 

79 "This happens, when you did not run visualize() or if you saved and loaded the model with `save_all=False`" 

80 ) 

81 else: 

82 fig = self.plots["loss_relative"] 

83 fig.show() 

84 # show_figure(fig) 

85 # plt.show() 

86 

87 if plot_type not in ["absolute", "relative"]: 

88 print( 

89 "Type of loss plot not recognized. Please use 'absolute' or 'relative'" 

90 ) 

91 

92 def show_evaluation( 

93 self, 

94 param: str, 

95 metric: str, 

96 ml_alg: Optional[str] = None, 

97 ) -> None: 

98 """ 

99 Displays the evaluation plot for a specific clinical parameter, metric, and optionally ML algorithm. 

100 Args: 

101 param: clinical parameter to visualize. 

102 metric: metric to visualize. 

103 ml_alg: ML algorithm to visualize. If None, plots all available algorithms. 

104 Returns: 

105 None 

106 """ 

107 plt.ioff() 

108 if "ML_Evaluation" not in self.plots.keys(): 

109 print("ML Evaluation plots not found in the plots dictionary") 

110 print("You need to run evaluate() method first") 

111 return None 

112 if param not in self.plots["ML_Evaluation"].keys(): 

113 print(f"Parameter {param} not found in the ML Evaluation plots") 

114 print(f"Available parameters: {list(self.plots['ML_Evaluation'].keys())}") 

115 return None 

116 if metric not in self.plots["ML_Evaluation"][param].keys(): 

117 print(f"Metric {metric} not found in the ML Evaluation plots for {param}") 

118 print( 

119 f"Available metrics: {list(self.plots['ML_Evaluation'][param].keys())}" 

120 ) 

121 return None 

122 

123 algs = list(self.plots["ML_Evaluation"][param][metric].keys()) 

124 if ml_alg is not None: 

125 if ml_alg not in algs: 

126 print(f"ML algorithm {ml_alg} not found for {param} and {metric}") 

127 print(f"Available ML algorithms: {algs}") 

128 return None 

129 fig = self.plots["ML_Evaluation"][param][metric][ml_alg].figure 

130 show_figure(fig) 

131 plt.show() 

132 else: 

133 for alg in algs: 

134 print(f"Showing plot for ML algorithm: {alg}") 

135 fig = self.plots["ML_Evaluation"][param][metric][alg].figure 

136 show_figure(fig) 

137 plt.show() 

138 

139 def save_plots( 

140 self, path: str, which: Union[str, list] = "all", format: str = "png" 

141 ) -> None: 

142 """ 

143 Save specified plots to the given path in the specified format. 

144 

145 Args: 

146 path: The directory path where the plots will be saved. 

147 which: A list of plot names to save or a string specifying which plots to save. 

148 If 'all', all plots in the plots dictionary will be saved. 

149 If a single plot name is provided as a string, only that plot will be saved. 

150 format: The file format in which to save the plots (e.g., 'png', 'jpg'). 

151 

152 Returns: 

153 None 

154 

155 Raises: 

156 ValueError: If the 'which' parameter is not a list or a string. 

157 """ 

158 if not os.path.exists(path): 

159 os.makedirs(path) 

160 

161 if not isinstance(which, list): 

162 ## Case when which is a string 

163 if which == "all": 

164 ## Case when all plots are to be saved 

165 if len(self.plots) == 0: 

166 print("No plots found in the plots dictionary") 

167 print("You need to run visualize() method first") 

168 else: 

169 for item in nested_to_tuple(self.plots): 

170 fig = item[-1] ## Figure is in last element of the tuple 

171 filename = "_".join(str(x) for x in item[0:-1]) 

172 fullpath = os.path.join(path, filename) 

173 if hasattr(fig, "savefig"): 

174 fig.savefig(f"{fullpath}.{format}") 

175 elif hasattr(fig, "save"): # for seaborn objects plots 

176 fig.save(f"{fullpath}.{format}") 

177 else: 

178 ## Case when a single plot is provided as string 

179 if which not in self.plots.keys(): 

180 print(f"Plot {which} not found in the plots dictionary") 

181 print(f"All available plots are: {list(self.plots.keys())}") 

182 else: 

183 for item in nested_to_tuple( 

184 self.plots[which] 

185 ): # Plot all epochs and splits of type which 

186 fig = item[-1] ## Figure is in last element of the tuple 

187 filename = which + "_" + "_".join(str(x) for x in item[0:-1]) # type: ignore 

188 fullpath = os.path.join(path, filename) 

189 if hasattr(fig, "savefig"): 

190 fig.savefig(f"{fullpath}.{format}") 

191 elif hasattr(fig, "save"): # for seaborn objects plots 

192 fig.save(f"{fullpath}.{format}") 

193 else: 

194 ## Case when which is a list of plot specified as strings 

195 for key in which: 

196 if key not in self.plots.keys(): 

197 print(f"Plot {key} not found in the plots dictionary") 

198 print(f"All available plots are: {list(self.plots.keys())}") 

199 continue 

200 else: 

201 for item in nested_to_tuple( 

202 self.plots[key] 

203 ): # Plot all epochs and splits of type key 

204 fig = item[-1] ## Figure is in last element of the tuple 

205 filename = key + "_" + "_".join(str(x) for x in item[0:-1]) 

206 fullpath = os.path.join(path, filename) 

207 if hasattr(fig, "savefig"): 

208 fig.savefig(f"{fullpath}.{format}") 

209 elif hasattr(fig, "save"): # for seaborn objects plots 

210 fig.save(f"{fullpath}.{format}") 

211 

212 ### Utilities ### 

213 

214 @staticmethod 

215 def _make_loss_format(result: Result, config: DefaultConfig) -> pd.DataFrame: 

216 loss_df_melt = pd.DataFrame() 

217 for term in result.sub_losses.keys(): 

218 # Get the loss values and ensure it's a dictionary 

219 loss_values = result.sub_losses.get(key=term).get() 

220 

221 # Add explicit type checking/conversion 

222 if not isinstance(loss_values, dict): 

223 # If it's not a dict, try to convert it or handle appropriately 

224 if hasattr(loss_values, "to_dict"): 

225 loss_values = loss_values.to_dict() # type: ignore 

226 else: 

227 # For non-convertible types, you might need a custom solution 

228 # For numpy arrays, you could do something like: 

229 if hasattr(loss_values, "shape"): 

230 # For numpy arrays, create a dict with indices as keys 

231 loss_values = {i: val for i, val in enumerate(loss_values)} 

232 

233 # Now create the DataFrame 

234 loss_df = pd.DataFrame.from_dict(loss_values, orient="index") # type: ignore 

235 

236 # Rest of your code remains the same 

237 if term == "var_loss": 

238 loss_df = loss_df * config.beta 

239 loss_df["Epoch"] = loss_df.index + 1 

240 loss_df["Loss Term"] = term 

241 

242 loss_df_melt = pd.concat( 

243 [ 

244 loss_df_melt, 

245 loss_df.melt( 

246 id_vars=["Epoch", "Loss Term"], 

247 var_name="Split", 

248 value_name="Loss Value", 

249 ), 

250 ], 

251 axis=0, 

252 ).reset_index(drop=True) 

253 

254 # Similar handling for the total losses 

255 loss_values = result.losses.get() 

256 if not isinstance(loss_values, dict): 

257 if hasattr(loss_values, "to_dict"): 

258 loss_values = loss_values.to_dict() # type: ignore 

259 else: 

260 if hasattr(loss_values, "shape"): 

261 loss_values = {i: val for i, val in enumerate(loss_values)} 

262 

263 loss_df = pd.DataFrame.from_dict(loss_values, orient="index") # type: ignore 

264 loss_df["Epoch"] = loss_df.index + 1 

265 loss_df["Loss Term"] = "total_loss" 

266 

267 loss_df_melt = pd.concat( 

268 [ 

269 loss_df_melt, 

270 loss_df.melt( 

271 id_vars=["Epoch", "Loss Term"], 

272 var_name="Split", 

273 value_name="Loss Value", 

274 ), 

275 ], 

276 axis=0, 

277 ).reset_index(drop=True) 

278 

279 loss_df_melt["Loss Value"] = loss_df_melt["Loss Value"].astype(float) 

280 return loss_df_melt 

281 

282 @staticmethod 

283 def _make_loss_plot( 

284 df_plot: pd.DataFrame, plot_type: str 

285 ) -> matplotlib.figure.Figure: # type: ignore 

286 """ 

287 Generates a plot for visualizing loss values from a DataFrame. 

288 

289 Args: 

290 df_plot : DataFrame containing the loss values to be plotted. It should have the columns: 

291 - "Loss Term": The type of loss term (e.g., "total_loss", "reconstruction_loss"). 

292 - "Epoch": The epoch number. 

293 - "Loss Value": The value of the loss. 

294 - "Split": The data split (e.g., "train", "validation"). 

295 

296 plot_type: The type of plot to generate. It can be either "absolute" or "relative". 

297 - "absolute": Generates a line plot for each unique loss term. 

298 - "relative": Generates a density plot for each data split, excluding the "total_loss" term. 

299 

300 Returns: 

301 The generated matplotlib figure containing the loss plots. 

302 """ 

303 fig_width_abs = 5 * len(df_plot["Loss Term"].unique()) 

304 fig_width_rel = 5 * len(df_plot["Split"].unique()) 

305 if plot_type == "absolute": 

306 fig, axes = plt.subplots( 

307 1, 

308 len(df_plot["Loss Term"].unique()), 

309 figsize=(fig_width_abs, 5), 

310 sharey=False, 

311 ) 

312 ax = 0 

313 for term in df_plot["Loss Term"].unique(): 

314 axes[ax] = sns.lineplot( 

315 data=df_plot[(df_plot["Loss Term"] == term)], 

316 x="Epoch", 

317 y="Loss Value", 

318 hue="Split", 

319 ax=axes[ax], 

320 ).set_title(term) 

321 ax += 1 

322 

323 plt.close() 

324 

325 if plot_type == "relative": 

326 # Check if loss values are positive 

327 if (df_plot["Loss Value"] < 0).any(): 

328 # Warning 

329 warnings.warn( 

330 "Loss values contain negative values. Check your loss function if correct. Loss will be clipped to zero for plotting." 

331 ) 

332 df_plot["Loss Value"] = df_plot["Loss Value"].clip(lower=0) 

333 

334 # Exclude loss terms where all Loss Value are zero or NaN over all epochs 

335 valid_terms = [ 

336 term 

337 for term in df_plot["Loss Term"].unique() 

338 if ( 

339 (df_plot[df_plot["Loss Term"] == term]["Loss Value"].notna().any()) 

340 and (df_plot[df_plot["Loss Term"] == term]["Loss Value"] != 0).any() 

341 ) 

342 ] 

343 exclude = ( 

344 (df_plot["Loss Term"] != "total_loss") 

345 & ~(df_plot["Loss Term"].str.contains("_factor")) 

346 & (df_plot["Loss Term"].isin(valid_terms)) 

347 ) 

348 

349 df_plot.loc[exclude, "Relative Loss Value"] = ( 

350 df_plot[exclude] 

351 .groupby(["Split", "Epoch"])["Loss Value"] 

352 .transform(lambda x: x / x.sum()) 

353 ) 

354 fig = ( 

355 ( 

356 so.Plot( 

357 df_plot[exclude], 

358 "Epoch", 

359 "Relative Loss Value", 

360 color="Loss Term", 

361 ).add(so.Area(alpha=0.7), so.Stack()) 

362 ) 

363 .facet("Split") 

364 .layout(size=(fig_width_rel, 5)) 

365 ) 

366 

367 # fig, axes = plt.subplots(1, 2, figsize=(fig_width_rel, 5), sharey=True) 

368 

369 # ax = 0 

370 

371 # for split in df_plot["Split"].unique(): 

372 # axes[ax] = sns.kdeplot( 

373 # data=df_plot[exclude & (df_plot["Split"] == split)], 

374 # x="Epoch", 

375 # hue="Loss Term", 

376 # multiple="fill", 

377 # weights="Loss Value", 

378 # clip=[0, df_plot["Epoch"].max()], 

379 # ax=axes[ax], 

380 # ).set_title(split) 

381 # ax += 1 

382 

383 # plt.close() 

384 

385 return fig 

386 

387 @staticmethod 

388 def _plot_model_weights(model: torch.nn.Module) -> matplotlib.figure.Figure: # type: ignore 

389 """ 

390 Visualization of model weights in encoder and decoder layers as heatmap for each layer as subplot. 

391 Handles non-symmetrical autoencoder architectures. 

392 Plots _mu layer for encoder as well. 

393 Uses node_names for decoder layers if model has ontologies. 

394 ARGS: 

395 model (torch.nn.Module): PyTorch model instance. 

396 RETURNS: 

397 fig (matplotlib.figure): Figure handle (of last plot) 

398 """ 

399 all_weights = [] 

400 names = [] 

401 node_names = None 

402 if hasattr(model, "ontologies"): 

403 if model.ontologies is not None: 

404 node_names = [] 

405 for ontology in model.ontologies: 

406 node_names.append(list(ontology.keys())) 

407 node_names.append(model.feature_order) 

408 

409 # Collect encoder and decoder weights separately 

410 encoder_weights = [] 

411 encoder_names = [] 

412 decoder_weights = [] 

413 decoder_names = [] 

414 for name, param in model.named_parameters(): 

415 # print(name) 

416 if "weight" in name and len(param.shape) == 2: 

417 if "encoder" in name and "var" not in name and "_mu" not in name: 

418 encoder_weights.append(param.detach().cpu().numpy()) 

419 encoder_names.append(name[:-7]) 

420 elif "_mu" in name: 

421 encoder_weights.append(param.detach().cpu().numpy()) 

422 encoder_names.append(name[:-7]) 

423 elif "decoder" in name and "var" not in name: 

424 decoder_weights.append(param.detach().cpu().numpy()) 

425 decoder_names.append(name[:-7]) 

426 elif ( 

427 "encoder" not in name 

428 and "decoder" not in name 

429 and "var" not in name 

430 ): 

431 # fallback for models without explicit encoder/decoder in name 

432 all_weights.append(param.detach().cpu().numpy()) 

433 names.append(name[:-7]) 

434 

435 if encoder_weights or decoder_weights: 

436 n_enc = len(encoder_weights) 

437 n_dec = len(decoder_weights) 

438 n_cols = max(n_enc, n_dec) 

439 fig, axes = plt.subplots(2, n_cols, sharex=False, figsize=(15 * n_cols, 15)) 

440 if n_cols == 1: 

441 axes = axes.reshape(2, 1) 

442 # Plot encoder weights 

443 for i in range(n_enc): 

444 ax = axes[0, i] 

445 sns.heatmap( 

446 encoder_weights[i], 

447 cmap=sns.color_palette("Spectral", as_cmap=True), 

448 center=0, 

449 ax=ax, 

450 ).set(title=encoder_names[i]) 

451 ax.set_ylabel("Out Node", size=12) 

452 # Hide unused encoder subplots 

453 for i in range(n_enc, n_cols): 

454 axes[0, i].axis("off") 

455 # Plot decoder weights 

456 for i in range(n_dec): 

457 ax = axes[1, i] 

458 heatmap_kwargs = {} 

459 

460 sns.heatmap( 

461 decoder_weights[i], 

462 cmap=sns.color_palette("Spectral", as_cmap=True), 

463 center=0, 

464 ax=ax, 

465 **heatmap_kwargs, 

466 ).set(title=decoder_names[i]) 

467 if model.ontologies is not None: 

468 axes[1, i].set_xticks( 

469 ticks=range(len(node_names[i])), # type: ignore 

470 labels=node_names[i], # type: ignore 

471 rotation=90, 

472 fontsize=8, 

473 ) 

474 axes[1, i].set_yticks( 

475 ticks=range(len(node_names[i + 1])), # type: ignore 

476 labels=node_names[i + 1], # type: ignore 

477 rotation=0, 

478 fontsize=8, 

479 ) 

480 ax.set_xlabel("In Node", size=12) 

481 ax.set_ylabel("Out Node", size=12) 

482 # Hide unused decoder subplots 

483 for i in range(n_dec, n_cols): 

484 axes[1, i].axis("off") 

485 else: 

486 # fallback: plot all weights in order, split in half for encoder/decoder 

487 n_layers = len(all_weights) // 2 

488 fig, axes = plt.subplots( 

489 2, n_layers, sharex=False, figsize=(5 * n_layers, 10) 

490 ) 

491 for layer in range(n_layers): 

492 sns.heatmap( 

493 all_weights[layer], 

494 cmap=sns.color_palette("Spectral", as_cmap=True), 

495 center=0, 

496 ax=axes[0, layer], 

497 ).set(title=names[layer]) 

498 sns.heatmap( 

499 all_weights[n_layers + layer], 

500 cmap=sns.color_palette("Spectral", as_cmap=True), 

501 center=0, 

502 ax=axes[1, layer], 

503 ).set(title=names[n_layers + layer]) 

504 axes[1, layer].set_xlabel("In Node", size=12) 

505 axes[0, layer].set_ylabel("Out Node", size=12) 

506 axes[1, layer].set_ylabel("Out Node", size=12) 

507 

508 fig.suptitle("Model Weights", size=20) 

509 plt.close() 

510 return fig 

511 

512 @staticmethod 

513 def _collect_all_metadata(result): 

514 all_metadata = pd.DataFrame() 

515 

516 # 1) collect metadata from results.datasets 

517 

518 # 1a) iterate over splits [train, valid, test] if they exist 

519 for split in ["train", "valid", "test"]: 

520 

521 if hasattr(result.datasets, split) and result.datasets[split] is not None: 

522 if hasattr(result.datasets[split], "metadata"): 

523 split_metadata = result.datasets[split].metadata 

524 

525 # 1b) if result.datasets.split is a dictionary, iterate over keys (modalities) 

526 if isinstance(split_metadata, dict): 

527 for modality, modality_data in split_metadata.items(): 

528 all_metadata = pd.concat( 

529 [all_metadata, modality_data], axis=0 

530 ) 

531 # 1c) if result.datasets.split is a Dataframe, just collect metadata directly 

532 elif isinstance(split_metadata, pd.DataFrame): 

533 all_metadata = pd.concat([all_metadata, split_metadata], axis=0) 

534 else: 

535 split_modalities = result.datasets[split].datasets 

536 if isinstance(split_modalities, dict): 

537 for modality, modality_data in split_modalities.items(): 

538 if hasattr(modality_data, "metadata"): 

539 modality_metadata = modality_data.metadata 

540 if isinstance(modality_metadata, pd.DataFrame): 

541 all_metadata = pd.concat( 

542 [all_metadata, modality_metadata], axis=0 

543 ) 

544 

545 # 2) collect metadata from results.new_datasets in the same way 

546 if hasattr(result, "new_datasets"): 

547 for split in ["train", "valid", "test"]: 

548 if ( 

549 hasattr(result.new_datasets, split) 

550 and result.new_datasets[split] is not None 

551 ): 

552 if hasattr(result.new_datasets[split], "metadata"): 

553 split_metadata = result.new_datasets[split].metadata 

554 

555 if isinstance(split_metadata, dict): 

556 for modality, modality_data in split_metadata.items(): 

557 all_metadata = pd.concat( 

558 [all_metadata, modality_data], axis=0 

559 ) 

560 elif isinstance(split_metadata, pd.DataFrame): 

561 all_metadata = pd.concat( 

562 [all_metadata, split_metadata], axis=0 

563 ) 

564 else: 

565 split_modalities = result.new_datasets[split].datasets 

566 if isinstance(split_modalities, dict): 

567 for modality, modality_data in split_modalities.items(): 

568 if hasattr(modality_data, "metadata"): 

569 modality_metadata = modality_data.metadata 

570 if isinstance(modality_metadata, pd.DataFrame): 

571 all_metadata = pd.concat( 

572 [all_metadata, modality_metadata], axis=0 

573 ) 

574 

575 # Remove duplicate rows if any 

576 all_metadata = all_metadata.loc[~all_metadata.index.duplicated(keep="first")] 

577 

578 return all_metadata