Coverage for src / autoencodix / data / _nanremover.py: 13%

105 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import anndata as ad # type: ignore 

2from pandas.api.types import is_numeric_dtype, is_categorical_dtype 

3 

4import warnings 

5import pandas as pd 

6import mudata as md # type: ignore 

7import numpy as np 

8from scipy.sparse import issparse # type: ignore 

9from scipy import sparse # type: ignore 

10 

11from autoencodix.data.datapackage import DataPackage 

12from autoencodix.configs.default_config import DefaultConfig 

13 

14 

15class NaNRemover: 

16 """Removes NaN values from multi-modal datasets. 

17 

18 This object identifies and removes NaN values from various data structures 

19 commonly used in single-cell and multi-modal omics, including AnnData, MuData, 

20 and Pandas DataFrames. It supports processing of X matrices, layers, and 

21 observation annotations within AnnData objects, as well as handling bulk and 

22 annotation data within a DataPackage. 

23 

24 Attributes: 

25 config: Configuration object containing settings for data processing. 

26 relevant_cols: List of columns in metadata to check for NaNs. 

27 """ 

28 

29 def __init__( 

30 self, 

31 config: DefaultConfig, 

32 ): 

33 """Initialize the NaNRemover with configuration settings. 

34 Args: 

35 config: Configuration object containing settings for data processing. 

36 

37 """ 

38 self.config = config 

39 self.relevant_cols = self.config.data_config.annotation_columns 

40 

41 def _process_modality(self, adata: ad.AnnData) -> ad.AnnData: 

42 """Converts NaN values in AnnData object to zero and metadata NaNs to 'missing'. 

43 Args: 

44 adata: The AnnData object to process. 

45 Returns: 

46 The processed AnnData object with NaN values replaced. 

47 """ 

48 adata = adata.copy() 

49 

50 # Handle X matrix 

51 if sparse.issparse(adata.X): 

52 if hasattr(adata.X, "data"): 

53 adata.X.data = np.nan_to_num( # ty: ignore 

54 adata.X.data, nan=0.0 

55 ) # ty: ignore[invalid-assignment] 

56 adata.X.eliminate_zeros() # ty: ignore 

57 else: 

58 adata.X = np.nan_to_num(adata.X, nan=0.0) 

59 

60 # Handle all layers 

61 for layer_name, layer_data in adata.layers.items(): 

62 if sparse.issparse(layer_data): 

63 if hasattr(layer_data, "data"): 

64 layer_data.data = np.nan_to_num(layer_data.data, nan=0.0) 

65 layer_data.eliminate_zeros() 

66 else: 

67 adata.layers[layer_name] = np.nan_to_num(layer_data, nan=0.0) 

68 

69 # Handle obs metadata 

70 if not self.relevant_cols: 

71 return adata 

72 

73 for col in self.relevant_cols: 

74 if col not in adata.obs.columns: 

75 warnings.warn(f"Column {col} not found in obs.") 

76 continue 

77 s = adata.obs[col] 

78 

79 if is_numeric_dtype(s): 

80 adata.obs[col] = np.nan_to_num(s, nan=0.0) 

81 continue 

82 if not is_categorical_dtype(s): 

83 s = s.astype("category") 

84 

85 if "missing" not in s.cat.categories: 

86 s = s.cat.add_categories(["missing"]) 

87 s = s.fillna("missing") 

88 adata.obs[col] = s 

89 

90 return adata 

91 

92 def remove_nan(self, data: DataPackage) -> DataPackage: 

93 """Removes NaN values from all applicable DataPackage components. 

94 

95 Iterates through the bulk data, annotation data, and multi-modal 

96 single-cell data (MuData and AnnData objects) within the provided 

97 DataPackage and removes rows/columns/entries containing NaN values. 

98 

99 Args: 

100 data: The DataPackage object containing multi-modal data. 

101 

102 Returns: 

103 The DataPackage object with NaN values removed from its components. 

104 """ 

105 # Handle bulk data 

