Coverage for src/driada/dim_reduction/data.py: 62.73%

110 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1 

2import numpy as np 

3import matplotlib.pyplot as plt 

4import scipy.sparse as sp 

5 

6from .dr_base import * 

7from ..utils.data import correlation_matrix, to_numpy_array, rescale 

8from .embedding import Embedding 

9from .graph import ProximityGraph 

10 

11# TODO: refactor this 

12def check_data_for_errors(d): 

13 sums = np.sum(np.abs(d), axis=0) 

14 if len(sums.nonzero()[1]) != d.shape[1]: 

15 bad_points = np.where(sums == 0)[1] 

16 print('zero points:', bad_points) 

17 print(d.todense()[:, bad_points[0]]) 

18 raise Exception('Data contains zero points!') 

19 

20 

21class MVData(object): 

22 ''' 

23 Main class for multivariate data storage & processing 

24 ''' 

25 

26 def __init__(self, 

27 data, 

28 labels=None, 

29 distmat=None, 

30 rescale_rows=False, 

31 data_name=None, 

32 downsampling=None): 

33 

34 if downsampling is None: 

35 self.ds = 1 

36 else: 

37 self.ds = int(downsampling) 

38 

39 self.data = to_numpy_array(data)[:, ::self.ds] 

40 

41 # TODO: add support for various preprocessing methods (wvt, med_filt, etc.) 

42 self.rescale_rows = rescale_rows 

43 if self.rescale_rows: 

44 for i, row in enumerate(self.data): 

45 self.data[i] = rescale(row) 

46 

47 self.data_name = data_name 

48 self.n_dim = self.data.shape[0] 

49 self.n_points = self.data.shape[1] 

50 

51 if labels is None: 

52 self.labels = np.zeros(self.n_points) 

53 else: 

54 self.labels = to_numpy_array(labels) 

55 

56 self.distmat = distmat 

57 

58 def median_filter(self, window): 

59 from scipy.signal import medfilt 

60 d = self.data.todense() 

61 

62 new_d = medfilt(d, window) 

63 

64 self.data = sp.csr_matrix(new_d) 

65 

66 def corr_mat(self): 

67 cm = correlation_matrix(self.data) 

68 return cm 

69 

70 def get_distmat(self, m_params=None): 

71 """Compute pairwise distance matrix. 

72  

73 Parameters 

74 ---------- 

75 m_params : dict or str, optional 

76 If dict: metric parameters with 'metric_name' key and optional metric-specific params 

77 If str: metric name directly 

78 If None: defaults to 'euclidean' 

79  

80 Returns 

81 ------- 

82 np.ndarray 

83 Distance matrix of shape (n_samples, n_samples) 

84 """ 

85 from scipy.spatial.distance import pdist, squareform 

86 

87 # Handle different input types 

88 if m_params is None: 

89 metric = 'euclidean' 

90 metric_kwargs = {} 

91 elif isinstance(m_params, str): 

92 metric = m_params 

93 metric_kwargs = {} 

94 elif isinstance(m_params, dict): 

95 metric = m_params.get('metric_name', 'euclidean') 

96 # Convert l2 to euclidean for scipy 

97 if metric == 'l2': 

98 metric = 'euclidean' 

99 # Extract additional parameters for the metric 

100 metric_kwargs = {k: v for k, v in m_params.items() if k not in ['metric_name', 'sigma']} 

101 # For minkowski distance, 'p' parameter is needed 

102 if metric == 'minkowski' and 'p' in m_params: 

103 metric_kwargs['p'] = m_params['p'] 

104 else: 

105 metric = 'euclidean' 

106 metric_kwargs = {} 

107 

108 # Compute distance matrix 

109 if metric_kwargs: 

110 distances = pdist(self.data.T, metric=metric, **metric_kwargs) 

111 else: 

112 distances = pdist(self.data.T, metric=metric) 

113 

114 self.distmat = squareform(distances) 

115 return self.distmat 

116 

117 def get_embedding(self, e_params=None, g_params=None, m_params=None, kwargs=None, 

118 method=None, **method_kwargs): 

