Coverage for src / autoencodix / data / _filter.py: 82%

136 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1""" 

2Ignoring types for scipy and sklearn, because it stubs require Python > 3.9, which we want to allow 

3""" 

4 

5import pandas as pd 

6import warnings 

7import numpy as np 

8from typing import Optional, List, Set, Tuple, Any, Union 

9from scipy.stats import median_abs_deviation # type: ignore 

10from sklearn.preprocessing import ( # type: ignore 

11 MinMaxScaler, 

12 StandardScaler, 

13 RobustScaler, 

14 MaxAbsScaler, 

15) # type: ignore 

16from enum import Enum 

17from autoencodix.configs.default_config import DataInfo, DefaultConfig 

18from sklearn.cluster import AgglomerativeClustering # type: ignore 

19from scipy.spatial.distance import pdist, squareform # type: ignore 

20 

21 

22class FilterMethod(Enum): 

23 """Supported filtering methods. 

24 

25 Attributes: 

26 VARCORR: Filter by variance and correlation. 

27 NOFILT No filtering. 

28 VAR: Filter by variance. 

29 MAD: Filter by median absolute deviation. 

30 NONZEROVAR: Filter by non-zero variance. 

31 CORR: Filter by correlation. 

32 """ 

33 

34 VARCORR = "VARCORR" 

35 NOFILT = "NOFILT" 

36 VAR = "VAR" 

37 MAD = "MAD" 

38 NONZEROVAR = "NONZEROVAR" 

39 CORR = "CORR" 

40 

41 

42class DataFilter: 

43 """Preprocesses dataframes, including filtering and scaling. 

44 

45 This class separates the filtering logic that needs to be applied consistently 

46 across train, validation, and test sets from the scaling logic that is 

47 typically fitted on the training data and then applied to the other sets. 

48 

49 Attributes: 

50 data_info: Configuration object containing preprocessing parameters. 

51 filtered_features: Set of features to keep after filtering on the training data. None initially. 

52 _scaler: The fitted scaler object. None initially. 

53 ontologies: Ontology information, if provided for Ontix. 

54 config: Configuration object containing default parameters. 

55 """ 

56 

57 def __init__( 

58 self, 

59 data_info: DataInfo, 

60 config: DefaultConfig, 

61 ontologies: Optional[tuple] = None, 

62 ): # Addition to Varix, mandotory for Ontix 

63 """Initializes the DataFilter with a configuration. 

64 

65 Args: 

66 data_info: Configuration object containing preprocessing parameters. 

67 config: Configuration object containing default parameters. 

68 ontologies: Ontology information, if provided for Ontix. 

69 """ 

70 self.data_info = data_info 

71 self.config = config 

72 self.filtered_features: Optional[Set[str]] = None 

73 self._scaler = None 

74 self.ontologies = ontologies # Addition to Varix, mandotory for Ontix 

75 self._init_scaler() 

76 

77 def _filter_nonzero_variance(self, df: pd.DataFrame) -> pd.Series: 

78 """Removes features with zero variance. 

79 

80 Args: 

81 df: Input dataframe. 

82 

83 Returns: 

84 Filtered dataframe containing only columns with non-zero variance. 

85 """ 

86 var = pd.Series(np.var(df, axis=0), index=df.columns) 

87 return df[var[var > 0].index] 

88 

89 def _filter_by_variance( 

90 self, df: pd.DataFrame, k: Optional[int] 

91 ) -> Union[pd.Series, pd.DataFrame]: 

92 """Keeps top k features by variance. 

93 

94 Args: 

95 df: Input dataframe. 

96 k: Number of top variance features to keep. If None or greater 

97 than number of columns, all features are kept. 

98 

99 Returns: 

100 Filtered dataframe with top k variance features. 

101 """ 

102 if k is None or k > df.shape[1]: 

103 warnings.warn( 

104 "WARNING: k is None or greater than number of columns, keeping all features." 

105 ) 

106 return df 

107 var = pd.Series(np.var(df, axis=0), index=df.columns) 

108 return df[var.sort_values(ascending=False).index[:k]] 

109 

110 def _filter_by_mad( 

111 self, df: pd.DataFrame, k: Optional[int] 

112 ) -> Union[pd.Series, pd.DataFrame]: 

