Coverage for src/driada/dim_reduction/dr_base.py: 74.68%

79 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1from pynndescent.distances import named_distances 

2import sys 

3from typing import Dict, Optional, Any 

4 

5class DRMethod(object): 

6 """Dimensionality reduction method configuration. 

7  

8 Attributes 

9 ---------- 

10 is_linear : bool 

11 Whether the method is linear 

12 requires_graph : bool 

13 Whether the method requires a proximity graph 

14 requires_distmat : bool 

15 Whether the method requires a distance matrix 

16 nn_based : bool 

17 Whether the method is neural network based 

18 default_params : dict 

19 Default parameters for the embedding method 

20 default_graph_params : dict or None 

21 Default graph construction parameters (if requires_graph) 

22 default_metric_params : dict or None 

23 Default metric parameters (if requires weights) 

24 """ 

25 

26 def __init__(self, is_linear, requires_graph, requires_distmat, nn_based, 

27 default_params=None, default_graph_params=None, default_metric_params=None): 

28 self.is_linear = is_linear 

29 self.requires_graph = requires_graph 

30 self.requires_distmat = requires_distmat 

31 self.nn_based = nn_based 

32 self.default_params = default_params or {} 

33 self.default_graph_params = default_graph_params 

34 self.default_metric_params = default_metric_params 

35 

36 

37# Default graph parameters for graph-based methods 

38DEFAULT_KNN_GRAPH = { 

39 'g_method_name': 'knn', 

40 'nn': 15, 

41 'weighted': 0, 

42 'max_deleted_nodes': 0.2, 

43 'dist_to_aff': 'hk' 

44} 

45 

46DEFAULT_METRIC = { 

47 'metric_name': 'l2', 

48 'sigma': 1.0 

49} 

50 

51METHODS_DICT = { 

52 'pca': DRMethod( 

53 1, 0, 0, 0, 

54 default_params={'dim': 2} 

55 ), 

56 'le': DRMethod( 

57 0, 1, 0, 0, 

58 default_params={'dim': 2}, 

59 default_graph_params=DEFAULT_KNN_GRAPH, 

60 default_metric_params=DEFAULT_METRIC 

61 ), 

62 'auto_le': DRMethod( 

63 0, 1, 0, 0, 

64 default_params={'dim': 2}, 

65 default_graph_params=DEFAULT_KNN_GRAPH, 

66 default_metric_params=DEFAULT_METRIC 

67 ), 

68 'dmaps': DRMethod( 

69 0, 1, 0, 0, 

70 default_params={'dim': 2, 'dm_alpha': 0.5}, 

71 default_graph_params=DEFAULT_KNN_GRAPH, 

72 default_metric_params=DEFAULT_METRIC 

73 ), 

74 'auto_dmaps': DRMethod( 

75 0, 1, 0, 0, 

76 default_params={'dim': 2, 'dm_alpha': 0.5}, 

77 default_graph_params=DEFAULT_KNN_GRAPH, 

78 default_metric_params=DEFAULT_METRIC 

79 ), 

80 'mds': DRMethod( 

81 0, 0, 1, 0, 

82 default_params={'dim': 2} 

83 ), 

84 'isomap': DRMethod( 

85 0, 1, 0, 0, 

86 default_params={'dim': 2}, 

87 default_graph_params={**DEFAULT_KNN_GRAPH, 'nn': 15}, 

88 default_metric_params=DEFAULT_METRIC 

89 ), 

90 'lle': DRMethod( 

91 0, 1, 0, 0, 

92 default_params={'dim': 2}, 

93 default_graph_params={**DEFAULT_KNN_GRAPH, 'nn': 10}, 

94 default_metric_params=DEFAULT_METRIC 

95 ), 

96 'hlle': DRMethod( 

97 0, 1, 0, 0, 

98 default_params={'dim': 2}, 

99 default_graph_params={**DEFAULT_KNN_GRAPH, 'nn': 10}, 

100 default_metric_params=DEFAULT_METRIC 

101 ), 

102 'mvu': DRMethod( 

103 0, 1, 0, 0, 

104 default_params={'dim': 2}, 

105 default_graph_params=DEFAULT_KNN_GRAPH, 

106 default_metric_params=DEFAULT_METRIC 

107 ), 

108 'ae': DRMethod( 

109 0, 0, 0, 1, 

110 default_params={'dim': 2} 

111 ), 

112 'vae': DRMethod( 

113 0, 0, 0, 1, 

114 default_params={'dim': 2} 

115 ), 

116 'tsne': DRMethod( 

117 0, 0, 0, 0, 

118 default_params={'dim': 2, 'perplexity': 30} 

119 ), 

120 'umap': DRMethod( 

121 0, 1, 0, 0, 

122 default_params={'dim': 2, 'min_dist': 0.1}, 

123 default_graph_params={**DEFAULT_KNN_GRAPH, 'nn': 15}, 

124 default_metric_params=DEFAULT_METRIC 

125 ) 

126} 

127 

128GRAPH_CONSTRUCTION_METHODS = ['knn', 'auto_knn', 'eps', 'eknn', 'umap', 'tsne'] 

129 

