Coverage for src/driada/utils/data.py: 71.63%

141 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-25 15:40 +0300

1import hashlib 

2import h5py 

3import scipy.sparse as ssp 

4from sklearn.preprocessing import MinMaxScaler 

5from scipy.signal import hilbert 

6import numpy as np 

7import scipy.stats as st 

8from numba import njit 

9 

10 

11def create_correlated_gaussian_data(n_features=10, n_samples=10000, 

12 correlation_pairs=None, seed=42): 

13 """Generate multivariate Gaussian data with specified correlations. 

14  

15 Parameters 

16 ---------- 

17 n_features : int 

18 Number of features (dimensions) 

19 n_samples : int 

20 Number of samples 

21 correlation_pairs : list of tuples or None 

22 List of (i, j, correlation) tuples specifying correlated features. 

23 If None, uses default pattern: [(1, 9, 0.9), (2, 8, 0.8), (3, 7, 0.7)] 

24 seed : int 

25 Random seed for reproducibility 

26  

27 Returns 

28 ------- 

29 data : np.ndarray 

30 Data array of shape (n_features, n_samples) 

31 cov_matrix : np.ndarray 

32 Covariance matrix used to generate the data 

33 """ 

34 np.random.seed(seed) 

35 if correlation_pairs is None: 

36 correlation_pairs = [(1, 9, 0.9), (2, 8, 0.8), (3, 7, 0.7)] 

37 

38 # Create correlation matrix 

39 C = np.eye(n_features) 

40 for i, j, corr in correlation_pairs: 

41 if i < n_features and j < n_features: 

42 C[i, j] = C[j, i] = corr 

43 

44 # Ensure positive definite 

45 min_eig = np.min(np.linalg.eigvals(C)) 

46 if min_eig < 0: 

47 C += (-min_eig + 0.01) * np.eye(n_features) 

48 

49 # Generate data 

50 data = np.random.multivariate_normal( 

51 np.zeros(n_features), C, size=n_samples 

52 ).T 

53 

54 return data, C 

55 

56 

57def populate_nested_dict(content, outer, inner): 

58 nested_dict = {o: {} for o in outer} 

59 for o in outer: 

60 nested_dict[o] = {i: content.copy() for i in inner} 

61 

62 return nested_dict 

63 

64 

65def nested_dict_to_seq_of_tables(datadict, ordered_names1=None, ordered_names2=None): 

66 names1 = list(datadict.keys()) 

67 names2 = list(datadict[names1[0]].keys()) 

68 datakeys = list(datadict[names1[0]][names2[0]].keys()) 

69 

70 #print(names1) 

71 #print(names2) 

72 #print(datakeys) 

73 if ordered_names1 is None: 

74 ordered_names1 = sorted(names1) 

75 if ordered_names2 is None: 

76 ordered_names2 = sorted(names2) 

77 

78 table_seq = {dkey: np.zeros((len(names1), len(names2))) for dkey in datakeys} 

79 for dkey in datakeys: 

80 for i, n1 in enumerate(ordered_names1): 

81 for j, n2 in enumerate(ordered_names2): 

82 try: 

83 table_seq[dkey][i, j] = datadict[n1][n2][dkey] 

84 except KeyError: 

85 table_seq[dkey][i, j] = np.nan 

86 

87 return table_seq 

88 

89 

90def add_names_to_nested_dict(datadict, names1, names2): 

91 # renaming for convenience 

92 n1 = len(datadict.keys()) 

93 n2 = len(datadict[list(datadict.keys())[0]]) 

94 

95 if not (names1 is None and names2 is None): 

96 if names1 is None: 

97 names1 = range(n1) 

98 if names2 is None: 

99 names2 = range(n2) 

100 

101 renamed_dict = populate_nested_dict(dict(), names1, names2) 

102 for i in range(n1): 

103 for j in range(n2): 

104 renamed_dict[names1[i]][names2[j]].update(datadict[i][j]) 

105 return renamed_dict 

106 

107 else: 

108 return datadict 

109 

110 

111def retrieve_relevant_from_nested_dict(nested_dict, 

112 target_key, 

113 target_value, 

114 operation='=', 

115 allow_missing_keys=False): 

116 relevant_pairs = [] 

117 for key1 in nested_dict.keys(): 

118 for key2 in nested_dict[key1].keys(): 

119 data = nested_dict[key1][key2] 

120 if target_key not in data and not allow_missing_keys: 

121 raise ValueError(f'Target key {target_key} not found in data dict') 

122 

123 if operation == '=': 

124 criterion = data.get(target_key) == target_value 

125 elif operation == '>': 

126 criterion = data.get(target_key) > target_value if data.get(target_key) is not None else False 

127 elif operation == '<': 

128 criterion = data.get(target_key) < target_value if data.get(target_key) is not None else False 

129 else: 

130 raise ValueError(f'Operation should be one of "=", ">", "<", but {operation} was passed') 

131 

132 if criterion: 

133 relevant_pairs.append((key1, key2)) 

134 

135 return relevant_pairs 

136 

137 

138def rescale(data): 

139 scaler = MinMaxScaler(feature_range=(0, 1)) 

140 res = scaler.fit_transform(data.reshape(-1, 1)).ravel() 