119 """Get embedding using specified method. 

120  

121 Parameters 

122 ---------- 

123 e_params : dict, optional 

124 Embedding parameters (legacy format) 

125 g_params : dict, optional 

126 Graph parameters (legacy format) 

127 m_params : dict, optional 

128 Metric parameters (legacy format) 

129 kwargs : dict, optional 

130 Additional kwargs for the embedding method 

131 method : str, optional 

132 Method name for simplified API (e.g., 'pca', 'umap') 

133 **method_kwargs 

134 Additional parameters when using simplified API 

135  

136 Returns 

137 ------- 

138 Embedding 

139 The computed embedding 

140  

141 Examples 

142 -------- 

143 # Legacy format (still supported) 

144 >>> emb = mvdata.get_embedding(e_params, g_params, m_params) 

145  

146 # New simplified format 

147 >>> emb = mvdata.get_embedding(method='pca', dim=3) 

148 >>> emb = mvdata.get_embedding(method='umap', n_components=2, n_neighbors=30) 

149 """ 

150 # Handle new simplified API 

151 if method is not None: 

152 # Merge with defaults 

153 from .dr_base import merge_params_with_defaults 

154 params = merge_params_with_defaults(method, method_kwargs) 

155 e_params = params['e_params'] 

156 g_params = params['g_params'] 

157 m_params = params['m_params'] 

158 elif e_params is None: 

159 raise ValueError("Either 'method' or 'e_params' must be provided") 

160 

161 # Legacy compatibility: ensure e_method is set 

162 if 'e_method' not in e_params or e_params['e_method'] is None: 

163 method_name = e_params.get('e_method_name') 

164 if method_name and method_name in METHODS_DICT: 

165 e_params['e_method'] = METHODS_DICT[method_name] 

166 

167 method = e_params['e_method'] 

168 method_name = e_params['e_method_name'] 

169 

170 if method_name not in EMBEDDING_CONSTRUCTION_METHODS: 

171 raise Exception('Unknown embedding construction method!') 

172 

173 graph = None 

174 if method.requires_graph: 

175 if g_params is None: 

176 raise ValueError(f'Method {method_name} requires proximity graph, but ' 

177 f'graph params were not provided') 

178 if g_params['weighted'] and m_params is None: 

179 raise ValueError(f'Method {method_name} requires weights for proximity graph, but ' 

180 f'metric params were not provided') 

181 

182 graph = self.get_proximity_graph(m_params, g_params) 

183 

184 if method.requires_distmat and self.distmat is None: 

185 raise Exception(f'No distmat provided for {method_name} method.' 

186 f' Try constructing it first with get_distmat() method') 

187 

188 emb = Embedding(self.data, self.distmat, self.labels, e_params, g=graph) 

189 

190 # For neural network methods, extract NN-specific params from e_params to pass as kwargs 

191 if method.nn_based: 

192 nn_kwargs = kwargs or {} 

193 # Extract neural network specific parameters from e_params 

194 nn_params = ['epochs', 'lr', 'batch_size', 'seed', 'verbose', 

195 'feature_dropout', 'enc_kwargs', 'dec_kwargs', 

196 'kld_weight', 'inter_dim', 'train_size', 

197 'add_corr_loss', 'corr_hyperweight', 

198 'add_mi_loss', 'mi_hyperweight', 'minimize_mi_data', 

199 'log_every', 'device', 'continue_learning'] 

200 for param in nn_params: 

201 if param in e_params: 

202 nn_kwargs[param] = e_params[param] 

203 emb.build(kwargs=nn_kwargs) 

204 else: 

205 emb.build(kwargs=kwargs) 

206 

207 return emb 

208 

209 def get_proximity_graph(self, m_params, g_params): 

210 if g_params['g_method_name'] not in GRAPH_CONSTRUCTION_METHODS: 

211 raise Exception('Unknown graph construction method!') 

212 

213 graph = ProximityGraph(self.data, m_params, g_params) 

214 # print('Graph succesfully constructed') 

215 return graph 

216 

217 def draw_vector(self, num): 

218 data = self.data[:, num] 

219 plt.matshow(data.reshape(1, self.n_dim)) 

220 plt.matshow(self.data) 

221 

222 def draw_row(self, num): 

223 data = self.data[num, :] 

224 plt.figure(figsize=(12, 10)) 

225 plt.plot(data)