Coverage for src/driada/dim_reduction/dim_reduction.py: 0.00%

55 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1from .dr_base import * 

2from .data import MVData 

3from .graph import ProximityGraph 

4from .embedding import Embedding 

5 

6 

7# TODO: refactor this 

8def dr_series(d, 

9 n_jumps, 

10 all_metric_params, 

11 all_graph_params, 

12 all_embedding_params, 

13 recalculate_if_error=0): 

14 

15 print('--------------------------- JUMP 1 --------------------------------') 

16 print('Performing jump from dim', d.dim, 'to dim', all_embedding_params['dim'][0], ':') 

17 metric_params = dict(zip(all_metric_params.keys(), [val[0] for val in all_metric_params.values()])) 

18 graph_params = dict(zip(all_graph_params.keys(), [val[0] for val in all_graph_params.values()])) 

19 embedding_params = dict(zip(all_embedding_params.keys(), [val[0] for val in all_embedding_params.values()])) 

20 embedding_params['e_method'] = METHODS_DICT[embedding_params['e_method_name']] 

21 

22 m_params = m_param_filter(metric_params) 

23 g_params = g_param_filter(graph_params) 

24 e_params = e_param_filter(embedding_params) 

25 

26 maxiter = 20 

27 n_iter = maxiter * int(recalculate_if_error) + 1 

28 it = 0 

29 success = 0 

30 while it < n_iter and success == 0: 

31 try: 

32 emb = d.get_embedding(m_params, g_params, e_params) 

33 success = 1 

34 except: 

35 it += 1 

36 print("Unexpected error:", sys.exc_info()[0]) 

37 raise 

38 

39 if it == n_iter: 

40 raise Exception('First jump failed after %s attempts' % n_iter) 

41 

42 if n_jumps > 1: 

43 datalist = [d, MVData(emb.coords, labels=d.labels)] 

44 

45 for i in range(1, n_jumps): 

46 print('--------------------------- JUMP ' + str(i + 1) + ' --------------------------------', ) 

47 print('Performing jump from dim', datalist[i].dim, 'to dim', all_embedding_params['dim'][i], ':') 

48 metric_params = dict(zip(all_metric_params.keys(), [val[i] for val in all_metric_params.values()])) 

49 graph_params = dict(zip(all_graph_params.keys(), [val[i] for val in all_graph_params.values()])) 

50 embedding_params = dict(zip(all_embedding_params.keys(), [val[i] for val in all_embedding_params.values()])) 

51 embedding_params['e_method'] = METHODS_DICT[embedding_params['e_method_name']] 

52 

53 m_params = m_param_filter(metric_params) 

54 g_params = g_param_filter(graph_params) 

55 e_params = e_param_filter(embedding_params) 

56 

57 maxiter = 20 

58 n_iter = maxiter * (recalculate_if_error) + 1 

59 it = 0 

60 success = 0 

61 

62 while it < n_iter and success == 0: 

63 try: 

64 emb = datalist[i].get_embedding(m_params, g_params, e_params) 

65 success = 1 

66 except: 

67 print('iter', it) 

68 it += 1 

69 

70 if it == n_iter: 

71 raise Exception('Jump ', str(i + 1), ' failed after %s attempts' % n_iter) 

72 

73 datalist.append(MVData(emb.coords, labels=d.labels)) 

74 

75 return emb