106 if data.multi_bulk: 

107 for key, df in data.multi_bulk.items(): 

108 data.multi_bulk[key] = df.dropna(axis=1) 

109 

110 # Handle annotation data 

111 if data.annotation is not None: 

112 non_na = {} 

113 for k, v in data.annotation.items(): 

114 if v is None: 

115 continue 

116 if self.relevant_cols is not None: 

117 for col in self.relevant_cols: 

118 # Fill with "missing" if column is not integer or float 

119 if col in v.columns and not pd.api.types.is_numeric_dtype( 

120 v[col] 

121 ): 

122 v.fillna(value={col: "missing"}, inplace=True) 

123 

124 non_na[k] = v 

125 data.annotation = non_na # type: ignore 

126 

127 # Handle MuData in multi_sc 

128 if data.multi_sc is not None and self.config.requires_paired: 

129 mudata = data.multi_sc["multi_sc"] 

130 # Process each modality 

131 for mod_name, mod_data in mudata.mod.items(): 

132 processed_mod = self._process_modality(adata=mod_data) 

133 data.multi_sc["multi_sc"].mod[mod_name] = processed_mod 

134 

135 elif data.multi_sc is not None: 

136 print(f"data in multi_sc: {data.multi_sc}") 

137 processed = {k: None for k, _ in data.multi_sc.items()} 

138 

139 for k, v in data.multi_sc.items(): 

140 if isinstance(v, dict): 

141 for sub_k, sub_v in v.items(): 

142 processed_mod = self._process_modality(adata=sub_v) 

143 processed_mod = md.MuData({sub_k: processed_mod}) 

144 processed[sub_k] = processed_mod 

145 

146 else: 

147 for modkey, adata in v.mod.items(): 

148 processed_mod = self._process_modality(adata=adata) 

149 processed_mod = md.MuData({modkey: processed_mod}) 

150 processed[k] = processed_mod 

151 processed_clean = {k: v for k, v in processed.items() if v} 

152 data.multi_sc = processed_clean 

153 

154 # Handle from_modality and to_modality (for translation cases) 

155 for direction in ["from_modality", "to_modality"]: 

156 modality_dict = getattr(data, direction) 

157 if not modality_dict: 

158 continue 

159 

160 for mod_key, mod_value in modality_dict.items(): 

161 # Handle MuData objects - use the proper import 

162 if isinstance(mod_value, md.MuData): 

163 # Process each modality in the MuData 

164 for inner_mod_name, inner_mod_data in mod_value.mod.items(): 

165 processed_mod = self._process_modality(inner_mod_data) 

166 mod_value.mod[inner_mod_name] = processed_mod 

167 

168 # Ensure cell alignment if there are multiple modalities 

169 if len(mod_value.mod) > 1: 

170 common_cells = list( 

171 set.intersection( 

172 *(set(mod.obs_names) for mod in mod_value.mod.values()) 

173 ) 

174 ) 

175 mod_value = mod_value[common_cells] 

176 

177 modality_dict[mod_key] = mod_value 

178 

179 # Handle AnnData objects directly 

180 elif isinstance(mod_value, ad.AnnData): 

181 processed_mod = self._process_modality(mod_value) 

182 modality_dict[mod_key] = processed_mod 

183 

184 # Handle other types of data (e.g., dictionaries of AnnData objects) 

185 elif isinstance(mod_value, dict): 

186 for sub_key, sub_value in mod_value.items(): 

187 if isinstance(sub_value, ad.AnnData): 

188 processed_mod = self._process_modality(sub_value) 

189 mod_value[sub_key] = processed_mod 

190 

191 elif isinstance(mod_value, pd.DataFrame): 

192 mod_value.dropna(axis=1, inplace=True) 

193 modality_dict[mod_key] = mod_value 

194 

195 else: 

196 warnings.warn( 

197 f"Skipping unknown type in {direction}.{mod_key}: {type(mod_value)}" 

198 ) 

199 

200 return data