Coverage for src / autoencodix / utils / _result.py: 64%

174 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from dataclasses import dataclass, field 

2from typing import Any, Optional, Dict, Union 

3from anndata import AnnData 

4 

5import torch 

6import pandas as pd 

7 

8from autoencodix.data._datasetcontainer import DatasetContainer 

9from autoencodix.data.datapackage import DataPackage 

10 

11from mudata import MuData # type: ignore 

12from ._traindynamics import TrainingDynamics 

13 

14 

15@dataclass 

16class LossRegistry: 

17 """Dataclass to store multiple TrainingDynamics objects for different losses. 

18 

19 Objective is to make the Result class extensible for multiple losses that might 

20 be needed in the future for diffrent autoencoder architectures. 

21 

22 Attributes: 

23 _losses: A dictionary to store TrainingDynamics objects for different losses. 

24 """ 

25 

26 _losses: Dict[str, TrainingDynamics] = field(default_factory=dict) 

27 

28 def __post_init__(self): 

29 for name, data in self._losses.items(): 

30 if not isinstance(data, TrainingDynamics): 

31 raise ValueError( 

32 f"Expected TrainingDynamics object, got {type(data)} for loss {name}" 

33 ) 

34 

35 def __eq__(self, other: object) -> bool: 

36 if not isinstance(other, LossRegistry): 

37 return False 

38 if self._losses.keys() != other._losses.keys(): 

39 return False 

40 return all(self._losses[key] == other._losses[key] for key in self._losses) 

41 

42 def add(self, data: Dict, split: str, epoch: int) -> None: 

43 for key, value in data.items(): 

44 if key not in self._losses: 

45 self._losses[key] = TrainingDynamics() 

46 self._losses[key].add(epoch=epoch, data=value, split=split) 

47 

48 def get(self, key: str) -> TrainingDynamics: 

49 if key not in self._losses: 

50 self._losses[key] = TrainingDynamics() 

51 return self._losses[key] 

52 

53 def losses(self): 

54 return self._losses.items() 

55 

56 def set(self, key: str, value: TrainingDynamics) -> None: 

57 self._losses[key] = value 

58 

59 def keys(self): 

60 return self._losses.keys() 

61 

62 

63@dataclass 

64class Result: 

65 """A dataclass to store results from the pipeline with predefined keys. 

66 

67 Attributes: 

68 latentspaces: TrainingDynamics object storing latent space representations for 'train', 'valid', and 'test' splits. 

69 sample_ids: TrainingDynamics object storing sample identifiers for 'train', 'valid', and 'test' splits. 

70 reconstructions: TrainingDynamics object storing reconstructed outputs for 'train', 'valid', and 'test' splits. 

71 mus: TrainingDynamics object storing mean values of latent distributions for 'train', 'valid', and 'test' splits. 

72 sigmas: TrainingDynamics object storing standard deviations of latent distributions for 'train', 'valid', and 'test' splits. 

73 losses: TrainingDynamics object storing the total loss for different epochs and splits ('train', 'valid', 'test'). 

74 sub_losses: LossRegistry object (extendable) for all sublosses. 

75 preprocessed_data: torch.Tensor containing data after preprocessing. 

76 model: final trained torch.nn.Module model. 

77 model_checkpoints: TrainingDynamics object storing model state at each checkpoint. 

78 datasets: Optional[DatasetContainer] containing train, valid, and test datasets. 

79 new_datasets: Optional[DatasetContainer] containing new train, valid, and test datasets. 

80 adata_latent: Optional[AnnData] containing latent representations as AnnData. 

81 final_reconstruction: Optional[Union[DataPackage, MuData]] containing final reconstruction results. 

82 sub_results: Optional[Dict[str, Any]] containing sub-results for multi-task or multi-modal models. 

83 sub_reconstructions: Optional[Dict[str, Any]] containing sub-reconstructions for multi-task or multi-modal models. 

84 embedding_evaluation: pd.DataFrame containing embedding evaluation results. 

85 """ 

86 

87 latentspaces: TrainingDynamics = field(default_factory=TrainingDynamics) 

88 sample_ids: TrainingDynamics = field(default_factory=TrainingDynamics) 