130EMBEDDING_CONSTRUCTION_METHODS = ['pca', 

131 'le', 

132 'auto_le', 

133 'dmaps', 

134 'auto_dmaps', 

135 'mds', 

136 'isomap', 

137 'lle', 

138 'hlle', 

139 'mvu', 

140 'ae', 

141 'vae', 

142 'tsne', 

143 'umap'] 

144 

145# TODO: implement random projections 

146 

147 

148def m_param_filter(para): 

149 ''' 

150 This function prunes parameters that are excessive for 

151 chosen distance matrix construction method 

152 ''' 

153 name = para['metric_name'] 

154 appr_keys = ['metric_name'] 

155 

156 if not (para['sigma'] is None): 

157 appr_keys.append('sigma') 

158 

159 if name not in named_distances: 

160 if name == 'hyperbolic': 

161 # para['metric_name'] = globals()[name] 

162 pass 

163 else: 

164 raise Exception('this custom metric is not implemented!') 

165 

166 if name == 'minkowski': 

167 appr_keys.append('p') 

168 

169 return {key: para[key] for key in appr_keys} 

170 

171 

172def g_param_filter(para): 

173 ''' 

174 This function prunes parameters that are excessive for 

175 chosen graph construction method 

176 ''' 

177 gmethod = para['g_method_name'] 

178 appr_keys = ['g_method_name', 'max_deleted_nodes', 'weighted', 'dist_to_aff'] 

179 

180 if gmethod in ['knn', 'auto_knn', 'umap']: 

181 appr_keys.extend(['nn']) 

182 

183 elif gmethod == 'eps': 

184 appr_keys.extend(['eps', 'eps_min']) 

185 

186 elif gmethod == 'eknn': 

187 appr_keys.extend(['eps', 'eps_min', 'nn']) 

188 

189 elif gmethod == 'tsne': 

190 appr_keys.extend(['perplexity']) 

191 

192 return {key: para[key] for key in appr_keys} 

193 

194 

195def e_param_filter(para): 

196 ''' 

197 This function prunes parameters that are excessive for the 

198 chosen embedding construction method 

199 ''' 

200 

201 appr_keys = ['e_method', 'e_method_name', 'dim'] 

202 

203 if para['e_method_name'] == 'umap': 

204 appr_keys.append('min_dist') 

205 

206 if para['e_method_name'] in ['dmaps', 'auto_dmaps']: 

207 appr_keys.append('dm_alpha') 

208 

209 return {key: para[key] for key in appr_keys} 

210 

211 

212def merge_params_with_defaults(method_name: str, user_params: Optional[Dict[str, Any]] = None) -> Dict[str, Dict[str, Any]]: 

213 """Merge user parameters with method defaults. 

214  

215 Parameters 

216 ---------- 

217 method_name : str 

218 Name of the DR method 

219 user_params : dict or None 

220 User-provided parameters. Can contain 'e_params', 'g_params', 'm_params' keys 

221 or direct parameter values which will be treated as embedding parameters. 

222  

223 Returns 

224 ------- 

225 dict 

226 Dictionary with 'e_params', 'g_params', 'm_params' keys containing merged parameters 

227 """ 

228 if method_name not in METHODS_DICT: 

229 raise ValueError(f"Unknown method: {method_name}") 

230 

231 method = METHODS_DICT[method_name] 

232 

233 # Initialize with defaults 

234 e_params = method.default_params.copy() 

235 e_params['e_method_name'] = method_name 

236 e_params['e_method'] = method 

237 

238 g_params = method.default_graph_params.copy() if method.default_graph_params else None 

239 m_params = method.default_metric_params.copy() if method.default_metric_params else None 

240 

241 if user_params is None: 

242 return {'e_params': e_params, 'g_params': g_params, 'm_params': m_params} 

243 

244 # Handle different input formats 

245 if 'e_params' in user_params or 'g_params' in user_params or 'm_params' in user_params: 

246 # User provided structured parameters 

247 if 'e_params' in user_params and user_params['e_params']: 

248 e_params.update(user_params['e_params']) 

249 if 'g_params' in user_params and user_params['g_params'] and g_params is not None: 

250 g_params.update(user_params['g_params']) 

251 if 'm_params' in user_params and user_params['m_params'] and m_params is not None: 

252 m_params.update(user_params['m_params']) 

253 else: 

254 # User provided flat parameters - need to distribute to appropriate dicts 

255 for key, value in user_params.items(): 

256 if key == 'n_neighbors' and g_params is not None: 

257 # Map n_neighbors to nn in graph params 

258 g_params['nn'] = value 

259 elif key == 'metric' and m_params is not None: 

260 # Map metric to metric_name in metric params 

261 m_params['metric_name'] = value 

262 elif key == 'sigma' and m_params is not None: 

263 m_params['sigma'] = value 

264 elif key == 'max_deleted_nodes' and g_params is not None: 

265 g_params['max_deleted_nodes'] = value 

266 else: 

267 # All other params go to embedding params 

268 e_params[key] = value 

269 

270 # Always ensure e_method is set 

271 if 'e_method' not in e_params or e_params['e_method'] is None: 

272 e_params['e_method'] = method 

273 

274 return {'e_params': e_params, 'g_params': g_params, 'm_params': m_params}