Coverage for src / autoencodix / data / _nanremover.py: 13%
105 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import anndata as ad # type: ignore
2from pandas.api.types import is_numeric_dtype, is_categorical_dtype
4import warnings
5import pandas as pd
6import mudata as md # type: ignore
7import numpy as np
8from scipy.sparse import issparse # type: ignore
9from scipy import sparse # type: ignore
11from autoencodix.data.datapackage import DataPackage
12from autoencodix.configs.default_config import DefaultConfig
15class NaNRemover:
16 """Removes NaN values from multi-modal datasets.
18 This object identifies and removes NaN values from various data structures
19 commonly used in single-cell and multi-modal omics, including AnnData, MuData,
20 and Pandas DataFrames. It supports processing of X matrices, layers, and
21 observation annotations within AnnData objects, as well as handling bulk and
22 annotation data within a DataPackage.
24 Attributes:
25 config: Configuration object containing settings for data processing.
26 relevant_cols: List of columns in metadata to check for NaNs.
27 """
29 def __init__(
30 self,
31 config: DefaultConfig,
32 ):
33 """Initialize the NaNRemover with configuration settings.
34 Args:
35 config: Configuration object containing settings for data processing.
37 """
38 self.config = config
39 self.relevant_cols = self.config.data_config.annotation_columns
41 def _process_modality(self, adata: ad.AnnData) -> ad.AnnData:
42 """Converts NaN values in AnnData object to zero and metadata NaNs to 'missing'.
43 Args:
44 adata: The AnnData object to process.
45 Returns:
46 The processed AnnData object with NaN values replaced.
47 """
48 adata = adata.copy()
50 # Handle X matrix
51 if sparse.issparse(adata.X):
52 if hasattr(adata.X, "data"):
53 adata.X.data = np.nan_to_num( # ty: ignore
54 adata.X.data, nan=0.0
55 ) # ty: ignore[invalid-assignment]
56 adata.X.eliminate_zeros() # ty: ignore
57 else:
58 adata.X = np.nan_to_num(adata.X, nan=0.0)
60 # Handle all layers
61 for layer_name, layer_data in adata.layers.items():
62 if sparse.issparse(layer_data):
63 if hasattr(layer_data, "data"):
64 layer_data.data = np.nan_to_num(layer_data.data, nan=0.0)
65 layer_data.eliminate_zeros()
66 else:
67 adata.layers[layer_name] = np.nan_to_num(layer_data, nan=0.0)
69 # Handle obs metadata
70 if not self.relevant_cols:
71 return adata
73 for col in self.relevant_cols:
74 if col not in adata.obs.columns:
75 warnings.warn(f"Column {col} not found in obs.")
76 continue
77 s = adata.obs[col]
79 if is_numeric_dtype(s):
80 adata.obs[col] = np.nan_to_num(s, nan=0.0)
81 continue
82 if not is_categorical_dtype(s):
83 s = s.astype("category")
85 if "missing" not in s.cat.categories:
86 s = s.cat.add_categories(["missing"])
87 s = s.fillna("missing")
88 adata.obs[col] = s
90 return adata
92 def remove_nan(self, data: DataPackage) -> DataPackage:
93 """Removes NaN values from all applicable DataPackage components.
95 Iterates through the bulk data, annotation data, and multi-modal
96 single-cell data (MuData and AnnData objects) within the provided
97 DataPackage and removes rows/columns/entries containing NaN values.
99 Args:
100 data: The DataPackage object containing multi-modal data.
102 Returns:
103 The DataPackage object with NaN values removed from its components.
104 """
105 # Handle bulk data
106 if data.multi_bulk:
107 for key, df in data.multi_bulk.items():
108 data.multi_bulk[key] = df.dropna(axis=1)
110 # Handle annotation data
111 if data.annotation is not None:
112 non_na = {}
113 for k, v in data.annotation.items():
114 if v is None:
115 continue
116 if self.relevant_cols is not None:
117 for col in self.relevant_cols:
118 # Fill with "missing" if column is not integer or float
119 if col in v.columns and not pd.api.types.is_numeric_dtype(
120 v[col]
121 ):
122 v.fillna(value={col: "missing"}, inplace=True)
124 non_na[k] = v
125 data.annotation = non_na # type: ignore
127 # Handle MuData in multi_sc
128 if data.multi_sc is not None and self.config.requires_paired:
129 mudata = data.multi_sc["multi_sc"]
130 # Process each modality
131 for mod_name, mod_data in mudata.mod.items():
132 processed_mod = self._process_modality(adata=mod_data)
133 data.multi_sc["multi_sc"].mod[mod_name] = processed_mod
135 elif data.multi_sc is not None:
136 print(f"data in multi_sc: {data.multi_sc}")
137 processed = {k: None for k, _ in data.multi_sc.items()}
139 for k, v in data.multi_sc.items():
140 if isinstance(v, dict):
141 for sub_k, sub_v in v.items():
142 processed_mod = self._process_modality(adata=sub_v)
143 processed_mod = md.MuData({sub_k: processed_mod})
144 processed[sub_k] = processed_mod
146 else:
147 for modkey, adata in v.mod.items():
148 processed_mod = self._process_modality(adata=adata)
149 processed_mod = md.MuData({modkey: processed_mod})
150 processed[k] = processed_mod
151 processed_clean = {k: v for k, v in processed.items() if v}
152 data.multi_sc = processed_clean
154 # Handle from_modality and to_modality (for translation cases)
155 for direction in ["from_modality", "to_modality"]:
156 modality_dict = getattr(data, direction)
157 if not modality_dict:
158 continue
160 for mod_key, mod_value in modality_dict.items():
161 # Handle MuData objects - use the proper import
162 if isinstance(mod_value, md.MuData):
163 # Process each modality in the MuData
164 for inner_mod_name, inner_mod_data in mod_value.mod.items():
165 processed_mod = self._process_modality(inner_mod_data)
166 mod_value.mod[inner_mod_name] = processed_mod
168 # Ensure cell alignment if there are multiple modalities
169 if len(mod_value.mod) > 1:
170 common_cells = list(
171 set.intersection(
172 *(set(mod.obs_names) for mod in mod_value.mod.values())
173 )
174 )
175 mod_value = mod_value[common_cells]
177 modality_dict[mod_key] = mod_value
179 # Handle AnnData objects directly
180 elif isinstance(mod_value, ad.AnnData):
181 processed_mod = self._process_modality(mod_value)
182 modality_dict[mod_key] = processed_mod
184 # Handle other types of data (e.g., dictionaries of AnnData objects)
185 elif isinstance(mod_value, dict):
186 for sub_key, sub_value in mod_value.items():
187 if isinstance(sub_value, ad.AnnData):
188 processed_mod = self._process_modality(sub_value)
189 mod_value[sub_key] = processed_mod
191 elif isinstance(mod_value, pd.DataFrame):
192 mod_value.dropna(axis=1, inplace=True)
193 modality_dict[mod_key] = mod_value
195 else:
196 warnings.warn(
197 f"Skipping unknown type in {direction}.{mod_key}: {type(mod_value)}"
198 )
200 return data