89 reconstructions: TrainingDynamics = field(default_factory=TrainingDynamics) 

90 mus: TrainingDynamics = field(default_factory=TrainingDynamics) 

91 sigmas: TrainingDynamics = field(default_factory=TrainingDynamics) 

92 losses: TrainingDynamics = field(default_factory=TrainingDynamics) 

93 sub_losses: LossRegistry = field(default_factory=LossRegistry) 

94 preprocessed_data: torch.Tensor = field(default_factory=torch.Tensor) 

95 model: Union[Dict[str, torch.nn.Module], torch.nn.Module] = field( 

96 default_factory=torch.nn.Module 

97 ) 

98 model_checkpoints: TrainingDynamics = field(default_factory=TrainingDynamics) 

99 

100 datasets: Optional[DatasetContainer] = field( 

101 default_factory=lambda: DatasetContainer(train=None, valid=None, test=None) 

102 ) 

103 new_datasets: Optional[DatasetContainer] = field( 

104 default_factory=lambda: DatasetContainer(train=None, valid=None, test=None) 

105 ) 

106 

107 adata_latent: Optional[AnnData] = field(default_factory=AnnData) 

108 final_reconstruction: Optional[ 

109 Union[DataPackage, MuData] # ty: ignore[invalid-type-form] 

110 ] = field(default=None) 

111 sub_results: Optional[Dict[str, Any]] = field(default=None) 

112 sub_reconstructions: Optional[Dict[str, Any]] = field(default=None) 

113 

114 # Embedding evaluation results 

115 embedding_evaluation: pd.DataFrame = field(default_factory=pd.DataFrame) 

116 embedding_attributions: pd.DataFrame = field(default_factory=pd.DataFrame) 

117 embedding_explanations: Dict[str, str] = field(default_factory=dict) 

118 

119 # plots: Dict[str, Any] = field( 

120 # default_factory=nested_dict 

121 # ) ## Nested dictionary of plots as figure handles 

122 

123 def __getitem__(self, key: str) -> Any: 

124 """Retrieve the value associated with a specific key. 

125 

126 Args: 

127 key: The name of the attribute to retrieve. 

128 Returns: 

129 The value of the specified attribute. 

130 Raises: 

131 KeyError - If the key is not a valid attribute of the Results class. 

132 

133 """ 

134 if not hasattr(self, key): 

135 raise KeyError( 

136 f"Invalid key: '{key}'. Allowed keys are: {', '.join(self.__annotations__.keys())}" 

137 ) 

138 return getattr(self, key) 

139 

140 def __setitem__(self, key: str, value: Any) -> None: 

141 """Assign a value to a specific attribute. 

142 

143 Args: 

144 key: The name of the attribute to set. 

145 value: The value to assign to the attribute. 

146 Raises: 

147 KeyError: If the key is not a valid attribute of the Results class. 

148 

149 """ 

150 if not hasattr(self, key): 

151 raise KeyError( 

152 f"Invalid key: '{key}'. Allowed keys are: {', '.join(self.__annotations__.keys())}" 

153 ) 

154 setattr(self, key, value) 

155 

156 def _is_empty_value(self, value: Any) -> bool: 

157 """ 

158 Helper method to check if an attribute of the Result object is empty. 

159 

160 Parameters: 

161 value (Any): The value to check 

162 Returns: 

163 bool: True if the value is empty, False otherwise 

164 

165 """ 

166 

167 if isinstance(value, TrainingDynamics): 

168 return len(value._data) == 0 

169 elif isinstance(value, torch.Tensor): 

170 return value.numel() == 0 

171 elif isinstance(value, torch.nn.Module): 

172 return sum(p.numel() for p in value.parameters()) == 0 

173 elif isinstance(value, DatasetContainer): 

174 return all(v is None for v in [value.train, value.valid, value.test]) 

175 elif isinstance(value, LossRegistry): 

176 # single Nones are handled in update method (skipped) 

177 return all(v is None for _, v in value.losses()) 

178 

179 return False 

180 

181 def update(self, other: "Result") -> None: 

