Coverage for src / autoencodix / utils / _utils.py: 27%

299 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1""" 

2Stores utility functions for the autoencodix package. 

3Use of OOP would be overkill for the simple functions in this module. 

4""" 

5 

6from pathlib import Path 

7 

8# import zipfile 

9import inspect 

10import os 

11from collections import defaultdict 

12from dataclasses import MISSING, fields, is_dataclass 

13from functools import wraps 

14from typing import Any, Callable, Dict, List, Optional, no_type_check 

15from autoencodix.data._datasetcontainer import DatasetContainer 

16from autoencodix.utils._result import Result 

17 

18import dill as pickle # type: ignore 

19import torch 

20import pandas as pd 

21from matplotlib import pyplot as plt 

22 

23from autoencodix.configs.default_config import DefaultConfig 

24 

25 

26# only for type hints, to avoid circual import 

27class BasePipeline: 

28 """Only for type hints in utils, not real BasePipeline class""" 

29 

30 pass 

31 

32 

33def get_dataset(result: Result) -> Optional[DatasetContainer]: 

34 """Retrieve the dataset from the Result object, depending on if new_datasets is filled. 

35 

36 Args: 

37 result: The Result object containing the dataset. 

38 Returns: 

39 The appropriate DatasetContainer object. 

40 

41 """ 

42 splits = ["train", "valid", "test"] 

43 if not result.new_datasets: 

44 return result.datasets 

45 new_values: List[Any] = [result.new_datasets[split] for split in splits] 

46 # check if all new_datasets are None 

47 if all(v is None for v in new_values): 

48 return result.datasets 

49 else: 

50 return result.new_datasets 

51 

52 

53def nested_dict(): 

54 """Creates a nested defaultdict. 

55 

56 This function returns a defaultdict where each value is another defaultdict 

57 of the same type. This allows for the creation of arbitrarily deep nested 

58 dictionaries without having to explicitly define each level. 

59 

60 Returns: 

61 :A nested defaultdict where each value is another nested defaultdict. 

62 """ 

63 return defaultdict(nested_dict) 

64 

65 

66def nested_to_tuple(d, base=()): 

67 """Recursively converts a nested dictionary into tuples. 

68 

69 Args: 

70 d: The dictionary to convert. 

71 base: The base tuple to start with. Defaults to (). 

72 

73 Yields: 

74 tuple: Tuples representing the nested dictionary structure, where each tuple 

75 contains the keys leading to a value and the value itself. 

76 """ 

77 if not isinstance(d, dict): 

78 yield base + (d,) 

79 

80 else: 

81 for k, v in d.items(): 

82 if isinstance(v, dict): 

83 yield from nested_to_tuple(v, base + (k,)) 

84 else: 

85 yield base + (k, v) 

86 

87 

88@no_type_check 

89def show_figure(fig): 

90 """Display a given Matplotlib figure in a new window. 

91 

92 Args: 

93 fig: The figure to be displayed. 

94 

95 """ 

96 dummy = plt.figure() 

97 new_manager = dummy.canvas.manager 

98 new_manager.canvas.figure = fig 

99 fig.set_canvas(new_manager.canvas) 

100 

101 

102def config_method(valid_params: Optional[set[str]] = None): 

103 """Decorator for methods that accept configuration parameters via kwargs or an explicit 'config' object. 

104 

105 It separates kwargs intended for the function's signature from those 

106 intended as configuration overrides, validates the latter against 

107 `valid_params`, applies valid overrides to a copy of `self.config`, 

108 and passes the appropriate arguments to the decorated function. 

109 

110 Args: 

111 valid_params: Set of valid configuration parameter names that can be overridden 

112 via kwargs for this method. If None, all kwargs not matching the 

113 function signature are considered potentially valid config overrides. 

114 """ 

115 

116 def decorator(func: Callable) -> Callable: 

117 sig = inspect.signature(func) 

118 

119 param_docs = "\n\nValid configuration parameters (passed via **kwargs):\n" 

120 if valid_params: 

121 param_docs += "\n".join(f"- `{param}`" for param in sorted(valid_params)) 

122 else: 

123 param_docs += ( 

124 "All keyword arguments not matching the function's " 

125 "signature are treated as potential configuration overrides." 

126 ) 

127 

128 if func.__doc__ is None: 

129 func.__doc__ = "" 

130 # Avoid duplicating if decorator is applied multiple times (though unlikely) 

131 if "Valid configuration parameters" not in func.__doc__: 

