Coverage for src / autoencodix / utils / adata_converter.py: 32%
47 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Literal, Dict, Optional
2from autoencodix.data._multimodal_dataset import MultiModalDataset
3from autoencodix.data._numeric_dataset import NumericDataset
4from autoencodix.data._datasetcontainer import DatasetContainer
5import anndata as ad
6import pandas as pd
7import torch
8from scipy.sparse import issparse
11class AnnDataConverter:
12 """Utility class for converting datasets into AnnData or multimodal AnnData dictionaries."""
14 @staticmethod
15 def _numeric_ds_to_adata(ds: NumericDataset) -> Dict[str, ad.AnnData]:
16 """Convert a NumericDataset to an AnnData object.
18 Args:
19 ds: The numeric dataset to convert.
21 Returns:
22 An AnnData object containing the dataset's data, features, and metadata.
23 """
25 metadata = ds.metadata.copy()
26 if isinstance(metadata, dict):
27 if len(metadata) != 1:
28 raise NotImplementedError(
29 "Unpaired metadata conversion not implemented yet. Expected single modality."
30 )
31 first_key = next(iter(metadata.keys()))
32 metadata = metadata[first_key]
33 if not isinstance(metadata, pd.DataFrame):
34 raise ValueError(f"metadata needs to be pd.DataFrame, got {type(metadata)}")
35 metadata.index = metadata.index.astype(str)
37 var = pd.DataFrame(index=pd.Index(ds.feature_ids, dtype=str))
38 # check if ds.data issparse
39 if issparse(ds.data):
40 x = torch.tensor(ds.data.toarray())
41 else:
42 x = ds.data
43 return {
44 "global": ad.AnnData(
45 X=x.clone().detach().cpu().numpy(),
46 var=var,
47 obs=metadata,
48 )
49 }
51 @staticmethod
52 def _parse_multimodal(mds: MultiModalDataset) -> Dict[str, ad.AnnData]:
53 """Convert a MultiModalDataset into a dictionary of AnnData objects.
55 Args:
56 mds: The multimodal dataset to convert.
58 Returns:
59 A dictionary mapping modality names to AnnData objects.
61 Raises:
62 NotImplementedError: If any modality is not a NumericDataset.
63 """
64 result_dict: Dict[str, ad.AnnData] = {}
65 for mod_name, dataset in mds.datasets.items():
66 if not isinstance(dataset, NumericDataset):
67 raise NotImplementedError(
68 f"Feature Importance is only implemented for NumericDataset, got type: {type(dataset)}"
69 )
70 result_dict[mod_name] = AnnDataConverter._numeric_ds_to_adata(dataset) # type: ignore
71 return result_dict
73 @staticmethod
74 def dataset_to_adata(
75 datasetcontainer: DatasetContainer,
76 split: Literal["train", "valid", "test"] = "train",
77 ) -> Optional[Dict[str, ad.AnnData]]:
78 """Convert a DatasetContainer split to an AnnData or multimodal AnnData dictionary.
80 Args:
81 datasetcontainer: Container holding train/valid/test datasets.
82 split: The dataset split to convert. Defaults to "train".
84 Returns:
85 A single AnnData object (for NumericDataset) or a dictionary of AnnData objects (for MultiModalDataset).
87 Raises:
88 ValueError: If the specified split does not exist in the DatasetContainer.
89 NotImplementedError: If the dataset type is not supported.
90 """
91 if not hasattr(datasetcontainer, split):
92 raise ValueError(
93 f"Split: {split} not present in DatasetContainer: {datasetcontainer}"
94 )
96 ds = datasetcontainer[split]
98 if isinstance(ds, MultiModalDataset):
99 return AnnDataConverter._parse_multimodal(ds)
100 elif isinstance(ds, NumericDataset):
101 return AnnDataConverter._numeric_ds_to_adata(ds)
102 elif ds is None:
103 import warnings
105 warnings.warn(f"No dataset found for split: {split}, returning None")
106 return None
108 else:
109 raise NotImplementedError(
110 f"Conversion not implemented for type: {type(ds)}"
111 )