Coverage for src/driada/dim_reduction/dim_reduction.py: 0.00%
55 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
1from .dr_base import *
2from .data import MVData
3from .graph import ProximityGraph
4from .embedding import Embedding
7# TODO: refactor this
8def dr_series(d,
9 n_jumps,
10 all_metric_params,
11 all_graph_params,
12 all_embedding_params,
13 recalculate_if_error=0):
15 print('--------------------------- JUMP 1 --------------------------------')
16 print('Performing jump from dim', d.dim, 'to dim', all_embedding_params['dim'][0], ':')
17 metric_params = dict(zip(all_metric_params.keys(), [val[0] for val in all_metric_params.values()]))
18 graph_params = dict(zip(all_graph_params.keys(), [val[0] for val in all_graph_params.values()]))
19 embedding_params = dict(zip(all_embedding_params.keys(), [val[0] for val in all_embedding_params.values()]))
20 embedding_params['e_method'] = METHODS_DICT[embedding_params['e_method_name']]
22 m_params = m_param_filter(metric_params)
23 g_params = g_param_filter(graph_params)
24 e_params = e_param_filter(embedding_params)
26 maxiter = 20
27 n_iter = maxiter * int(recalculate_if_error) + 1
28 it = 0
29 success = 0
30 while it < n_iter and success == 0:
31 try:
32 emb = d.get_embedding(m_params, g_params, e_params)
33 success = 1
34 except:
35 it += 1
36 print("Unexpected error:", sys.exc_info()[0])
37 raise
39 if it == n_iter:
40 raise Exception('First jump failed after %s attempts' % n_iter)
42 if n_jumps > 1:
43 datalist = [d, MVData(emb.coords, labels=d.labels)]
45 for i in range(1, n_jumps):
46 print('--------------------------- JUMP ' + str(i + 1) + ' --------------------------------', )
47 print('Performing jump from dim', datalist[i].dim, 'to dim', all_embedding_params['dim'][i], ':')
48 metric_params = dict(zip(all_metric_params.keys(), [val[i] for val in all_metric_params.values()]))
49 graph_params = dict(zip(all_graph_params.keys(), [val[i] for val in all_graph_params.values()]))
50 embedding_params = dict(zip(all_embedding_params.keys(), [val[i] for val in all_embedding_params.values()]))
51 embedding_params['e_method'] = METHODS_DICT[embedding_params['e_method_name']]
53 m_params = m_param_filter(metric_params)
54 g_params = g_param_filter(graph_params)
55 e_params = e_param_filter(embedding_params)
57 maxiter = 20
58 n_iter = maxiter * (recalculate_if_error) + 1
59 it = 0
60 success = 0
62 while it < n_iter and success == 0:
63 try:
64 emb = datalist[i].get_embedding(m_params, g_params, e_params)
65 success = 1
66 except:
67 print('iter', it)
68 it += 1
70 if it == n_iter:
71 raise Exception('Jump ', str(i + 1), ' failed after %s attempts' % n_iter)
73 datalist.append(MVData(emb.coords, labels=d.labels))
75 return emb