132 func.__doc__ += param_docs 

133 # --- End Docstring Modification --- 

134 

135 @wraps(func) 

136 def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: 

137 # Pop the explicit config object if provided 

138 user_config = kwargs.pop("config", None) 

139 

140 # Get names of parameters in the decorated function's signature 

141 # that can accept keyword arguments (excluding self and config) 

142 func_sig_kwarg_names = { 

143 name 

144 for name, param in sig.parameters.items() 

145 if param.kind 

146 in ( 

147 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

148 inspect.Parameter.KEYWORD_ONLY, 

149 ) 

150 and name not in ("self", "config") 

151 } 

152 

153 # Separate kwargs into those matching the signature and potential config overrides 

154 func_specific_kwargs = {} 

155 potential_config_kwargs = {} 

156 for k, v in kwargs.items(): 

157 if k in func_sig_kwarg_names: 

158 func_specific_kwargs[k] = v 

159 else: 

160 potential_config_kwargs[k] = v 

161 

162 # Determine the configuration object to use 

163 if user_config is None: 

164 # No explicit config object, use self.config and apply overrides 

165 if not hasattr(self, "config"): 

166 raise AttributeError( 

167 f"{type(self).__name__} instance is missing 'config' attribute." 

168 ) 

169 # Ensure self.config is the right type (or duck-types model_copy/model_dump) 

170 if not hasattr(self.config, "model_copy") or not hasattr( 

171 self.config, "model_dump" 

172 ): 

173 raise TypeError( 

174 f"'self.config' on {type(self).__name__} must have 'model_copy' and 'model_dump' methods." 

175 ) 

176 

177 final_config = self.config.model_copy(deep=True) # Start with a copy 

178 

179 # Validate potential_config_kwargs against valid_params 

180 if valid_params: 

181 # Check for invalid *config* parameters among the potential overrides 

182 invalid_config_params = ( 

183 set(potential_config_kwargs.keys()) - valid_params 

184 ) 

185 if invalid_config_params: 

186 print( 

187 f"\nWarning: The following parameters are not valid " 

188 f"configuration overrides for {func.__name__}:" # type: ignore 

189 ) 

190 print( 

191 f"Invalid config parameters: {', '.join(invalid_config_params)}" 

192 ) 

193 print( 

194 f"Valid config parameters are: {', '.join(sorted(valid_params))}" 

195 ) 

196 print("These parameters will be ignored.") 

197 

198 # Filter potential overrides to only include those listed in valid_params 

199 # and that actually exist as fields in the config object 

200 # (prevents adding arbitrary attributes if DefaultConfig uses Pydantic's extra='forbid') 

201 valid_config_fields = { 

202 p 

203 for p in valid_params 

204 if hasattr(final_config, p) # Safer check 

205 } 

206 config_overrides = { 

207 k: v 

208 for k, v in potential_config_kwargs.items() 

209 if k in valid_config_fields 

210 } 

211 

212 else: # valid_params is None: Allow all potential config kwargs as overrides 

213 # Optional: Add a check here if you want to ensure they exist in the config model 

214 config_overrides = potential_config_kwargs 

215 

216 # Apply the valid overrides 

217 if config_overrides: 

218 final_config = final_config.model_copy(update=config_overrides) 

219 

220 else: 

221 # User provided an explicit config object 

222 # Add type check if DefaultConfig class is available 

223 # Note: Replace 'object' with 'DefaultConfig' if it's imported/defined 

224 if not isinstance(user_config, DefaultConfig): 

225 # Trying to be robust if DefaultConfig is not strictly enforced type 

226 if hasattr(user_config, "model_copy") and hasattr( 

227 user_config, "model_dump" 

228 ): 

229 pass # Looks like a valid config object duck-typing Pydantic 

230 else: 

231 raise TypeError( 

232 "The 'config' parameter must be a valid configuration object " 

233 "(e.g., an instance of DefaultConfig or similar)." 

234 ) 

235 final_config = user_config 

236 # Decide what to do with potential_config_kwargs when user_config is provided. 

237 # Option 1: Ignore them (current implementation below) 

238 # Option 2: Raise an error if any exist 

239 # Option 3: Apply them even to the user_config (might be unexpected) 

240 if potential_config_kwargs: 

241 print( 

242 f"\nWarning: Additional keyword arguments provided " 

243 f"({', '.join(potential_config_kwargs.keys())}) " 

244 f"while an explicit 'config' object was also passed to {func.__name__}. " # type: ignore 

245 f"These additional arguments will be ignored as configuration overrides." 

246 ) 