141 return res 

142 

143 

144def get_hash(data): 

145 # Prepare the object hash 

146 hash_id = hashlib.md5() 

147 hash_id.update(repr(data).encode('utf-8')) 

148 return hash_id.hexdigest() 

149 

150 

151def phase_synchrony(vec1, vec2): 

152 al1 = np.angle(hilbert(vec1), deg=False) 

153 al2 = np.angle(hilbert(vec2), deg=False) 

154 phase_sync = 1-np.sin(np.abs(al1-al2)/2) 

155 return phase_sync 

156 

157 

158def correlation_matrix_old(a, b): 

159 if np.allclose(a, b): 

160 return np.corrcoef(a,a) 

161 else: 

162 n1 = a.shape[0] 

163 n2 = b.shape[0] 

164 corrmat = np.zeros((n1, n2)) 

165 for i in range(n1): 

166 for j in range(n2): 

167 corrmat[i,j] = st.pearsonr(a[i,:], b[j,:])[0] 

168 

169 return corrmat 

170 

171 

172def correlation_matrix(A): 

173 ''' 

174 # fast implementation. 

175 A: numpy array of shape (ndims, nvars) 

176 

177 returns: numpy array of shape (nvars, nvars) 

178 ''' 

179 

180 am = A - np.mean(A, axis=1, keepdims=True) 

181 return am @ am.T / np.sum(am**2, axis=1, keepdims=True).T 

182 

183 

184def cross_correlation_matrix(A, B): 

185 ''' 

186 # fast implementation. 

187 

188 A: numpy array of shape (ndims, nvars1) 

189 B: numpy array of shape (ndims, nvars2) 

190 

191 returns: numpy array of shape (nvars1, nvars2) 

192 ''' 

193 am = A - np.mean(A, axis=0, keepdims=True) 

194 bm = B - np.mean(B, axis=0, keepdims=True) 

195 return am.T @ bm / (np.sqrt(np.sum(am**2, axis=0, keepdims=True)).T * np.sqrt(np.sum(bm**2, axis=0, keepdims=True))) 

196 

197 

198# TODO: review this function 

199def norm_cross_corr(a, b): 

200 a = (a - np.mean(a)) / (np.std(a) * len(a)) 

201 b = (b - np.mean(b)) / (np.std(b)) 

202 c = np.correlate(a, b, 'full') 

203 return c 

204 

205 

206def to_numpy_array(data): 

207 if isinstance(data, np.ndarray): 

208 return data 

209 

210 if ssp.issparse(data): 

211 return data.A 

212 else: 

213 return np.array(data) 

214 

215 

216def write_dict_to_hdf5(data, hdf5_file, group_name=''): 

217 """ 

218 Recursively writes a dictionary to an HDF5 file. 

219 

220 Parameters: 

221 data (dict): The dictionary to write. 

222 hdf5_file (str): The path to the HDF5 file. 

223 group_name (str): The name of the current group in the HDF5 file. 

224 """ 

225 with h5py.File(hdf5_file, 'a') as f: 

226 # Create a new group or get existing one 

227 group = f.create_group(group_name) if group_name else f 

228 

229 for key, value in data.items(): 

230 print(key) 

231 if isinstance(value, dict): 

232 # If the value is a dictionary, recurse into it 

233 write_dict_to_hdf5(value, hdf5_file, f"{group_name}/{key}") 

234 elif isinstance(value, list): 

235 # If the value is a list, convert it to a numpy array and store it 

236 group.create_dataset(key, data=np.array(value).astype(np.float64)) 

237 elif isinstance(value, np.ndarray): 

238 # If the value is already a numpy array, store it directly 

239 group.create_dataset(key, data=value.astype(np.float64)) 

240 else: 

241 # Otherwise, store it as an attribute (string or number) 

242 group.attrs[key] = value 

243 

244 

245def read_hdf5_to_dict(hdf5_file): 

246 """ 

247 Reads an HDF5 file and converts it into a nested dictionary. 

248 

249 Parameters: 

250 hdf5_file (str): The path to the HDF5 file. 

251 

252 Returns: 

253 dict: A nested dictionary representing the contents of the HDF5 file. 

254 """ 

255 

256 def _read_group(group): 

257 """ 

258 Recursively reads an HDF5 group and converts it to a dictionary. 

259 

260 Parameters: 

261 group (h5py.Group): The HDF5 group to read. 

262 

263 Returns: 

264 dict: A dictionary representation of the group. 

265 """ 

266 data = {} 

267 

268 # Iterate over all items in the group 

269 for key, item in group.items(): 

270 if isinstance(item, h5py.Group): 

271 # If the item is a group, recurse into it 

272 data[key] = _read_group(item) 

273 elif isinstance(item, h5py.Dataset): 

274 # If the item is a dataset, convert it to a numpy array or list 

275 data[key] = item[()] 

276 else: 

277 # Handle attributes 

278 data[key] = item.attrs 

279 

280 # Add attributes of the group itself 

281 for attr_key in group.attrs: 

282 data[attr_key] = group.attrs[attr_key] 

283 

284 return data 

285 

286 with h5py.File(hdf5_file, 'r') as f: 

287 return _read_group(f)