182 """Update the current Result object with values from another Result object. 

183 

184 For TrainingDynamics, merges the data across epochs and splits and overwrites if already exists. 

185 For all other attributes, replaces the current value with the other value. 

186 

187 Args: 

188 other: The Result object to update from. 

189 Raises: 

190 TypeError: If the input object is not a Result instance 

191 

192 """ 

193 if not isinstance(other, Result): 

194 raise TypeError(f"Expected Result object, got {type(other)}") 

195 

196 for field_name in self.__annotations__.keys(): 

197 current_value = getattr(self, field_name) 

198 other_value = getattr(other, field_name) 

199 if self._is_empty_value(other_value): 

200 continue 

201 

202 # Handle TrainingDynamics - merge data 

203 if isinstance(current_value, TrainingDynamics): 

204 current_value = self._update_traindynamics(current_value, other_value) 

205 # For all other types - replace with other value 

206 if isinstance(current_value, LossRegistry): 

207 for key, value in other_value.losses(): 

208 if value is None: 

209 continue 

210 if not isinstance(value, TrainingDynamics): 

211 raise ValueError( 

212 f"Expected TrainingDynamics object, got {type(value)}" 

213 ) 

214 updated_dynamic = self._update_traindynamics( 

215 current_value=current_value.get(key=key), other_value=value 

216 ) 

217 current_value.set(key=key, value=updated_dynamic) 

218 else: 

219 setattr(self, field_name, other_value) 

220 

221 def _update_traindynamics( 

222 self, current_value: TrainingDynamics, other_value: TrainingDynamics 

223 ) -> TrainingDynamics: 

224 """Update TrainingDynamics object with values from another TrainingDynamics object. 

225 

226 Args: 

227 current_value: The current TrainingDynamics object to update. 

228 other_value: The TrainingDynamics object to update from. 

229 

230 Returns: 

231 Updated TrainingDynamics object. 

232 

233 Examples: 

234 >>> current = TrainingDynamics() 

235 >>> current._data = {1: {"train": np.array([1, 2, 3])}, 

236 ... 2: None} 

237 

238 >>> other = TrainingDynamics() 

239 >>> other._data = {1: {"train": np.array([4, 5, 6])}, 

240 ... 2: {"train": np.array([7, 8, 9])}} 

241 >>> # after update 

242 >>> print(current._data) 

243 {1: {"train": np.array([4, 5, 6])}, # updated 

244 2: {"train": np.array([7, 8, 9])}} # kept, because other was None 

245 

246 """ 

247 

248 if current_value is None: 

249 return other_value 

250 if current_value._data is None: 

251 return other_value 

252 

253 for epoch, split_data in other_value._data.items(): 

254 if split_data is None: 

255 continue 

256 if len(split_data) == 0: 

257 continue 

258 

259 # If current epoch is None, it should be updated 

260 if epoch in current_value._data and current_value._data[epoch] is None: 

261 current_value._data[epoch] = {} 

262 for split, data in split_data.items(): 

263 if data is None: 

264 continue 

265 current_value.add(epoch=epoch, data=data, split=split) 

266 continue 

267 

268 if epoch not in current_value._data: 

269 for split, data in split_data.items(): 

270 if data is None: 

271 continue 

272 current_value.add(epoch=epoch, data=data, split=split) 

273 continue 

274 # case when current epoch exists, then update all but None values 

275 for split, value in split_data.items(): 

276 if value is not None: 

277 current_value.add(epoch=epoch, data=value, split=split) 

278 

279 # Ensure ordering 

280 current_value._data = dict(sorted(current_value._data.items())) 

281 

282 return current_value 

283 

284 def __str__(self) -> str: 

285 """Provide a readable string representation of the Result object's public attributes. 

286 

287 Returns: 

288 Formatted string showing all public attributes and their values 

289 """ 

290 output = ["Result Object Public Attributes:", "-" * 30] 

291 

292 for name in self.__annotations__.keys(): 

293 if name.startswith("__"): 

294 continue 

295 

296 value = getattr(self, name) 

297 if isinstance(value, TrainingDynamics): 

298 output.append(f"{name}: TrainingDynamics object") 

299 elif isinstance(value, torch.nn.Module): 