247 

248 # Call the original function with the correct arguments 

249 # Pass: self, original *args, the determined config object, 

250 # and only the **kwargs that matched the function's signature. 

251 return func(self, *args, config=final_config, **func_specific_kwargs) 

252 

253 # Preserve information about valid_params on the wrapper if needed elsewhere 

254 setattr(wrapper, "valid_params", valid_params) 

255 return wrapper 

256 

257 return decorator 

258 

259 

260class Saver: 

261 """Handles the saving of BasePipeline objects. 

262 

263 Atrributes: 

264 file_path: path to save file. 

265 preprocessor_path: path where pickle object of preprocesser should be saved. 

266 model_state_path: path where model state dict should be saved. 

267 save_all: indicator if all results should be save, or only core pipeline functionalty 

268 

269 """ 

270 

271 def __init__(self, file_path: str, save_all: bool): 

272 """Initializes the Saver with the base file path. 

273 

274 Args: 

275 file_path: The base file path (without extensions). 

276 """ 

277 

278 self.save_all = save_all 

279 self.file_stem: str = Path(file_path).stem 

280 self.file_name: str = Path(file_path).name 

281 self.folder: str = Path(file_path).parent.as_posix() 

282 

283 self.preprocessor_path: str = os.path.join( 

284 self.folder, f"{self.file_stem}_preprocessor.pkl" 

285 ) 

286 self.model_state_path: str = os.path.join( 

287 self.folder, f"{self.file_stem}_model.pth" 

288 ) 

289 self.file_path: str = os.path.join(self.folder, self.file_name) 

290 os.makedirs(self.folder, exist_ok=True) 

291 

292 @no_type_check 

293 def save(self, pipeline: "BasePipeline"): 

294 """Saves the BasePipeline object. 

295 

296 Args: 

297 pipeline: The BasePipeline object to save. 

298 """ 

299 

300 self._save_preprocessor(pipeline._preprocessor) # type: ignore 

301 self._save_model_state(pipeline) 

302 self.pipeline = pipeline 

303 

304 self.pipeline._trainer.purge() 

305 

306 if not self.save_all: 

307 print("saving memory efficient") 

308 self.reset_to_defaults(pipeline.result) # ty: ignore 

309 

310 self.pipeline.preprocessed_data = None # ty: ignore 

311 self.pipeline._datasets = None # ty: ignore 

312 self.pipeline.raw_user_data = None # ty: ignore 

313 self.pipeline._datasets = None 

314 self.pipeline._preprocessor = type( 

315 self.pipeline._preprocessor 

316 )( # ty: ignore 

317 config=pipeline.config # ty: ignore 

318 ) # ty: ignore 

319 self.pipeline.visualizer = type(self.pipeline.visualizer)() # ty: ignore 

320 

321 self._save_pipeline_object(self.pipeline) 

322 

323 ## REMOVING Zip functionality as it causes issues with filesystems 

324 # with zipfile.ZipFile( 

325 # os.path.join(self.folder, f"{self.file_stem}.zip"), "w" 

326 # ) as archive: 

327 # arcname = self.file_name 

328 # archive.write(self.file_path , arcname=arcname) 

329 # arcname = f"{self.file_stem}_preprocessor.pkl" 

330 # archive.write(self.preprocessor_path, arcname=arcname) 

331 # for model_state_path in self.model_state_paths: 

332 # arcname = f"{self.file_stem}_model.pth" 

333 # archive.write(model_state_path, arcname=arcname) 

334 # os.remove(self.file_path) 

335 # os.remove(self.preprocessor_path) 

336 # for model_state_path in self.model_state_paths: 

337 # os.remove(model_state_path) 

338 

339 def _save_pipeline_object(self, pipeline: "BasePipeline"): 

340 try: 

341 with open(self.file_path, "wb") as f: 

342 pickle.dump(pipeline, f) 

343 print("Pipeline object saved successfully.") 

344 except (pickle.PicklingError, OSError) as e: 

345 print(f"Error saving pipeline object: {e}") 

346 raise e 

347 

348 def _save_preprocessor(self, preprocessor): 

349 if preprocessor is not None: 

350 try: 

351 with open(self.preprocessor_path, "wb") as f: 

352 pickle.dump(preprocessor, f) 

353 print("Preprocessor saved successfully.") 

354 except (pickle.PickleError, OSError) as e: 

355 print(f"Error saving preprocessor: {e}") 

