Coverage for src/driada/dim_reduction/data.py: 62.73%
110 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-25 15:40 +0300
2import numpy as np
3import matplotlib.pyplot as plt
4import scipy.sparse as sp
6from .dr_base import *
7from ..utils.data import correlation_matrix, to_numpy_array, rescale
8from .embedding import Embedding
9from .graph import ProximityGraph
11# TODO: refactor this
12def check_data_for_errors(d):
13 sums = np.sum(np.abs(d), axis=0)
14 if len(sums.nonzero()[1]) != d.shape[1]:
15 bad_points = np.where(sums == 0)[1]
16 print('zero points:', bad_points)
17 print(d.todense()[:, bad_points[0]])
18 raise Exception('Data contains zero points!')
21class MVData(object):
22 '''
23 Main class for multivariate data storage & processing
24 '''
26 def __init__(self,
27 data,
28 labels=None,
29 distmat=None,
30 rescale_rows=False,
31 data_name=None,
32 downsampling=None):
34 if downsampling is None:
35 self.ds = 1
36 else:
37 self.ds = int(downsampling)
39 self.data = to_numpy_array(data)[:, ::self.ds]
41 # TODO: add support for various preprocessing methods (wvt, med_filt, etc.)
42 self.rescale_rows = rescale_rows
43 if self.rescale_rows:
44 for i, row in enumerate(self.data):
45 self.data[i] = rescale(row)
47 self.data_name = data_name
48 self.n_dim = self.data.shape[0]
49 self.n_points = self.data.shape[1]
51 if labels is None:
52 self.labels = np.zeros(self.n_points)
53 else:
54 self.labels = to_numpy_array(labels)
56 self.distmat = distmat
58 def median_filter(self, window):
59 from scipy.signal import medfilt
60 d = self.data.todense()
62 new_d = medfilt(d, window)
64 self.data = sp.csr_matrix(new_d)
66 def corr_mat(self):
67 cm = correlation_matrix(self.data)
68 return cm
70 def get_distmat(self, m_params=None):
71 """Compute pairwise distance matrix.
73 Parameters
74 ----------
75 m_params : dict or str, optional
76 If dict: metric parameters with 'metric_name' key and optional metric-specific params
77 If str: metric name directly
78 If None: defaults to 'euclidean'
80 Returns
81 -------
82 np.ndarray
83 Distance matrix of shape (n_samples, n_samples)
84 """
85 from scipy.spatial.distance import pdist, squareform
87 # Handle different input types
88 if m_params is None:
89 metric = 'euclidean'
90 metric_kwargs = {}
91 elif isinstance(m_params, str):
92 metric = m_params
93 metric_kwargs = {}
94 elif isinstance(m_params, dict):
95 metric = m_params.get('metric_name', 'euclidean')
96 # Convert l2 to euclidean for scipy
97 if metric == 'l2':
98 metric = 'euclidean'
99 # Extract additional parameters for the metric
100 metric_kwargs = {k: v for k, v in m_params.items() if k not in ['metric_name', 'sigma']}
101 # For minkowski distance, 'p' parameter is needed
102 if metric == 'minkowski' and 'p' in m_params:
103 metric_kwargs['p'] = m_params['p']
104 else:
105 metric = 'euclidean'
106 metric_kwargs = {}
108 # Compute distance matrix
109 if metric_kwargs:
110 distances = pdist(self.data.T, metric=metric, **metric_kwargs)
111 else:
112 distances = pdist(self.data.T, metric=metric)
114 self.distmat = squareform(distances)
115 return self.distmat
117 def get_embedding(self, e_params=None, g_params=None, m_params=None, kwargs=None,
118 method=None, **method_kwargs):
119 """Get embedding using specified method.
121 Parameters
122 ----------
123 e_params : dict, optional
124 Embedding parameters (legacy format)
125 g_params : dict, optional
126 Graph parameters (legacy format)
127 m_params : dict, optional
128 Metric parameters (legacy format)
129 kwargs : dict, optional
130 Additional kwargs for the embedding method
131 method : str, optional
132 Method name for simplified API (e.g., 'pca', 'umap')
133 **method_kwargs
134 Additional parameters when using simplified API
136 Returns
137 -------
138 Embedding
139 The computed embedding
141 Examples
142 --------
143 # Legacy format (still supported)
144 >>> emb = mvdata.get_embedding(e_params, g_params, m_params)
146 # New simplified format
147 >>> emb = mvdata.get_embedding(method='pca', dim=3)
148 >>> emb = mvdata.get_embedding(method='umap', n_components=2, n_neighbors=30)
149 """
150 # Handle new simplified API
151 if method is not None:
152 # Merge with defaults
153 from .dr_base import merge_params_with_defaults
154 params = merge_params_with_defaults(method, method_kwargs)
155 e_params = params['e_params']
156 g_params = params['g_params']
157 m_params = params['m_params']
158 elif e_params is None:
159 raise ValueError("Either 'method' or 'e_params' must be provided")
161 # Legacy compatibility: ensure e_method is set
162 if 'e_method' not in e_params or e_params['e_method'] is None:
163 method_name = e_params.get('e_method_name')
164 if method_name and method_name in METHODS_DICT:
165 e_params['e_method'] = METHODS_DICT[method_name]
167 method = e_params['e_method']
168 method_name = e_params['e_method_name']
170 if method_name not in EMBEDDING_CONSTRUCTION_METHODS:
171 raise Exception('Unknown embedding construction method!')
173 graph = None
174 if method.requires_graph:
175 if g_params is None:
176 raise ValueError(f'Method {method_name} requires proximity graph, but '
177 f'graph params were not provided')
178 if g_params['weighted'] and m_params is None:
179 raise ValueError(f'Method {method_name} requires weights for proximity graph, but '
180 f'metric params were not provided')
182 graph = self.get_proximity_graph(m_params, g_params)
184 if method.requires_distmat and self.distmat is None:
185 raise Exception(f'No distmat provided for {method_name} method.'
186 f' Try constructing it first with get_distmat() method')
188 emb = Embedding(self.data, self.distmat, self.labels, e_params, g=graph)
190 # For neural network methods, extract NN-specific params from e_params to pass as kwargs
191 if method.nn_based:
192 nn_kwargs = kwargs or {}
193 # Extract neural network specific parameters from e_params
194 nn_params = ['epochs', 'lr', 'batch_size', 'seed', 'verbose',
195 'feature_dropout', 'enc_kwargs', 'dec_kwargs',
196 'kld_weight', 'inter_dim', 'train_size',
197 'add_corr_loss', 'corr_hyperweight',
198 'add_mi_loss', 'mi_hyperweight', 'minimize_mi_data',
199 'log_every', 'device', 'continue_learning']
200 for param in nn_params:
201 if param in e_params:
202 nn_kwargs[param] = e_params[param]
203 emb.build(kwargs=nn_kwargs)
204 else:
205 emb.build(kwargs=kwargs)
207 return emb
209 def get_proximity_graph(self, m_params, g_params):
210 if g_params['g_method_name'] not in GRAPH_CONSTRUCTION_METHODS:
211 raise Exception('Unknown graph construction method!')
213 graph = ProximityGraph(self.data, m_params, g_params)
214 # print('Graph succesfully constructed')
215 return graph
217 def draw_vector(self, num):
218 data = self.data[:, num]
219 plt.matshow(data.reshape(1, self.n_dim))
220 plt.matshow(self.data)
222 def draw_row(self, num):
223 data = self.data[num, :]
224 plt.figure(figsize=(12, 10))
225 plt.plot(data)