300 output.append(f"{name}: {value.__class__.__name__}") 

301 elif isinstance(value, dict): 

302 output.append(f"{name}: Dict with {len(value)} items") 

303 elif isinstance(value, torch.Tensor): 

304 output.append(f"{name}: Tensor of shape {tuple(value.shape)}") 

305 else: 

306 output.append(f"{name}: {value}") 

307 

308 return "\n".join(output) 

309 

310 def __repr__(self) -> str: 

311 """Return the same representation as __str__ for consistency.""" 

312 return self.__str__() 

313 

314 def get_latent_df( 

315 self, epoch: int, split: str, modality: Optional[str] = None 

316 ) -> pd.DataFrame: 

317 """Return latent representations as a DataFrame. 

318 

319 Retrieves latent vectors and their corresponding sample IDs for a given 

320 epoch and data split. If a specific modality is provided, the results 

321 are restricted to that modality. Column names are inferred from model 

322 ontologies if available; otherwise, generic latent dimension labels are 

323 used. 

324 

325 Args: 

326 epoch: The epoch number to retrieve latents from. 

327 split: The dataset split to query (e.g., "train", "valid", "test"). 

328 modality: Optional modality name to filter the latents and sample IDs. 

329 

330 Returns: 

331 A DataFrame where rows correspond to samples, columns represent latent 

332 dimensions, and the index contains sample IDs. 

333 """ 

334 try: 

335 latents = self.latentspaces.get(epoch=epoch, split=split) 

336 ids = self.sample_ids.get(epoch=epoch, split=split) 

337 if modality is not None: # for x-modalix and other multi-modal models 

338 latents = latents[modality] 

339 ids = ids[modality] 

340 if hasattr(self.model, "ontologies") and self.model.ontologies is not None: 

341 cols = list(self.model.ontologies[0].keys()) 

342 else: 

343 cols = ["LatDim_" + str(i) for i in range(latents.shape[1])] 

344 return pd.DataFrame(latents, index=ids, columns=cols) 

345 except Exception as e: 

346 import warnings 

347 

348 warnings.warn( 

349 f"Could not retrieve latent representations for epoch {epoch} and split '{split}'. " 

350 f"Returning empty DataFrame. This may be due to missing data in the Result object or incorrect keys.\n\n" 

351 ) 

352 

353 return pd.DataFrame() 

354 

355 def get_reconstructions_df( 

356 self, epoch: int, split: str, modality: Optional[str] = None 

357 ) -> pd.DataFrame: 

358 """Return reconstructions as a DataFrame. 

359 

360 Retrieves reconstructed features and their corresponding sample IDs for a 

361 given epoch and data split. If a specific modality is provided, the results 

362 are restricted to that modality. Column names are based on the dataset's 

363 feature identifiers. 

364 

365 Args: 

366 epoch: The epoch number to retrieve reconstructions from. 

367 split: The dataset split to query (e.g., "train", "valid", "test"). 

368 modality: Optional modality name to filter the reconstructions and 

369 sample IDs. 

370 

371 Returns: 

372 A DataFrame where rows correspond to samples, columns represent 

373 reconstructed features, and the index contains sample IDs. 

374 """ 

375 reconstructions = self.reconstructions.get(epoch=epoch, split=split) 

376 ids = self.sample_ids.get(epoch=epoch, split=split) 

377 if modality is not None: 

378 reconstructions = reconstructions[modality] 

379 ids = ids[modality] 

380 

381 if self.new_datasets is not None: 

382 datasets = self.new_datasets 

383 

384 # cols = self.datasets.train.feature_ids 

385 if split == "train" and datasets is not None and datasets.train is not None: 

386 cols = datasets.train.feature_ids 

387 elif split == "valid" and datasets is not None and datasets.valid is not None: 

388 cols = datasets.valid.feature_ids 

389 elif split == "test" and datasets is not None and datasets.test is not None: 

390 cols = datasets.test.feature_ids 

391 else: 

392 cols = [f"Feature_{i}" for i in range(reconstructions.shape[1])] 

393 return pd.DataFrame(reconstructions, index=ids, columns=cols)