356 raise e 

357 

358 @no_type_check 

359 def _save_model_state(self, pipeline: "BasePipeline"): 

360 self.model_state_paths: List[str] = [] 

361 if pipeline.result is not None and pipeline.result.model is not None: # type: ignore 

362 if isinstance(pipeline.result.model, torch.nn.Module): 

363 try: 

364 pipeline.result.model.to("cpu") 

365 torch.save( 

366 pipeline.result.model.state_dict(), self.model_state_path 

367 ) # type: ignore 

368 self.model_state_paths.append(self.model_state_path) 

369 

370 except (TypeError, OSError) as e: 

371 print(f"Error saving model state: {e}") 

372 raise e 

373 elif isinstance(pipeline.result.model, dict): 

374 for model_name, model in pipeline.result.model.items(): 

375 if hasattr(model, "module"): 

376 model = model.module 

377 model.to("cpu") 

378 cur_path: str = os.path.join( 

379 self.folder, f"{model_name}_{self.file_stem}_model.pth" 

380 ) 

381 torch.save(model.state_dict(), cur_path) # type: ignore 

382 self.model_state_paths.append(cur_path) 

383 else: 

384 raise TypeError( 

385 f"pipeline.result.model is neither a torch.nn.Module nor a dict, got {type(pipeline.result.model)}" 

386 ) 

387 

388 @no_type_check 

389 def reset_to_defaults(self, obj): 

390 if not is_dataclass(obj): 

391 raise ValueError("Object must be a dataclass") 

392 

393 for f in fields(obj): 

394 # we keep the adata_latent space as a "core result" 

395 if f.name == "adata_latent": 

396 continue 

397 if f.name == "losses" or f.name == "sub_losses": # Keep loss dynamics 

398 continue 

399 if f.name == "model": 

400 # we need to keep the instantiated class, so we can load the state dict 

401 # but we don't want to save the model twice (once via the pipeline object, once as model.pth) 

402 # thus, we first save the state_dict, then reiniate the model to free memory 

403 if isinstance(obj.model, dict): 

404 empty_models = {} 

405 for model_name, model in obj.model.items(): 

406 if hasattr(model, "module"): 

407 model = model.module 

408 empty_models[model_name] = type(model)(**model.init_args) 

409 elif isinstance(obj.model, torch.nn.Module): 

410 obj.model.init_args["ontologies"] = self.pipeline.ontologies 

411 obj.model.init_args["feature_order"] = ( 

412 self.pipeline._trainer.feature_order 

413 ) 

414 obj.model = type(obj.model)(**obj.model.init_args) 

415 continue 

416 if f.default_factory is not MISSING: 

417 setattr(obj, f.name, f.default_factory()) 

418 elif f.default is not MISSING: 

419 setattr(obj, f.name, f.default) 

420 else: 

421 setattr(obj, f.name, None) 

422 

423 

424class Loader: 

425 """Handles the loading of BasePipeline objects. 

426 

427 Atrributes: 

428 file_path: path of saved pipeline object. 

429 preprocessor_path: path where pickle object of preprocesser should was saved. 

430 model_state_path: path where model state dict was saved. 

431 """ 

432 

433 def __init__(self, file_path: str): 

434 """Initializes the Loader with the base file path. 

435 

436 Args: 

437 file_path: The base file path (without extensions). 

438 """ 

439 

440 self.file_stem: str = Path(file_path).stem 

441 self.file_name: str = Path(file_path).name 

442 self.folder: str = Path(file_path).parent.as_posix() 

443 

444 self.preprocessor_path: str = os.path.join( 

445 self.folder, f"{self.file_stem}_preprocessor.pkl" 

446 ) 

447 self.model_state_path: str = os.path.join( 

448 self.folder, f"{self.file_stem}_model.pth" 

449 ) 

450 self.file_path: str = os.path.join(self.folder, self.file_name) 

451 

452 def load(self) -> Any: 

453 """Loads the BasePipeline object. 

454 

455 Returns: 

456 The loaded BasePipeline object, or None on error. 

457 """ 

458 ## REMOVING Zip functionality since it causes issues with filesystems 

459 # try: 

460 # with zipfile.ZipFile( 

461 # os.path.join(self.folder, f"{self.file_stem}.zip"), "r" 

462 # ) as archive: 

463 # archive.extractall() 

464 # except: 

465 # print(f"Error extracting zip file at {self.file_path}") 

466 # print("Attempting to load without extraction...") 

