Coverage for src / autoencodix / utils / _result.py: 64%
174 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from dataclasses import dataclass, field
2from typing import Any, Optional, Dict, Union
3from anndata import AnnData
5import torch
6import pandas as pd
8from autoencodix.data._datasetcontainer import DatasetContainer
9from autoencodix.data.datapackage import DataPackage
11from mudata import MuData # type: ignore
12from ._traindynamics import TrainingDynamics
15@dataclass
16class LossRegistry:
17 """Dataclass to store multiple TrainingDynamics objects for different losses.
19 Objective is to make the Result class extensible for multiple losses that might
20 be needed in the future for diffrent autoencoder architectures.
22 Attributes:
23 _losses: A dictionary to store TrainingDynamics objects for different losses.
24 """
26 _losses: Dict[str, TrainingDynamics] = field(default_factory=dict)
28 def __post_init__(self):
29 for name, data in self._losses.items():
30 if not isinstance(data, TrainingDynamics):
31 raise ValueError(
32 f"Expected TrainingDynamics object, got {type(data)} for loss {name}"
33 )
35 def __eq__(self, other: object) -> bool:
36 if not isinstance(other, LossRegistry):
37 return False
38 if self._losses.keys() != other._losses.keys():
39 return False
40 return all(self._losses[key] == other._losses[key] for key in self._losses)
42 def add(self, data: Dict, split: str, epoch: int) -> None:
43 for key, value in data.items():
44 if key not in self._losses:
45 self._losses[key] = TrainingDynamics()
46 self._losses[key].add(epoch=epoch, data=value, split=split)
48 def get(self, key: str) -> TrainingDynamics:
49 if key not in self._losses:
50 self._losses[key] = TrainingDynamics()
51 return self._losses[key]
53 def losses(self):
54 return self._losses.items()
56 def set(self, key: str, value: TrainingDynamics) -> None:
57 self._losses[key] = value
59 def keys(self):
60 return self._losses.keys()
63@dataclass
64class Result:
65 """A dataclass to store results from the pipeline with predefined keys.
67 Attributes:
68 latentspaces: TrainingDynamics object storing latent space representations for 'train', 'valid', and 'test' splits.
69 sample_ids: TrainingDynamics object storing sample identifiers for 'train', 'valid', and 'test' splits.
70 reconstructions: TrainingDynamics object storing reconstructed outputs for 'train', 'valid', and 'test' splits.
71 mus: TrainingDynamics object storing mean values of latent distributions for 'train', 'valid', and 'test' splits.
72 sigmas: TrainingDynamics object storing standard deviations of latent distributions for 'train', 'valid', and 'test' splits.
73 losses: TrainingDynamics object storing the total loss for different epochs and splits ('train', 'valid', 'test').
74 sub_losses: LossRegistry object (extendable) for all sublosses.
75 preprocessed_data: torch.Tensor containing data after preprocessing.
76 model: final trained torch.nn.Module model.
77 model_checkpoints: TrainingDynamics object storing model state at each checkpoint.
78 datasets: Optional[DatasetContainer] containing train, valid, and test datasets.
79 new_datasets: Optional[DatasetContainer] containing new train, valid, and test datasets.
80 adata_latent: Optional[AnnData] containing latent representations as AnnData.
81 final_reconstruction: Optional[Union[DataPackage, MuData]] containing final reconstruction results.
82 sub_results: Optional[Dict[str, Any]] containing sub-results for multi-task or multi-modal models.
83 sub_reconstructions: Optional[Dict[str, Any]] containing sub-reconstructions for multi-task or multi-modal models.
84 embedding_evaluation: pd.DataFrame containing embedding evaluation results.
85 """
87 latentspaces: TrainingDynamics = field(default_factory=TrainingDynamics)
88 sample_ids: TrainingDynamics = field(default_factory=TrainingDynamics)
89 reconstructions: TrainingDynamics = field(default_factory=TrainingDynamics)
90 mus: TrainingDynamics = field(default_factory=TrainingDynamics)
91 sigmas: TrainingDynamics = field(default_factory=TrainingDynamics)
92 losses: TrainingDynamics = field(default_factory=TrainingDynamics)
93 sub_losses: LossRegistry = field(default_factory=LossRegistry)
94 preprocessed_data: torch.Tensor = field(default_factory=torch.Tensor)
95 model: Union[Dict[str, torch.nn.Module], torch.nn.Module] = field(
96 default_factory=torch.nn.Module
97 )
98 model_checkpoints: TrainingDynamics = field(default_factory=TrainingDynamics)
100 datasets: Optional[DatasetContainer] = field(
101 default_factory=lambda: DatasetContainer(train=None, valid=None, test=None)
102 )
103 new_datasets: Optional[DatasetContainer] = field(
104 default_factory=lambda: DatasetContainer(train=None, valid=None, test=None)
105 )
107 adata_latent: Optional[AnnData] = field(default_factory=AnnData)
108 final_reconstruction: Optional[
109 Union[DataPackage, MuData] # ty: ignore[invalid-type-form]
110 ] = field(default=None)
111 sub_results: Optional[Dict[str, Any]] = field(default=None)
112 sub_reconstructions: Optional[Dict[str, Any]] = field(default=None)
114 # Embedding evaluation results
115 embedding_evaluation: pd.DataFrame = field(default_factory=pd.DataFrame)
116 embedding_attributions: pd.DataFrame = field(default_factory=pd.DataFrame)
117 embedding_explanations: Dict[str, str] = field(default_factory=dict)
119 # plots: Dict[str, Any] = field(
120 # default_factory=nested_dict
121 # ) ## Nested dictionary of plots as figure handles
123 def __getitem__(self, key: str) -> Any:
124 """Retrieve the value associated with a specific key.
126 Args:
127 key: The name of the attribute to retrieve.
128 Returns:
129 The value of the specified attribute.
130 Raises:
131 KeyError - If the key is not a valid attribute of the Results class.
133 """
134 if not hasattr(self, key):
135 raise KeyError(
136 f"Invalid key: '{key}'. Allowed keys are: {', '.join(self.__annotations__.keys())}"
137 )
138 return getattr(self, key)
140 def __setitem__(self, key: str, value: Any) -> None:
141 """Assign a value to a specific attribute.
143 Args:
144 key: The name of the attribute to set.
145 value: The value to assign to the attribute.
146 Raises:
147 KeyError: If the key is not a valid attribute of the Results class.
149 """
150 if not hasattr(self, key):
151 raise KeyError(
152 f"Invalid key: '{key}'. Allowed keys are: {', '.join(self.__annotations__.keys())}"
153 )
154 setattr(self, key, value)
156 def _is_empty_value(self, value: Any) -> bool:
157 """
158 Helper method to check if an attribute of the Result object is empty.
160 Parameters:
161 value (Any): The value to check
162 Returns:
163 bool: True if the value is empty, False otherwise
165 """
167 if isinstance(value, TrainingDynamics):
168 return len(value._data) == 0
169 elif isinstance(value, torch.Tensor):
170 return value.numel() == 0
171 elif isinstance(value, torch.nn.Module):
172 return sum(p.numel() for p in value.parameters()) == 0
173 elif isinstance(value, DatasetContainer):
174 return all(v is None for v in [value.train, value.valid, value.test])
175 elif isinstance(value, LossRegistry):
176 # single Nones are handled in update method (skipped)
177 return all(v is None for _, v in value.losses())
179 return False
181 def update(self, other: "Result") -> None:
182 """Update the current Result object with values from another Result object.
184 For TrainingDynamics, merges the data across epochs and splits and overwrites if already exists.
185 For all other attributes, replaces the current value with the other value.
187 Args:
188 other: The Result object to update from.
189 Raises:
190 TypeError: If the input object is not a Result instance
192 """
193 if not isinstance(other, Result):
194 raise TypeError(f"Expected Result object, got {type(other)}")
196 for field_name in self.__annotations__.keys():
197 current_value = getattr(self, field_name)
198 other_value = getattr(other, field_name)
199 if self._is_empty_value(other_value):
200 continue
202 # Handle TrainingDynamics - merge data
203 if isinstance(current_value, TrainingDynamics):
204 current_value = self._update_traindynamics(current_value, other_value)
205 # For all other types - replace with other value
206 if isinstance(current_value, LossRegistry):
207 for key, value in other_value.losses():
208 if value is None:
209 continue
210 if not isinstance(value, TrainingDynamics):
211 raise ValueError(
212 f"Expected TrainingDynamics object, got {type(value)}"
213 )
214 updated_dynamic = self._update_traindynamics(
215 current_value=current_value.get(key=key), other_value=value
216 )
217 current_value.set(key=key, value=updated_dynamic)
218 else:
219 setattr(self, field_name, other_value)
221 def _update_traindynamics(
222 self, current_value: TrainingDynamics, other_value: TrainingDynamics
223 ) -> TrainingDynamics:
224 """Update TrainingDynamics object with values from another TrainingDynamics object.
226 Args:
227 current_value: The current TrainingDynamics object to update.
228 other_value: The TrainingDynamics object to update from.
230 Returns:
231 Updated TrainingDynamics object.
233 Examples:
234 >>> current = TrainingDynamics()
235 >>> current._data = {1: {"train": np.array([1, 2, 3])},
236 ... 2: None}
238 >>> other = TrainingDynamics()
239 >>> other._data = {1: {"train": np.array([4, 5, 6])},
240 ... 2: {"train": np.array([7, 8, 9])}}
241 >>> # after update
242 >>> print(current._data)
243 {1: {"train": np.array([4, 5, 6])}, # updated
244 2: {"train": np.array([7, 8, 9])}} # kept, because other was None
246 """
248 if current_value is None:
249 return other_value
250 if current_value._data is None:
251 return other_value
253 for epoch, split_data in other_value._data.items():
254 if split_data is None:
255 continue
256 if len(split_data) == 0:
257 continue
259 # If current epoch is None, it should be updated
260 if epoch in current_value._data and current_value._data[epoch] is None:
261 current_value._data[epoch] = {}
262 for split, data in split_data.items():
263 if data is None:
264 continue
265 current_value.add(epoch=epoch, data=data, split=split)
266 continue
268 if epoch not in current_value._data:
269 for split, data in split_data.items():
270 if data is None:
271 continue
272 current_value.add(epoch=epoch, data=data, split=split)
273 continue
274 # case when current epoch exists, then update all but None values
275 for split, value in split_data.items():
276 if value is not None:
277 current_value.add(epoch=epoch, data=value, split=split)
279 # Ensure ordering
280 current_value._data = dict(sorted(current_value._data.items()))
282 return current_value
284 def __str__(self) -> str:
285 """Provide a readable string representation of the Result object's public attributes.
287 Returns:
288 Formatted string showing all public attributes and their values
289 """
290 output = ["Result Object Public Attributes:", "-" * 30]
292 for name in self.__annotations__.keys():
293 if name.startswith("__"):
294 continue
296 value = getattr(self, name)
297 if isinstance(value, TrainingDynamics):
298 output.append(f"{name}: TrainingDynamics object")
299 elif isinstance(value, torch.nn.Module):
300 output.append(f"{name}: {value.__class__.__name__}")
301 elif isinstance(value, dict):
302 output.append(f"{name}: Dict with {len(value)} items")
303 elif isinstance(value, torch.Tensor):
304 output.append(f"{name}: Tensor of shape {tuple(value.shape)}")
305 else:
306 output.append(f"{name}: {value}")
308 return "\n".join(output)
310 def __repr__(self) -> str:
311 """Return the same representation as __str__ for consistency."""
312 return self.__str__()
314 def get_latent_df(
315 self, epoch: int, split: str, modality: Optional[str] = None
316 ) -> pd.DataFrame:
317 """Return latent representations as a DataFrame.
319 Retrieves latent vectors and their corresponding sample IDs for a given
320 epoch and data split. If a specific modality is provided, the results
321 are restricted to that modality. Column names are inferred from model
322 ontologies if available; otherwise, generic latent dimension labels are
323 used.
325 Args:
326 epoch: The epoch number to retrieve latents from.
327 split: The dataset split to query (e.g., "train", "valid", "test").
328 modality: Optional modality name to filter the latents and sample IDs.
330 Returns:
331 A DataFrame where rows correspond to samples, columns represent latent
332 dimensions, and the index contains sample IDs.
333 """
334 try:
335 latents = self.latentspaces.get(epoch=epoch, split=split)
336 ids = self.sample_ids.get(epoch=epoch, split=split)
337 if modality is not None: # for x-modalix and other multi-modal models
338 latents = latents[modality]
339 ids = ids[modality]
340 if hasattr(self.model, "ontologies") and self.model.ontologies is not None:
341 cols = list(self.model.ontologies[0].keys())
342 else:
343 cols = ["LatDim_" + str(i) for i in range(latents.shape[1])]
344 return pd.DataFrame(latents, index=ids, columns=cols)
345 except Exception as e:
346 import warnings
348 warnings.warn(
349 f"Could not retrieve latent representations for epoch {epoch} and split '{split}'. "
350 f"Returning empty DataFrame. This may be due to missing data in the Result object or incorrect keys.\n\n"
351 )
353 return pd.DataFrame()
355 def get_reconstructions_df(
356 self, epoch: int, split: str, modality: Optional[str] = None
357 ) -> pd.DataFrame:
358 """Return reconstructions as a DataFrame.
360 Retrieves reconstructed features and their corresponding sample IDs for a
361 given epoch and data split. If a specific modality is provided, the results
362 are restricted to that modality. Column names are based on the dataset's
363 feature identifiers.
365 Args:
366 epoch: The epoch number to retrieve reconstructions from.
367 split: The dataset split to query (e.g., "train", "valid", "test").
368 modality: Optional modality name to filter the reconstructions and
369 sample IDs.
371 Returns:
372 A DataFrame where rows correspond to samples, columns represent
373 reconstructed features, and the index contains sample IDs.
374 """
375 reconstructions = self.reconstructions.get(epoch=epoch, split=split)
376 ids = self.sample_ids.get(epoch=epoch, split=split)
377 if modality is not None:
378 reconstructions = reconstructions[modality]
379 ids = ids[modality]
381 if self.new_datasets is not None:
382 datasets = self.new_datasets
384 # cols = self.datasets.train.feature_ids
385 if split == "train" and datasets is not None and datasets.train is not None:
386 cols = datasets.train.feature_ids
387 elif split == "valid" and datasets is not None and datasets.valid is not None:
388 cols = datasets.valid.feature_ids
389 elif split == "test" and datasets is not None and datasets.test is not None:
390 cols = datasets.test.feature_ids
391 else:
392 cols = [f"Feature_{i}" for i in range(reconstructions.shape[1])]
393 return pd.DataFrame(reconstructions, index=ids, columns=cols)