Coverage for src/driada/dim_reduction/dr_base.py: 74.68%
79 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1from pynndescent.distances import named_distances
2import sys
3from typing import Dict, Optional, Any
5class DRMethod(object):
6 """Dimensionality reduction method configuration.
8 Attributes
9 ----------
10 is_linear : bool
11 Whether the method is linear
12 requires_graph : bool
13 Whether the method requires a proximity graph
14 requires_distmat : bool
15 Whether the method requires a distance matrix
16 nn_based : bool
17 Whether the method is neural network based
18 default_params : dict
19 Default parameters for the embedding method
20 default_graph_params : dict or None
21 Default graph construction parameters (if requires_graph)
22 default_metric_params : dict or None
23 Default metric parameters (if requires weights)
24 """
26 def __init__(self, is_linear, requires_graph, requires_distmat, nn_based,
27 default_params=None, default_graph_params=None, default_metric_params=None):
28 self.is_linear = is_linear
29 self.requires_graph = requires_graph
30 self.requires_distmat = requires_distmat
31 self.nn_based = nn_based
32 self.default_params = default_params or {}
33 self.default_graph_params = default_graph_params
34 self.default_metric_params = default_metric_params
37# Default graph parameters for graph-based methods
38DEFAULT_KNN_GRAPH = {
39 'g_method_name': 'knn',
40 'nn': 15,
41 'weighted': 0,
42 'max_deleted_nodes': 0.2,
43 'dist_to_aff': 'hk'
44}
46DEFAULT_METRIC = {
47 'metric_name': 'l2',
48 'sigma': 1.0
49}
51METHODS_DICT = {
52 'pca': DRMethod(
53 1, 0, 0, 0,
54 default_params={'dim': 2}
55 ),
56 'le': DRMethod(
57 0, 1, 0, 0,
58 default_params={'dim': 2},
59 default_graph_params=DEFAULT_KNN_GRAPH,
60 default_metric_params=DEFAULT_METRIC
61 ),
62 'auto_le': DRMethod(
63 0, 1, 0, 0,
64 default_params={'dim': 2},
65 default_graph_params=DEFAULT_KNN_GRAPH,
66 default_metric_params=DEFAULT_METRIC
67 ),
68 'dmaps': DRMethod(
69 0, 1, 0, 0,
70 default_params={'dim': 2, 'dm_alpha': 0.5},
71 default_graph_params=DEFAULT_KNN_GRAPH,
72 default_metric_params=DEFAULT_METRIC
73 ),
74 'auto_dmaps': DRMethod(
75 0, 1, 0, 0,
76 default_params={'dim': 2, 'dm_alpha': 0.5},
77 default_graph_params=DEFAULT_KNN_GRAPH,
78 default_metric_params=DEFAULT_METRIC
79 ),
80 'mds': DRMethod(
81 0, 0, 1, 0,
82 default_params={'dim': 2}
83 ),
84 'isomap': DRMethod(
85 0, 1, 0, 0,
86 default_params={'dim': 2},
87 default_graph_params={**DEFAULT_KNN_GRAPH, 'nn': 15},
88 default_metric_params=DEFAULT_METRIC
89 ),
90 'lle': DRMethod(
91 0, 1, 0, 0,
92 default_params={'dim': 2},
93 default_graph_params={**DEFAULT_KNN_GRAPH, 'nn': 10},
94 default_metric_params=DEFAULT_METRIC
95 ),
96 'hlle': DRMethod(
97 0, 1, 0, 0,
98 default_params={'dim': 2},
99 default_graph_params={**DEFAULT_KNN_GRAPH, 'nn': 10},
100 default_metric_params=DEFAULT_METRIC
101 ),
102 'mvu': DRMethod(
103 0, 1, 0, 0,
104 default_params={'dim': 2},
105 default_graph_params=DEFAULT_KNN_GRAPH,
106 default_metric_params=DEFAULT_METRIC
107 ),
108 'ae': DRMethod(
109 0, 0, 0, 1,
110 default_params={'dim': 2}
111 ),
112 'vae': DRMethod(
113 0, 0, 0, 1,
114 default_params={'dim': 2}
115 ),
116 'tsne': DRMethod(
117 0, 0, 0, 0,
118 default_params={'dim': 2, 'perplexity': 30}
119 ),
120 'umap': DRMethod(
121 0, 1, 0, 0,
122 default_params={'dim': 2, 'min_dist': 0.1},
123 default_graph_params={**DEFAULT_KNN_GRAPH, 'nn': 15},
124 default_metric_params=DEFAULT_METRIC
125 )
126}
128GRAPH_CONSTRUCTION_METHODS = ['knn', 'auto_knn', 'eps', 'eknn', 'umap', 'tsne']
130EMBEDDING_CONSTRUCTION_METHODS = ['pca',
131 'le',
132 'auto_le',
133 'dmaps',
134 'auto_dmaps',
135 'mds',
136 'isomap',
137 'lle',
138 'hlle',
139 'mvu',
140 'ae',
141 'vae',
142 'tsne',
143 'umap']
145# TODO: implement random projections
148def m_param_filter(para):
149 '''
150 This function prunes parameters that are excessive for
151 chosen distance matrix construction method
152 '''
153 name = para['metric_name']
154 appr_keys = ['metric_name']
156 if not (para['sigma'] is None):
157 appr_keys.append('sigma')
159 if name not in named_distances:
160 if name == 'hyperbolic':
161 # para['metric_name'] = globals()[name]
162 pass
163 else:
164 raise Exception('this custom metric is not implemented!')
166 if name == 'minkowski':
167 appr_keys.append('p')
169 return {key: para[key] for key in appr_keys}
172def g_param_filter(para):
173 '''
174 This function prunes parameters that are excessive for
175 chosen graph construction method
176 '''
177 gmethod = para['g_method_name']
178 appr_keys = ['g_method_name', 'max_deleted_nodes', 'weighted', 'dist_to_aff']
180 if gmethod in ['knn', 'auto_knn', 'umap']:
181 appr_keys.extend(['nn'])
183 elif gmethod == 'eps':
184 appr_keys.extend(['eps', 'eps_min'])
186 elif gmethod == 'eknn':
187 appr_keys.extend(['eps', 'eps_min', 'nn'])
189 elif gmethod == 'tsne':
190 appr_keys.extend(['perplexity'])
192 return {key: para[key] for key in appr_keys}
195def e_param_filter(para):
196 '''
197 This function prunes parameters that are excessive for the
198 chosen embedding construction method
199 '''
201 appr_keys = ['e_method', 'e_method_name', 'dim']
203 if para['e_method_name'] == 'umap':
204 appr_keys.append('min_dist')
206 if para['e_method_name'] in ['dmaps', 'auto_dmaps']:
207 appr_keys.append('dm_alpha')
209 return {key: para[key] for key in appr_keys}
212def merge_params_with_defaults(method_name: str, user_params: Optional[Dict[str, Any]] = None) -> Dict[str, Dict[str, Any]]:
213 """Merge user parameters with method defaults.
215 Parameters
216 ----------
217 method_name : str
218 Name of the DR method
219 user_params : dict or None
220 User-provided parameters. Can contain 'e_params', 'g_params', 'm_params' keys
221 or direct parameter values which will be treated as embedding parameters.
223 Returns
224 -------
225 dict
226 Dictionary with 'e_params', 'g_params', 'm_params' keys containing merged parameters
227 """
228 if method_name not in METHODS_DICT:
229 raise ValueError(f"Unknown method: {method_name}")
231 method = METHODS_DICT[method_name]
233 # Initialize with defaults
234 e_params = method.default_params.copy()
235 e_params['e_method_name'] = method_name
236 e_params['e_method'] = method
238 g_params = method.default_graph_params.copy() if method.default_graph_params else None
239 m_params = method.default_metric_params.copy() if method.default_metric_params else None
241 if user_params is None:
242 return {'e_params': e_params, 'g_params': g_params, 'm_params': m_params}
244 # Handle different input formats
245 if 'e_params' in user_params or 'g_params' in user_params or 'm_params' in user_params:
246 # User provided structured parameters
247 if 'e_params' in user_params and user_params['e_params']:
248 e_params.update(user_params['e_params'])
249 if 'g_params' in user_params and user_params['g_params'] and g_params is not None:
250 g_params.update(user_params['g_params'])
251 if 'm_params' in user_params and user_params['m_params'] and m_params is not None:
252 m_params.update(user_params['m_params'])
253 else:
254 # User provided flat parameters - need to distribute to appropriate dicts
255 for key, value in user_params.items():
256 if key == 'n_neighbors' and g_params is not None:
257 # Map n_neighbors to nn in graph params
258 g_params['nn'] = value
259 elif key == 'metric' and m_params is not None:
260 # Map metric to metric_name in metric params
261 m_params['metric_name'] = value
262 elif key == 'sigma' and m_params is not None:
263 m_params['sigma'] = value
264 elif key == 'max_deleted_nodes' and g_params is not None:
265 g_params['max_deleted_nodes'] = value
266 else:
267 # All other params go to embedding params
268 e_params[key] = value
270 # Always ensure e_method is set
271 if 'e_method' not in e_params or e_params['e_method'] is None:
272 e_params['e_method'] = method
274 return {'e_params': e_params, 'g_params': g_params, 'm_params': m_params}