467 

468 loaded_obj = self._load_pipeline_object() 

469 if loaded_obj is None: 

470 raise ValueError("Error while loading pipeline object") 

471 

472 loaded_obj._preprocessor = self._load_preprocessor() 

473 loaded_obj.result.model = self._load_model_state(loaded_obj) 

474 

475 return loaded_obj 

476 

477 def _load_pipeline_object(self) -> Any: 

478 print(f"Attempting to load a pipeline from {self.file_path}...") 

479 try: 

480 if not os.path.exists(self.file_path): 

481 print(f"Error: File not found at {self.file_path}") 

482 return None 

483 with open(self.file_path, "rb") as f: 

484 loaded_obj = pickle.load(f) 

485 print( 

486 f"Pipeline object loaded successfully. Actual type: {type(loaded_obj).__name__}" 

487 ) 

488 return loaded_obj 

489 except (pickle.UnpicklingError, EOFError, OSError, FileNotFoundError) as e: 

490 print(f"Error loading pipeline object: {e}") 

491 return None 

492 

493 def _load_preprocessor(self): 

494 if os.path.exists(self.preprocessor_path): 

495 try: 

496 with open(self.preprocessor_path, "rb") as f: 

497 preprocessor = pickle.load(f) 

498 print("Preprocessor loaded successfully.") 

499 return preprocessor 

500 except (pickle.UnpicklingError, EOFError, OSError, FileNotFoundError) as e: 

501 print(f"Error loading preprocessor: {e}") 

502 return None 

503 else: 

504 print("Preprocessor file not found. Skipping preprocessor load.") 

505 return None 

506 

507 @no_type_check 

508 def _load_model_state(self, loaded_obj: "BasePipeline"): 

509 if loaded_obj.result is None: # type: ignore: 

510 raise ValueError("Loaded pipeline has no result attribute") 

511 if loaded_obj.result.model is None: 

512 raise ValueError("Loaded pipeline result has no model attribute") 

513 if isinstance(loaded_obj.result.model, dict): 

514 for model_name, model in loaded_obj.result.model.items(): 

515 if hasattr(model, "module"): 

516 model = model.module 

517 cur_path: str = os.path.join( 

518 self.folder, f"{model_name}_{self.file_stem}_model.pth" 

519 ) 

520 if not os.path.exists(cur_path): 

521 raise FileNotFoundError(f"Model state file not found at {cur_path}") 

522 model_state = torch.load(cur_path, map_location="cpu") 

523 loaded_obj.result.model[model_name].to("cpu") 

524 loaded_obj.result.model[model_name].load_state_dict( # type: ignore 

525 model_state 

526 ) 

527 return loaded_obj.result.model # type: ignore 

528 elif isinstance(loaded_obj.result.model, torch.nn.Module): 

529 if not os.path.exists(self.model_state_path): 

530 raise FileNotFoundError( 

531 f"Model state file not found at {self.model_state_path}" 

532 ) 

533 model_state = torch.load(self.model_state_path, map_location="cpu") 

534 loaded_obj.result.model.to("cpu") 

535 loaded_obj.result.model.load_state_dict(model_state) # type: ignore 

536 return loaded_obj.result.model # type: ignore 

537 else: 

538 raise TypeError( 

539 f"Loaded model is neither a dict nor a torch.nn.Module, got {type(loaded_obj.result.model)}" 

540 ) 

541 

542 

543def flip_labels(labels: torch.Tensor) -> torch.Tensor: 

544 """Randomly flip modality labels with probability (1 - 1/n_modalities), vectorized. 

545 

546 This is mainly used for advers training in multi modal xmodalix. 

547 

548 Args: 

549 labels: tensor of labels 

550 Returns: 

551 flipped tensor 

552 

553 """ 

554 device = labels.device 

555 n_modalities = labels.unique().numel() 

556 batch_size = labels.size(0) 

557 flip_prob = 1.0 - 1.0 / n_modalities 

558 

559 # Decide which labels to flip 

560 flip_mask = torch.rand(batch_size, device=device) < flip_prob 

561 

562 # Sample random labels for flipping 

563 rand_labels = torch.randint(0, n_modalities, size=(batch_size,), device=device) 

564 

565 # Ensure sampled labels are different from original labels 

566 needs_resample = (rand_labels == labels) & flip_mask 

567 while needs_resample.any(): 

568 rand_labels[needs_resample] = torch.randint( 

569 0, n_modalities, size=(needs_resample.sum(),), device=device 

570 ) 

