Coverage for src / autoencodix / utils / _utils.py: 27%
299 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1"""
2Stores utility functions for the autoencodix package.
3Use of OOP would be overkill for the simple functions in this module.
4"""
6from pathlib import Path
8# import zipfile
9import inspect
10import os
11from collections import defaultdict
12from dataclasses import MISSING, fields, is_dataclass
13from functools import wraps
14from typing import Any, Callable, Dict, List, Optional, no_type_check
15from autoencodix.data._datasetcontainer import DatasetContainer
16from autoencodix.utils._result import Result
18import dill as pickle # type: ignore
19import torch
20import pandas as pd
21from matplotlib import pyplot as plt
23from autoencodix.configs.default_config import DefaultConfig
26# only for type hints, to avoid circual import
27class BasePipeline:
28 """Only for type hints in utils, not real BasePipeline class"""
30 pass
33def get_dataset(result: Result) -> Optional[DatasetContainer]:
34 """Retrieve the dataset from the Result object, depending on if new_datasets is filled.
36 Args:
37 result: The Result object containing the dataset.
38 Returns:
39 The appropriate DatasetContainer object.
41 """
42 splits = ["train", "valid", "test"]
43 if not result.new_datasets:
44 return result.datasets
45 new_values: List[Any] = [result.new_datasets[split] for split in splits]
46 # check if all new_datasets are None
47 if all(v is None for v in new_values):
48 return result.datasets
49 else:
50 return result.new_datasets
53def nested_dict():
54 """Creates a nested defaultdict.
56 This function returns a defaultdict where each value is another defaultdict
57 of the same type. This allows for the creation of arbitrarily deep nested
58 dictionaries without having to explicitly define each level.
60 Returns:
61 :A nested defaultdict where each value is another nested defaultdict.
62 """
63 return defaultdict(nested_dict)
66def nested_to_tuple(d, base=()):
67 """Recursively converts a nested dictionary into tuples.
69 Args:
70 d: The dictionary to convert.
71 base: The base tuple to start with. Defaults to ().
73 Yields:
74 tuple: Tuples representing the nested dictionary structure, where each tuple
75 contains the keys leading to a value and the value itself.
76 """
77 if not isinstance(d, dict):
78 yield base + (d,)
80 else:
81 for k, v in d.items():
82 if isinstance(v, dict):
83 yield from nested_to_tuple(v, base + (k,))
84 else:
85 yield base + (k, v)
88@no_type_check
89def show_figure(fig):
90 """Display a given Matplotlib figure in a new window.
92 Args:
93 fig: The figure to be displayed.
95 """
96 dummy = plt.figure()
97 new_manager = dummy.canvas.manager
98 new_manager.canvas.figure = fig
99 fig.set_canvas(new_manager.canvas)
102def config_method(valid_params: Optional[set[str]] = None):
103 """Decorator for methods that accept configuration parameters via kwargs or an explicit 'config' object.
105 It separates kwargs intended for the function's signature from those
106 intended as configuration overrides, validates the latter against
107 `valid_params`, applies valid overrides to a copy of `self.config`,
108 and passes the appropriate arguments to the decorated function.
110 Args:
111 valid_params: Set of valid configuration parameter names that can be overridden
112 via kwargs for this method. If None, all kwargs not matching the
113 function signature are considered potentially valid config overrides.
114 """
116 def decorator(func: Callable) -> Callable:
117 sig = inspect.signature(func)
119 param_docs = "\n\nValid configuration parameters (passed via **kwargs):\n"
120 if valid_params:
121 param_docs += "\n".join(f"- `{param}`" for param in sorted(valid_params))
122 else:
123 param_docs += (
124 "All keyword arguments not matching the function's "
125 "signature are treated as potential configuration overrides."
126 )
128 if func.__doc__ is None:
129 func.__doc__ = ""
130 # Avoid duplicating if decorator is applied multiple times (though unlikely)
131 if "Valid configuration parameters" not in func.__doc__:
132 func.__doc__ += param_docs
133 # --- End Docstring Modification ---
135 @wraps(func)
136 def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
137 # Pop the explicit config object if provided
138 user_config = kwargs.pop("config", None)
140 # Get names of parameters in the decorated function's signature
141 # that can accept keyword arguments (excluding self and config)
142 func_sig_kwarg_names = {
143 name
144 for name, param in sig.parameters.items()
145 if param.kind
146 in (
147 inspect.Parameter.POSITIONAL_OR_KEYWORD,
148 inspect.Parameter.KEYWORD_ONLY,
149 )
150 and name not in ("self", "config")
151 }
153 # Separate kwargs into those matching the signature and potential config overrides
154 func_specific_kwargs = {}
155 potential_config_kwargs = {}
156 for k, v in kwargs.items():
157 if k in func_sig_kwarg_names:
158 func_specific_kwargs[k] = v
159 else:
160 potential_config_kwargs[k] = v
162 # Determine the configuration object to use
163 if user_config is None:
164 # No explicit config object, use self.config and apply overrides
165 if not hasattr(self, "config"):
166 raise AttributeError(
167 f"{type(self).__name__} instance is missing 'config' attribute."
168 )
169 # Ensure self.config is the right type (or duck-types model_copy/model_dump)
170 if not hasattr(self.config, "model_copy") or not hasattr(
171 self.config, "model_dump"
172 ):
173 raise TypeError(
174 f"'self.config' on {type(self).__name__} must have 'model_copy' and 'model_dump' methods."
175 )
177 final_config = self.config.model_copy(deep=True) # Start with a copy
179 # Validate potential_config_kwargs against valid_params
180 if valid_params:
181 # Check for invalid *config* parameters among the potential overrides
182 invalid_config_params = (
183 set(potential_config_kwargs.keys()) - valid_params
184 )
185 if invalid_config_params:
186 print(
187 f"\nWarning: The following parameters are not valid "
188 f"configuration overrides for {func.__name__}:" # type: ignore
189 )
190 print(
191 f"Invalid config parameters: {', '.join(invalid_config_params)}"
192 )
193 print(
194 f"Valid config parameters are: {', '.join(sorted(valid_params))}"
195 )
196 print("These parameters will be ignored.")
198 # Filter potential overrides to only include those listed in valid_params
199 # and that actually exist as fields in the config object
200 # (prevents adding arbitrary attributes if DefaultConfig uses Pydantic's extra='forbid')
201 valid_config_fields = {
202 p
203 for p in valid_params
204 if hasattr(final_config, p) # Safer check
205 }
206 config_overrides = {
207 k: v
208 for k, v in potential_config_kwargs.items()
209 if k in valid_config_fields
210 }
212 else: # valid_params is None: Allow all potential config kwargs as overrides
213 # Optional: Add a check here if you want to ensure they exist in the config model
214 config_overrides = potential_config_kwargs
216 # Apply the valid overrides
217 if config_overrides:
218 final_config = final_config.model_copy(update=config_overrides)
220 else:
221 # User provided an explicit config object
222 # Add type check if DefaultConfig class is available
223 # Note: Replace 'object' with 'DefaultConfig' if it's imported/defined
224 if not isinstance(user_config, DefaultConfig):
225 # Trying to be robust if DefaultConfig is not strictly enforced type
226 if hasattr(user_config, "model_copy") and hasattr(
227 user_config, "model_dump"
228 ):
229 pass # Looks like a valid config object duck-typing Pydantic
230 else:
231 raise TypeError(
232 "The 'config' parameter must be a valid configuration object "
233 "(e.g., an instance of DefaultConfig or similar)."
234 )
235 final_config = user_config
236 # Decide what to do with potential_config_kwargs when user_config is provided.
237 # Option 1: Ignore them (current implementation below)
238 # Option 2: Raise an error if any exist
239 # Option 3: Apply them even to the user_config (might be unexpected)
240 if potential_config_kwargs:
241 print(
242 f"\nWarning: Additional keyword arguments provided "
243 f"({', '.join(potential_config_kwargs.keys())}) "
244 f"while an explicit 'config' object was also passed to {func.__name__}. " # type: ignore
245 f"These additional arguments will be ignored as configuration overrides."
246 )
248 # Call the original function with the correct arguments
249 # Pass: self, original *args, the determined config object,
250 # and only the **kwargs that matched the function's signature.
251 return func(self, *args, config=final_config, **func_specific_kwargs)
253 # Preserve information about valid_params on the wrapper if needed elsewhere
254 setattr(wrapper, "valid_params", valid_params)
255 return wrapper
257 return decorator
260class Saver:
261 """Handles the saving of BasePipeline objects.
263 Atrributes:
264 file_path: path to save file.
265 preprocessor_path: path where pickle object of preprocesser should be saved.
266 model_state_path: path where model state dict should be saved.
267 save_all: indicator if all results should be save, or only core pipeline functionalty
269 """
271 def __init__(self, file_path: str, save_all: bool):
272 """Initializes the Saver with the base file path.
274 Args:
275 file_path: The base file path (without extensions).
276 """
278 self.save_all = save_all
279 self.file_stem: str = Path(file_path).stem
280 self.file_name: str = Path(file_path).name
281 self.folder: str = Path(file_path).parent.as_posix()
283 self.preprocessor_path: str = os.path.join(
284 self.folder, f"{self.file_stem}_preprocessor.pkl"
285 )
286 self.model_state_path: str = os.path.join(
287 self.folder, f"{self.file_stem}_model.pth"
288 )
289 self.file_path: str = os.path.join(self.folder, self.file_name)
290 os.makedirs(self.folder, exist_ok=True)
292 @no_type_check
293 def save(self, pipeline: "BasePipeline"):
294 """Saves the BasePipeline object.
296 Args:
297 pipeline: The BasePipeline object to save.
298 """
300 self._save_preprocessor(pipeline._preprocessor) # type: ignore
301 self._save_model_state(pipeline)
302 self.pipeline = pipeline
304 self.pipeline._trainer.purge()
306 if not self.save_all:
307 print("saving memory efficient")
308 self.reset_to_defaults(pipeline.result) # ty: ignore
310 self.pipeline.preprocessed_data = None # ty: ignore
311 self.pipeline._datasets = None # ty: ignore
312 self.pipeline.raw_user_data = None # ty: ignore
313 self.pipeline._datasets = None
314 self.pipeline._preprocessor = type(
315 self.pipeline._preprocessor
316 )( # ty: ignore
317 config=pipeline.config # ty: ignore
318 ) # ty: ignore
319 self.pipeline.visualizer = type(self.pipeline.visualizer)() # ty: ignore
321 self._save_pipeline_object(self.pipeline)
323 ## REMOVING Zip functionality as it causes issues with filesystems
324 # with zipfile.ZipFile(
325 # os.path.join(self.folder, f"{self.file_stem}.zip"), "w"
326 # ) as archive:
327 # arcname = self.file_name
328 # archive.write(self.file_path , arcname=arcname)
329 # arcname = f"{self.file_stem}_preprocessor.pkl"
330 # archive.write(self.preprocessor_path, arcname=arcname)
331 # for model_state_path in self.model_state_paths:
332 # arcname = f"{self.file_stem}_model.pth"
333 # archive.write(model_state_path, arcname=arcname)
334 # os.remove(self.file_path)
335 # os.remove(self.preprocessor_path)
336 # for model_state_path in self.model_state_paths:
337 # os.remove(model_state_path)
339 def _save_pipeline_object(self, pipeline: "BasePipeline"):
340 try:
341 with open(self.file_path, "wb") as f:
342 pickle.dump(pipeline, f)
343 print("Pipeline object saved successfully.")
344 except (pickle.PicklingError, OSError) as e:
345 print(f"Error saving pipeline object: {e}")
346 raise e
348 def _save_preprocessor(self, preprocessor):
349 if preprocessor is not None:
350 try:
351 with open(self.preprocessor_path, "wb") as f:
352 pickle.dump(preprocessor, f)
353 print("Preprocessor saved successfully.")
354 except (pickle.PickleError, OSError) as e:
355 print(f"Error saving preprocessor: {e}")
356 raise e
358 @no_type_check
359 def _save_model_state(self, pipeline: "BasePipeline"):
360 self.model_state_paths: List[str] = []
361 if pipeline.result is not None and pipeline.result.model is not None: # type: ignore
362 if isinstance(pipeline.result.model, torch.nn.Module):
363 try:
364 pipeline.result.model.to("cpu")
365 torch.save(
366 pipeline.result.model.state_dict(), self.model_state_path
367 ) # type: ignore
368 self.model_state_paths.append(self.model_state_path)
370 except (TypeError, OSError) as e:
371 print(f"Error saving model state: {e}")
372 raise e
373 elif isinstance(pipeline.result.model, dict):
374 for model_name, model in pipeline.result.model.items():
375 if hasattr(model, "module"):
376 model = model.module
377 model.to("cpu")
378 cur_path: str = os.path.join(
379 self.folder, f"{model_name}_{self.file_stem}_model.pth"
380 )
381 torch.save(model.state_dict(), cur_path) # type: ignore
382 self.model_state_paths.append(cur_path)
383 else:
384 raise TypeError(
385 f"pipeline.result.model is neither a torch.nn.Module nor a dict, got {type(pipeline.result.model)}"
386 )
388 @no_type_check
389 def reset_to_defaults(self, obj):
390 if not is_dataclass(obj):
391 raise ValueError("Object must be a dataclass")
393 for f in fields(obj):
394 # we keep the adata_latent space as a "core result"
395 if f.name == "adata_latent":
396 continue
397 if f.name == "losses" or f.name == "sub_losses": # Keep loss dynamics
398 continue
399 if f.name == "model":
400 # we need to keep the instantiated class, so we can load the state dict
401 # but we don't want to save the model twice (once via the pipeline object, once as model.pth)
402 # thus, we first save the state_dict, then reiniate the model to free memory
403 if isinstance(obj.model, dict):
404 empty_models = {}
405 for model_name, model in obj.model.items():
406 if hasattr(model, "module"):
407 model = model.module
408 empty_models[model_name] = type(model)(**model.init_args)
409 elif isinstance(obj.model, torch.nn.Module):
410 obj.model.init_args["ontologies"] = self.pipeline.ontologies
411 obj.model.init_args["feature_order"] = (
412 self.pipeline._trainer.feature_order
413 )
414 obj.model = type(obj.model)(**obj.model.init_args)
415 continue
416 if f.default_factory is not MISSING:
417 setattr(obj, f.name, f.default_factory())
418 elif f.default is not MISSING:
419 setattr(obj, f.name, f.default)
420 else:
421 setattr(obj, f.name, None)
424class Loader:
425 """Handles the loading of BasePipeline objects.
427 Atrributes:
428 file_path: path of saved pipeline object.
429 preprocessor_path: path where pickle object of preprocesser should was saved.
430 model_state_path: path where model state dict was saved.
431 """
433 def __init__(self, file_path: str):
434 """Initializes the Loader with the base file path.
436 Args:
437 file_path: The base file path (without extensions).
438 """
440 self.file_stem: str = Path(file_path).stem
441 self.file_name: str = Path(file_path).name
442 self.folder: str = Path(file_path).parent.as_posix()
444 self.preprocessor_path: str = os.path.join(
445 self.folder, f"{self.file_stem}_preprocessor.pkl"
446 )
447 self.model_state_path: str = os.path.join(
448 self.folder, f"{self.file_stem}_model.pth"
449 )
450 self.file_path: str = os.path.join(self.folder, self.file_name)
452 def load(self) -> Any:
453 """Loads the BasePipeline object.
455 Returns:
456 The loaded BasePipeline object, or None on error.
457 """
458 ## REMOVING Zip functionality since it causes issues with filesystems
459 # try:
460 # with zipfile.ZipFile(
461 # os.path.join(self.folder, f"{self.file_stem}.zip"), "r"
462 # ) as archive:
463 # archive.extractall()
464 # except:
465 # print(f"Error extracting zip file at {self.file_path}")
466 # print("Attempting to load without extraction...")
468 loaded_obj = self._load_pipeline_object()
469 if loaded_obj is None:
470 raise ValueError("Error while loading pipeline object")
472 loaded_obj._preprocessor = self._load_preprocessor()
473 loaded_obj.result.model = self._load_model_state(loaded_obj)
475 return loaded_obj
477 def _load_pipeline_object(self) -> Any:
478 print(f"Attempting to load a pipeline from {self.file_path}...")
479 try:
480 if not os.path.exists(self.file_path):
481 print(f"Error: File not found at {self.file_path}")
482 return None
483 with open(self.file_path, "rb") as f:
484 loaded_obj = pickle.load(f)
485 print(
486 f"Pipeline object loaded successfully. Actual type: {type(loaded_obj).__name__}"
487 )
488 return loaded_obj
489 except (pickle.UnpicklingError, EOFError, OSError, FileNotFoundError) as e:
490 print(f"Error loading pipeline object: {e}")
491 return None
493 def _load_preprocessor(self):
494 if os.path.exists(self.preprocessor_path):
495 try:
496 with open(self.preprocessor_path, "rb") as f:
497 preprocessor = pickle.load(f)
498 print("Preprocessor loaded successfully.")
499 return preprocessor
500 except (pickle.UnpicklingError, EOFError, OSError, FileNotFoundError) as e:
501 print(f"Error loading preprocessor: {e}")
502 return None
503 else:
504 print("Preprocessor file not found. Skipping preprocessor load.")
505 return None
507 @no_type_check
508 def _load_model_state(self, loaded_obj: "BasePipeline"):
509 if loaded_obj.result is None: # type: ignore:
510 raise ValueError("Loaded pipeline has no result attribute")
511 if loaded_obj.result.model is None:
512 raise ValueError("Loaded pipeline result has no model attribute")
513 if isinstance(loaded_obj.result.model, dict):
514 for model_name, model in loaded_obj.result.model.items():
515 if hasattr(model, "module"):
516 model = model.module
517 cur_path: str = os.path.join(
518 self.folder, f"{model_name}_{self.file_stem}_model.pth"
519 )
520 if not os.path.exists(cur_path):
521 raise FileNotFoundError(f"Model state file not found at {cur_path}")
522 model_state = torch.load(cur_path, map_location="cpu")
523 loaded_obj.result.model[model_name].to("cpu")
524 loaded_obj.result.model[model_name].load_state_dict( # type: ignore
525 model_state
526 )
527 return loaded_obj.result.model # type: ignore
528 elif isinstance(loaded_obj.result.model, torch.nn.Module):
529 if not os.path.exists(self.model_state_path):
530 raise FileNotFoundError(
531 f"Model state file not found at {self.model_state_path}"
532 )
533 model_state = torch.load(self.model_state_path, map_location="cpu")
534 loaded_obj.result.model.to("cpu")
535 loaded_obj.result.model.load_state_dict(model_state) # type: ignore
536 return loaded_obj.result.model # type: ignore
537 else:
538 raise TypeError(
539 f"Loaded model is neither a dict nor a torch.nn.Module, got {type(loaded_obj.result.model)}"
540 )
543def flip_labels(labels: torch.Tensor) -> torch.Tensor:
544 """Randomly flip modality labels with probability (1 - 1/n_modalities), vectorized.
546 This is mainly used for advers training in multi modal xmodalix.
548 Args:
549 labels: tensor of labels
550 Returns:
551 flipped tensor
553 """
554 device = labels.device
555 n_modalities = labels.unique().numel()
556 batch_size = labels.size(0)
557 flip_prob = 1.0 - 1.0 / n_modalities
559 # Decide which labels to flip
560 flip_mask = torch.rand(batch_size, device=device) < flip_prob
562 # Sample random labels for flipping
563 rand_labels = torch.randint(0, n_modalities, size=(batch_size,), device=device)
565 # Ensure sampled labels are different from original labels
566 needs_resample = (rand_labels == labels) & flip_mask
567 while needs_resample.any():
568 rand_labels[needs_resample] = torch.randint(
569 0, n_modalities, size=(needs_resample.sum(),), device=device
570 )
571 needs_resample = (rand_labels == labels) & flip_mask
573 # Apply flipped labels where needed
574 flipped = labels.clone()
575 flipped[flip_mask] = rand_labels[flip_mask]
577 return flipped
580def find_translation_keys(
581 config: DefaultConfig,
582 trained_modalities: List[str],
583 from_key: Optional[str] = None,
584 to_key: Optional[str] = None,
585) -> Dict[str, str]: # type: ignore
586 """Find translation source and target modalities.
588 Determines which modalities serve as the "from" and "to" directions for
589 cross-modal prediction, either from explicit arguments or from the
590 configuration.
592 Args:
593 config: Experiment configuration containing data information.
594 trained_modalities: List of trained modality names.
595 from_key: Optional name of the source modality.
596 to_key: Optional name of the target modality.
598 Returns:
599 A dictionary with two entries:
600 - "from": Name of the source modality.
601 - "to": Name of the target modality.
603 Raises:
604 ValueError: If no valid "from" or "to" modality is found, or if
605 multiple conflicting directions are specified.
606 """
607 from_key_final: Optional[str] = None
608 to_key_final: Optional[str] = None
609 simple_names: List[str] = [
610 tm.split(".", 1)[1] if "." in tm else tm for tm in trained_modalities
611 ]
613 if from_key and to_key:
614 for name in trained_modalities:
615 simple_name = name.split(".", 1)[1] if "." in name else name
617 if from_key == simple_name or from_key == name:
618 from_key_final = name
619 # use if instead of elif to allow for reference prediciton where from_key == to_key
620 if to_key == simple_name or from_key == name:
621 to_key_final = name
622 # if the users passes from_key and to_key and we don't find them, we raise an error
623 if not (from_key_final and to_key_final):
624 raise ValueError(
625 f"Invalid translation keys: {from_key} => {to_key}, valid keys are: {simple_names}"
626 )
627 return {"from": from_key_final, "to": to_key_final}
629 data_info = config.data_config.data_info
630 for name in trained_modalities:
631 simple_name = name.split(".", 1)[1] if "." in name else name
632 cur_info = data_info.get(simple_name)
633 if not cur_info or not hasattr(cur_info, "translate_direction"):
634 continue
636 direction = cur_info.translate_direction
637 if direction == "to":
638 if to_key_final is not None:
639 raise ValueError(
640 f"Multiple 'to' directions found: '{to_key_final}' and '{name}'"
641 )
642 to_key_final = name
643 elif direction == "from":
644 if from_key_final is not None:
645 raise ValueError(
646 f"Multiple 'from' directions found: '{from_key_final}' and '{name}'"
647 )
648 from_key_final = name
650 if from_key_final is None:
651 raise ValueError(
652 "No modality with a 'from' direction was specified in the config."
653 )
654 if to_key_final is None:
655 raise ValueError(
656 "No modality with a 'to' direction was specified in the config."
657 )
658 assert from_key_final is not None and to_key_final is not None
660 return {"from": from_key_final, "to": to_key_final}
663def preprocess_explanations(
664 df: pd.DataFrame, n: int = 10, max_dims: int = 8
665) -> Dict[str, List[str]]:
666 """
667 Transform a DataFrame of gene attributions into a dictionary mapping each
668 (selected) latent dimension to its top-n genes.
670 The function:
671 - Computes the mean attribution for each latent dimension.
672 - Selects the top `max_dims` most informative dimensions (highest mean).
673 - Extracts the top-n genes for each selected dimension.
675 Args:
676 df: Output DataFrame of gene attributions from explainix.
677 n: Number of top genes to include per latent dimension.
678 max_dims: Maximum number of latent dimensions to retain.
680 Returns:
681 Dictionary with latent dimensions as keys and lists of top gene names as values.
682 """
683 # Identify top latent dimensions by mean attribution
684 top_dims = df.mean(axis=0).nlargest(max_dims).index
686 result = {}
687 for col in top_dims:
688 result[col] = df.nlargest(n, col).index.tolist()
690 return result