Coverage for src / autoencodix / utils / adata_converter.py: 32%

47 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Literal, Dict, Optional 

2from autoencodix.data._multimodal_dataset import MultiModalDataset 

3from autoencodix.data._numeric_dataset import NumericDataset 

4from autoencodix.data._datasetcontainer import DatasetContainer 

5import anndata as ad 

6import pandas as pd 

7import torch 

8from scipy.sparse import issparse 

9 

10 

11class AnnDataConverter: 

12 """Utility class for converting datasets into AnnData or multimodal AnnData dictionaries.""" 

13 

14 @staticmethod 

15 def _numeric_ds_to_adata(ds: NumericDataset) -> Dict[str, ad.AnnData]: 

16 """Convert a NumericDataset to an AnnData object. 

17 

18 Args: 

19 ds: The numeric dataset to convert. 

20 

21 Returns: 

22 An AnnData object containing the dataset's data, features, and metadata. 

23 """ 

24 

25 metadata = ds.metadata.copy() 

26 if isinstance(metadata, dict): 

27 if len(metadata) != 1: 

28 raise NotImplementedError( 

29 "Unpaired metadata conversion not implemented yet. Expected single modality." 

30 ) 

31 first_key = next(iter(metadata.keys())) 

32 metadata = metadata[first_key] 

33 if not isinstance(metadata, pd.DataFrame): 

34 raise ValueError(f"metadata needs to be pd.DataFrame, got {type(metadata)}") 

35 metadata.index = metadata.index.astype(str) 

36 

37 var = pd.DataFrame(index=pd.Index(ds.feature_ids, dtype=str)) 

38 # check if ds.data issparse 

39 if issparse(ds.data): 

40 x = torch.tensor(ds.data.toarray()) 

41 else: 

42 x = ds.data 

43 return { 

44 "global": ad.AnnData( 

45 X=x.clone().detach().cpu().numpy(), 

46 var=var, 

47 obs=metadata, 

48 ) 

49 } 

50 

51 @staticmethod 

52 def _parse_multimodal(mds: MultiModalDataset) -> Dict[str, ad.AnnData]: 

53 """Convert a MultiModalDataset into a dictionary of AnnData objects. 

54 

55 Args: 

56 mds: The multimodal dataset to convert. 

57 

58 Returns: 

59 A dictionary mapping modality names to AnnData objects. 

60 

61 Raises: 

62 NotImplementedError: If any modality is not a NumericDataset. 

63 """ 

64 result_dict: Dict[str, ad.AnnData] = {} 

65 for mod_name, dataset in mds.datasets.items(): 

66 if not isinstance(dataset, NumericDataset): 

67 raise NotImplementedError( 

68 f"Feature Importance is only implemented for NumericDataset, got type: {type(dataset)}" 

69 ) 

70 result_dict[mod_name] = AnnDataConverter._numeric_ds_to_adata(dataset) # type: ignore 

71 return result_dict 

72 

73 @staticmethod 

74 def dataset_to_adata( 

75 datasetcontainer: DatasetContainer, 

76 split: Literal["train", "valid", "test"] = "train", 

77 ) -> Optional[Dict[str, ad.AnnData]]: 

78 """Convert a DatasetContainer split to an AnnData or multimodal AnnData dictionary. 

79 

80 Args: 

81 datasetcontainer: Container holding train/valid/test datasets. 

82 split: The dataset split to convert. Defaults to "train". 

83 

84 Returns: 

85 A single AnnData object (for NumericDataset) or a dictionary of AnnData objects (for MultiModalDataset). 

86 

87 Raises: 

88 ValueError: If the specified split does not exist in the DatasetContainer. 

89 NotImplementedError: If the dataset type is not supported. 

90 """ 

91 if not hasattr(datasetcontainer, split): 

92 raise ValueError( 

93 f"Split: {split} not present in DatasetContainer: {datasetcontainer}" 

94 ) 

95 

96 ds = datasetcontainer[split] 

97 

98 if isinstance(ds, MultiModalDataset): 

99 return AnnDataConverter._parse_multimodal(ds) 

100 elif isinstance(ds, NumericDataset): 

101 return AnnDataConverter._numeric_ds_to_adata(ds) 

102 elif ds is None: 

103 import warnings 

104 

105 warnings.warn(f"No dataset found for split: {split}, returning None") 

106 return None 

107 

108 else: 

109 raise NotImplementedError( 

110 f"Conversion not implemented for type: {type(ds)}" 

111 )