571 needs_resample = (rand_labels == labels) & flip_mask 

572 

573 # Apply flipped labels where needed 

574 flipped = labels.clone() 

575 flipped[flip_mask] = rand_labels[flip_mask] 

576 

577 return flipped 

578 

579 

580def find_translation_keys( 

581 config: DefaultConfig, 

582 trained_modalities: List[str], 

583 from_key: Optional[str] = None, 

584 to_key: Optional[str] = None, 

585) -> Dict[str, str]: # type: ignore 

586 """Find translation source and target modalities. 

587 

588 Determines which modalities serve as the "from" and "to" directions for 

589 cross-modal prediction, either from explicit arguments or from the 

590 configuration. 

591 

592 Args: 

593 config: Experiment configuration containing data information. 

594 trained_modalities: List of trained modality names. 

595 from_key: Optional name of the source modality. 

596 to_key: Optional name of the target modality. 

597 

598 Returns: 

599 A dictionary with two entries: 

600 - "from": Name of the source modality. 

601 - "to": Name of the target modality. 

602 

603 Raises: 

604 ValueError: If no valid "from" or "to" modality is found, or if 

605 multiple conflicting directions are specified. 

606 """ 

607 from_key_final: Optional[str] = None 

608 to_key_final: Optional[str] = None 

609 simple_names: List[str] = [ 

610 tm.split(".", 1)[1] if "." in tm else tm for tm in trained_modalities 

611 ] 

612 

613 if from_key and to_key: 

614 for name in trained_modalities: 

615 simple_name = name.split(".", 1)[1] if "." in name else name 

616 

617 if from_key == simple_name or from_key == name: 

618 from_key_final = name 

619 # use if instead of elif to allow for reference prediciton where from_key == to_key 

620 if to_key == simple_name or from_key == name: 

621 to_key_final = name 

622 # if the users passes from_key and to_key and we don't find them, we raise an error 

623 if not (from_key_final and to_key_final): 

624 raise ValueError( 

625 f"Invalid translation keys: {from_key} => {to_key}, valid keys are: {simple_names}" 

626 ) 

627 return {"from": from_key_final, "to": to_key_final} 

628 

629 data_info = config.data_config.data_info 

630 for name in trained_modalities: 

631 simple_name = name.split(".", 1)[1] if "." in name else name 

632 cur_info = data_info.get(simple_name) 

633 if not cur_info or not hasattr(cur_info, "translate_direction"): 

634 continue 

635 

636 direction = cur_info.translate_direction 

637 if direction == "to": 

638 if to_key_final is not None: 

639 raise ValueError( 

640 f"Multiple 'to' directions found: '{to_key_final}' and '{name}'" 

641 ) 

642 to_key_final = name 

643 elif direction == "from": 

644 if from_key_final is not None: 

645 raise ValueError( 

646 f"Multiple 'from' directions found: '{from_key_final}' and '{name}'" 

647 ) 

648 from_key_final = name 

649 

650 if from_key_final is None: 

651 raise ValueError( 

652 "No modality with a 'from' direction was specified in the config." 

653 ) 

654 if to_key_final is None: 

655 raise ValueError( 

656 "No modality with a 'to' direction was specified in the config." 

657 ) 

658 assert from_key_final is not None and to_key_final is not None 

659 

660 return {"from": from_key_final, "to": to_key_final} 

661 

662 

663def preprocess_explanations( 

664 df: pd.DataFrame, n: int = 10, max_dims: int = 8 

665) -> Dict[str, List[str]]: 

666 """ 

667 Transform a DataFrame of gene attributions into a dictionary mapping each 

668 (selected) latent dimension to its top-n genes. 

669 

670 The function: 

671 - Computes the mean attribution for each latent dimension. 

672 - Selects the top `max_dims` most informative dimensions (highest mean). 

673 - Extracts the top-n genes for each selected dimension. 

674 

675 Args: 

676 df: Output DataFrame of gene attributions from explainix. 

677 n: Number of top genes to include per latent dimension. 

678 max_dims: Maximum number of latent dimensions to retain. 

679 

680 Returns: 

681 Dictionary with latent dimensions as keys and lists of top gene names as values. 

682 """ 

683 # Identify top latent dimensions by mean attribution 

684 top_dims = df.mean(axis=0).nlargest(max_dims).index 

685 

686 result = {} 

687 for col in top_dims: 

688 result[col] = df.nlargest(n, col).index.tolist() 

689 

690 return result