113 """Keeps top k features by median absolute deviation. 

114 

115 Args: 

116 df: Input dataframe. 

117 k: Number of top MAD features to keep. If None or greater 

118 than number of columns, all features are kept. 

119 

120 Returns: 

121 Filtered dataframe with top k MAD features. 

122 """ 

123 if k is None or k > df.shape[1]: 

124 return df 

125 mads = pd.Series(median_abs_deviation(df, axis=0), index=df.columns) 

126 return df[mads.sort_values(ascending=False).index[:k]] 

127 

128 def _filter_by_correlation( 

129 self, df: pd.DataFrame, k: Optional[int] 

130 ) -> Union[pd.Series, pd.DataFrame]: 

131 """Filters features using correlation-based clustering. 

132 

133 This method clusters features based on their correlation distance and 

134 selects a representative feature (medoid) from each cluster. 

135 

136 Args: 

137 df: Input dataframe. 

138 k: Number of clusters to create. If None or greater 

139 than number of columns, all features are kept. 

140 

141 Returns: 

142 Filtered dataframe with one representative feature (medoid) per cluster. 

143 """ 

144 if k is None or k > df.shape[1]: 

145 warnings.warn( 

146 "WARNING: k is None or greater than number of columns, keeping all features." 

147 ) 

148 return df 

149 else: 

150 X = df.transpose().values 

151 

152 dist_matrix = squareform(pdist(X, metric="correlation")) 

153 

154 clustering = AgglomerativeClustering( 

155 n_clusters=k, 

156 ).fit(dist_matrix) 

157 

158 medoid_indices = [] 

159 for i in range(k): 

160 cluster_points = np.where(clustering.labels_ == i)[0] 

161 if len(cluster_points) > 0: 

162 # The medoid is the point with minimum sum of distances to other points in the cluster 

163 cluster_dist_matrix = dist_matrix[ 

164 np.ix_(cluster_points, cluster_points) 

165 ] 

166 sum_distances = np.sum(cluster_dist_matrix, axis=1) 

167 medoid_idx = cluster_points[np.argmin(sum_distances)] 

168 medoid_indices.append(medoid_idx) 

169 

170 df_filt: Union[pd.DataFrame, pd.Series] = df.iloc[:, medoid_indices] 

171 return df_filt 

172 

173 def filter( 

174 self, df: pd.DataFrame, genes_to_keep: Optional[List] = None 

175 ) -> Tuple[Union[pd.Series, pd.DataFrame], List[str]]: 

176 """Applies the configured filtering method to the dataframe. 

177 

178 This method is intended to be called on the training data to determine 

179 which features to keep. The `filtered_features` attribute will be set 

180 based on the result. 

181 

182 Args: 

183 df: Input dataframe to be filtered (typically the training set). 

184 genes_to_keep: A list of gene names to explicitly keep. 

185 If provided, other filtering methods will be ignored. 

186 

187 Returns: 

188 A tuple containing: 

189 - The filtered dataframe. 

190 - A list of column names (features) that were kept. 

191 

192 Raises: 

193 KeyError: If some genes in `genes_to_keep` are not present in the dataframe. 

194 """ 

195 if genes_to_keep is not None: 

196 try: 

197 df: Union[pd.Series, pd.DataFrame] = df[genes_to_keep] 

198 return df, genes_to_keep 

199 except KeyError as e: 

200 raise KeyError( 

201 f"Some genes in genes_to_keep are not present in the dataframe: {e}" 

202 ) 

203 

204 MIN_FILTER = 2 

205 filtering_method = FilterMethod(self.data_info.filtering) 

206 

207 if df.shape[0] < MIN_FILTER or df.empty: 

208 warnings.warn( 

209 f"WARNING: df is too small for filtering, needs to have at least {MIN_FILTER}" 

210 ) 

211 return df, df.columns.tolist() 

212 

213 filtered_df = df.copy() 

214 

215 ## Remove features which are not in the ontology for Ontix architecture 

216 ## must be done before other filtering is applied 

217 if hasattr(self, "ontologies") and self.ontologies is not None: 

218 all_feature_names: Union[Set, List] = set() 

219 for key, values in self.ontologies[-1].items(): 

220 all_feature_names.update(values) 

221 all_feature_names = list(all_feature_names) 

222 feature_order = filtered_df.columns.tolist() 

223 missing_features = [f for f in feature_order if f not in all_feature_names] 

224 ## Filter out features not in the ontology 

225 feature_order = [f for f in feature_order if f in all_feature_names] 

226 if missing_features: 

227 print( 

228 f"Features in feature_order not found in all_feature_names: {missing_features}" 

229 ) 

230 

231 filtered_df = filtered_df.loc[:, feature_order] 

232 

233 #### 

234 

235 if filtering_method == FilterMethod.NOFILT: 

236 return filtered_df, df.columns.tolist() 

237 if self.data_info.k_filter is None: 

238 return filtered_df, df.columns.tolist() 

239 

240 if filtering_method == FilterMethod.NONZEROVAR: 

241 filtered_df = self._filter_nonzero_variance(filtered_df) 

242 elif filtering_method == FilterMethod.VAR: 

243 filtered_df = self._filter_nonzero_variance(filtered_df) 

244 filtered_df = self._filter_by_variance(filtered_df, self.data_info.k_filter) 

245 elif filtering_method == FilterMethod.MAD: 

246 filtered_df = self._filter_nonzero_variance(filtered_df) 

247 filtered_df = self._filter_by_mad(filtered_df, self.data_info.k_filter) 

248 elif filtering_method == FilterMethod.CORR: 

249 filtered_df = self._filter_nonzero_variance(filtered_df) 

250 filtered_df = self._filter_by_correlation( 

251 filtered_df, self.data_info.k_filter 

252 ) 

253 elif filtering_method == FilterMethod.VARCORR: 

254 filtered_df = self._filter_nonzero_variance(filtered_df) 

255 filtered_df = self._filter_by_variance( 

256 filtered_df, 

257 self.data_info.k_filter * 10 if self.data_info.k_filter else None, 

258 ) 

259 if self.data_info.k_filter is not None: 

260 # Apply correlation filter on the already variance-filtered data 

261 num_features_after_var = filtered_df.shape[1] 

262 k_corr = min(self.data_info.k_filter, num_features_after_var) 

263 filtered_df = self._filter_by_correlation(filtered_df, k_corr) 

264 

265 return filtered_df, filtered_df.columns.tolist() 

266 

267 def _init_scaler(self) -> None: 

268 """Initializes the scaler based on the configured scaling method.""" 

269 self.method = self.data_info.scaling 

270 

271 if self.method == "NOTSET": 

272 # if not set in data config, we use the global scaling config 

273 self.method = self.config.scaling 

274 if self.method == "MINMAX": 

275 self._scaler = MinMaxScaler(clip=True) 

276 elif self.method == "STANDARD": 

277 self._scaler = StandardScaler() 

278 elif self.method == "ROBUST": 

279 self._scaler = RobustScaler() 

280 elif self.method == "MAXABS": 

281 self._scaler = MaxAbsScaler() 

282 else: 

283 self._scaler = None 

284 

285 def fit_scaler(self, df: Union[pd.Series, pd.DataFrame]) -> Any: 

286 """Fits the scaler to the input dataframe (typically the training set). 

287 

288 Args: 

289 df: Input dataframe to fit the scaler on. 

290 

291 Returns: 

292 The fitted scaler object. 

293 """ 

294 self._init_scaler() 

295 if self._scaler is not None: 

296 self._scaler.fit(df) 

297 else: 

298 warnings.warn("No scaling applied.") 

299 return self._scaler 

300 

301 def scale( 

302 self, df: Union[pd.Series, pd.DataFrame], scaler: Any 

303 ) -> Union[pd.Series, pd.DataFrame]: 

304 """Applies the fitted scaler to the input dataframe. 

305 

306 Args: 

307 df: Input dataframe to be scaled. 

308 scaler: The fitted scaler object. 

309 

310 Returns: 

311 Scaled dataframe. 

312 """ 

313 if self.method == "LOG1P": 

314 X_log = np.log1p(df.values) 

315 X_norm = X_log / np.log1p(np.max(X_log, axis=0)) 

316 df_scaled = pd.DataFrame(X_norm, columns=df.columns, index=df.index) 

317 return df_scaled 

318 if scaler is None: 

319 warnings.warn("No scaler has been fitted yet or scaling is set to none.") 

320 return df 

321 df_scaled = pd.DataFrame( 

322 scaler.transform(df), columns=df.columns, index=df.index 

323 ) 

324 return df_scaled 

325 

326 @property 

327 def available_methods(self) -> List[str]: 

328 """Lists all available filtering methods. 

329 

330 Returns: 

331 List of available filtering method names. 

332 """ 

333 return [method.value for method in FilterMethod]