Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ timeseries \ spectre \ multiple_station.py: 67%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:01 -0800

1""" 

2 Work In progress 

3 

4 This module is concerned with working with Fourier coefficient data 

5 

6 TODO: 

7 2. Give MultivariateDataset a covariance() method 

8 

9 Tools include prototypes for 

10 - extracting portions of an FC Run Time Series 

11 - merging multiple stations runs together into an xarray 

12 - relabelling channels to avoid namespace clashes for multi-station data 

13 

14""" 

15 

16from dataclasses import dataclass 

17from typing import List, Literal, Optional, Tuple, Union 

18 

19import numpy as np 

20import pandas as pd 

21import xarray as xr 

22from loguru import logger 

23 

24import mth5.mth5 

25from mth5.timeseries.spectre.spectrogram import Spectrogram 

26from mth5.utils.exceptions import MTH5Error 

27 

28 

29@dataclass 

30class FCRunChunk: 

31 """ 

32 

33 This class formalizes the required metadata to specify a chunk of a timeseries of Fourier coefficients. 

34 

35 This may move to mt_metadata -- for now just use a dataclass as a prototype. 

36 """ 

37 

38 survey_id: str = "none" 

39 station_id: str = "" 

40 run_id: str = "" 

41 decimation_level_id: str = "0" 

42 start: str = "" 

43 end: str = "" 

44 channels: Tuple[str] = () 

45 

46 @property 

47 def start_timestamp(self) -> pd.Timestamp: 

48 return pd.Timestamp(self.start) 

49 

50 @property 

51 def end_timestamp(self) -> pd.Timestamp: 

52 return pd.Timestamp(self.end) 

53 

54 @property 

55 def duration(self) -> pd.Timestamp: 

56 return self.end_timestamp - self.start_timestamp 

57 

58 

59@dataclass 

60class MultivariateLabelScheme: 

61 """ 

62 Class to store information about how a multivariate (MV) dataset will be lablelled. 

63 

64 Has a scheme to handle the how channels will be named. 

65 

66 This is just a place holder to manage possible future complexity. 

67 

68 It seemed like a good idea to formalize the fact that we take, by default 

69 f"{station}_{component}" as the MV channel label. 

70 It also seemed like a good idea to record what the join character is. 

71 In the event that we wind up with station names that have underscores in them, then we could, 

72 for example, set the join character to "__". 

73 

74 TODO: Consider rename default to ("station", "data_var") instead of ("station", "component") 

75 

76 Parameters 

77 ---------- 

78 :type label_elements: tuple 

79 :param label_elements: This is meant to tell what information is being concatenated into an MV channel label. 

80 :type join_char: str 

81 :param join_char: The string that is used to join the label elements. 

82 

83 """ 

84 

85 label_elements: tuple = ( 

86 "station", 

87 "component", 

88 ) 

89 join_char: str = "_" 

90 

91 @property 

92 def id(self) -> str: 

93 return self.join(self.label_elements) 

94 

95 def join(self, elements: Union[list, tuple]) -> str: 

96 """ 

97 

98 Join the label elements to a string 

99 

100 :type elements: tuple 

101 :param elements: Expected to be the label elements, default are (station, component) 

102 

103 :return: The name of the channel (in a multiple-station context). 

104 :rtype: str 

105 

106 """ 

107 return self.join_char.join(elements) 

108 

109 def split(self, mv_channel_name) -> dict: 

110 """ 

111 

112 Splits a multi-station channel name and returns a dict of strings, keyed by self.label_elements. 

113 This method is basically the reverse of self.join 

114 

115 :param mv_channel_name: a multivariate channel name string 

116 :type mv_channel_name: str 

117 :return: Channel name as a dictionary. 

118 :rtype: dict 

119 

120 """ 

121 splitted = mv_channel_name.split(self.join_char) 

122 if len(splitted) != len(self.label_elements): 

123 msg = f"Incompatable map {splitted} and {self.label_elements}" 

124 logger.error(msg) 

125 msg = f"cannot map {len(splitted)} to {len(self.label_elements)}" 

126 raise ValueError(msg) 

127 output = dict(zip(self.label_elements, splitted)) 

128 return output 

129 

130 

131class MultivariateDataset(Spectrogram): 

132 """ 

133 Here is a container for a multivariate spectral dataset. 

134 The xarray is the main underlying item, but it will be useful to have functions that, for example returns a 

135 list of the associated stations, or that return a list of channels that are associated with a station, etc. 

136 

137 This is intended to be used as a multivariate spectral dotaset at one frequency band. 

138 

139 TODO: Consider making this an extension of Spectrogram 

140 TODO: Rename this class to MultivariateSpectrogram. 

141 

142 

143 """ 

144 

145 def __init__( 

146 self, 

147 dataset: xr.Dataset, 

148 label_scheme: Optional[MultivariateLabelScheme] = None, 

149 ): 

150 super().__init__(dataset=dataset) 

151 self._label_scheme = label_scheme 

152 

153 self._channels = None 

154 self._stations = None 

155 self._station_channels = None 

156 

157 @property 

158 def label_scheme(self) -> MultivariateLabelScheme: 

159 if self._label_scheme is None: 

160 msg = f"No label scheme found for {self.__class__} -- setting to default" 

161 logger.warning(msg) 

162 self._label_scheme = MultivariateLabelScheme() 

163 return self._label_scheme 

164 

165 @property 

166 def channels(self) -> list: 

167 """ 

168 returns a list of channels in the dataarray 

169 """ 

170 if self._channels is None: 

171 self._channels = list(self.dataarray.coords["variable"].values) 

172 return self._channels 

173 

174 @property 

175 def num_channels(self) -> int: 

176 """returns a count of the total number of channels in the dataset""" 

177 return len(self.channels) 

178 

179 @property 

180 def stations(self) -> List[str]: 

181 """ 

182 Parses the channel names, extracts the station names 

183 

184 return a unique list of stations preserving order. 

185 """ 

186 if self._stations is None: 

187 if self.label_scheme.id == "station_component": 

188 tmp = [self.label_scheme.split(x)["station"] for x in self.channels] 

189 # tmp = [x.split("_")[0] for x in self.channels] 

190 stations = list(dict.fromkeys(tmp)) # order preserving unique values 

191 self._stations = stations 

192 else: 

193 msg = f"No rule for parsting station names from label scheme {self.label_scheme.id}" 

194 raise NotImplementedError(msg) 

195 

196 return self._stations 

197 

198 def station_channels( 

199 self, 

200 station: str, 

201 ) -> List[str]: 

202 """ 

203 This is a utility function that provides a way to access channel_names in a multivariate array associated 

204 with a particular station. 

205 The list is accessed via the self._station_channels attr, which gets set here if it has not 

206 been initialized previously. self._station_channels is a dict keyed by station_id, with value 

207 is a list of channel names for that station. 

208 

209 :param station: The name of the station. 

210 :type station: str 

211 

212 :rtype: List[str] 

213 :returns: list of channel names for the input station. 

214 

215 """ 

216 # set self._station_channels is not already done 

217 if self._station_channels is None: 

218 station_channels = {} 

219 for station_id in self.stations: 

220 station_channels[station_id] = self._get_station_channel_names( 

221 station_id, 

222 multivariate_labels=True, 

223 ) 

224 self._station_channels = station_channels 

225 

226 return self._station_channels[station] 

227 

228 def _get_station_channel_names( 

229 self, station: str, multivariate_labels: bool = True 

230 ) -> List[str]: 

231 """ 

232 

233 This is a utility function that to get all channel names in a multivariate array associated 

234 with a particular station. 

235 

236 :param station: The name of the station. 

237 :type station: str 

238 :param multivariate_labels: When set to true, returned values have the "full multivariate" channel names, 

239 e.g. station "mt1" may return for example "mt1_ex", "mt1_ey", "mt1_hx" ... etc. If set to false the names 

240 will be returned within the context of a station, so they may be for example "ex", "ey", "hx" ... etc. 

241 The default value is True. 

242 :type multivariate_labels: bool 

243 

244 :rtype: List[str] 

245 :returns: Channel names for the input station. 

246 

247 """ 

248 station_channels = [ 

249 x 

250 for x in self.channels 

251 if station == x.split(self.label_scheme.join_char)[0] 

252 ] 

253 if not multivariate_labels: 

254 station_channels = [ 

255 x.split(self.label_scheme.join_char)[1] for x in station_channels 

256 ] 

257 

258 return station_channels 

259 

260 def archive_cross_powers( 

261 self, 

262 tf_station: str, 

263 with_fcs: bool = True, 

264 ): 

265 """ 

266 tf_station: str 

267 This tells us under which station we should store the output of this function. 

268 TODO: Consider moving this to another function which performs archiving in future. 

269 

270 with_fcs: bool 

271 If True, the features are packed into the same hdf5-group as the FCs, 

272 as its own dataset. 

273 If False: the features are packed into the hdf5 features-group. 

274 

275 Returns 

276 ------- 

277 

278 """ 

279 

280 # TODO: Replace with Spectrogram's covariance_matrix 

281 def cross_power( 

282 self, aweights: Optional[np.ndarray] = None, bias: Optional[bool] = True 

283 ) -> xr.DataArray: 

284 """ 

285 Calculate the cross-power from a multivariate, complex-valued array of Fourier coefficients. 

286 

287 For a multivaraiate FC Dataset with n_time time windows, this returns an array with the same number of time 

288 windows. At each time _t_, the result is a covariance matrix. 

289 

290 Caveats and Notes: 

291 - This method calls numpy.cov, which means that the cross-power is computes as X@XH (rather than 

292 XH@X). Sometimes X*XH is referred to as the Vozoff convention, whereas XH*X could be the 

293 Bendat & Piersol convention. 

294 - np.cov subtracts the meas before computing the cross terms. 

295 - This methos will use the entire band of the spectrogram. 

296 

297 :param X: Multivariate time series as an xarray 

298 :type X: xr.DataArray 

299 :param aweights: This is a "passthrough" parameter to numpy.cov These relative weights are typically large for 

300 observations considered "important" and smaller for observations considered less "important". If ``ddof=0`` 

301 the array of weights can be used to assign probabilities to observation vectors. 

302 :type aweights: Optional[np.ndarray] 

303 :param bias: bias=True normalizes by N instead of (N-1). 

304 :type bias: bool 

305 

306 :rtype: xr.DataArray 

307 :return: The covariance matrix of the data in xarray form. 

308 

309 """ 

310 X = self.dataarray 

311 channels = list(X.coords["variable"].values) 

312 

313 S = xr.DataArray( 

314 np.cov(X, aweights=aweights, bias=bias), 

315 dims=["channel_1", "channel_2"], 

316 coords={"channel_1": channels, "channel_2": channels}, 

317 ) 

318 return S 

319 

320 

321# Weights vs masks 

322 

323 

324def calculate_mask_from_feature( 

325 feature_series, 

326 threshold_obj, # has lower/upper bound, can be -inf, inf 

327): 

328 """ 

329 

330 Returns 

331 ------- 

332 

333 """ 

334 mask1 = feature_series < threshold_obj.lower_bound 

335 mask2 = feature_series > threshold_obj.upper_bound 

336 return mask1 & mask2 

337 

338 

339def calculate_weight_from_feature( 

340 feature_series, 

341 threshold_obj, # has lower/upper bound, can be -inf, inf 

342): 

343 """ 

344 This calculates a weighting function based on the thresholds 

345 and possibly some other info, such as the distribution of the features. 

346 

347 The weigth function is interpolated over the range of the feature values 

348 and then evaluated at the feature values. 

349 Parameters 

350 ---------- 

351 feature_series 

352 threshold_obj 

353 

354 Returns 

355 ------- 

356 

357 """ 

358 

359 

360def merge_masks(): 

361 """ 

362 calcualtes a "final mask" that is loaded and applied to the data 

363 input to regression 

364 """ 

365 

366 

367def merge_weights(): 

368 """ 

369 calcualtes a "final mask" that is loaded and applied to the data 

370 input to regression 

371 Returns 

372 ------- 

373 

374 """ 

375 

376 

377# TODO: add this method to tf-estimation right before robust regression. 

378def apply_masks_and_weights(): 

379 pass 

380 

381 

382def make_multistation_spectrogram( 

383 m: mth5.mth5.MTH5, 

384 fc_run_chunks: list, 

385 label_scheme: Optional[MultivariateLabelScheme] = MultivariateLabelScheme(), 

386 rtype: Optional[Literal["xrds"]] = None, 

387) -> Union[xr.Dataset, MultivariateDataset]: 

388 """ 

389 

390 See notes in mth5 issue #209. Takes a list of FCRunChunks and returns the largest contiguous 

391 block of multichannel FC data available. 

392 

393 |----------Station 1 ------------| 

394 |----------Station 2 ------------| 

395 |--------------------Station 3 ----------------------| 

396 

397 

398 |-------RETURNED------| 

399 

400 Handle additional runs in a separate call to this function and then concatenate time series afterwards. 

401 

402 Input must specify N (station-run-start-end-channel_list) tuples. 

403 If channel_list is not provided, get all channels. 

404 If start-end are not provided, read the whole run -- warn if runs are not all synchronous, and 

405 truncate all to max(starts), min(ends) after the start and end times are sorted out. 

406 

407 Station IDs must be unique. 

408 

409 :param m: The mth5 object to get the FCs from. 

410 :type m: mth5.mth5.MTH5 

411 :param fc_run_chunks: Each element of this describes a chunk of a run to load from stored FCs. 

412 :type fc_run_chunks: list 

413 :param label_scheme: Specifies how the channels are to be named in the multivariate xarray. 

414 :type label_scheme: Optional[MultivariateLabelScheme] 

415 :param rtype: Specifies whether to return an xarray or a MultivariateDataset. Currently only supports "xrds", 

416 otherwise will return MultivariateDataset. 

417 :type rtype: Optional[Literal["xrds"]] 

418 

419 :rtype: Union[xarray.Dataset, MultivariateDataset]: 

420 :return: The multivariate dataset, either as an xarray or as a MultivariateDataset 

421 

422 """ 

423 for i_fcrc, fcrc in enumerate(fc_run_chunks): 

424 station_obj = m.get_station(fcrc.station_id, fcrc.survey_id) 

425 station_fc_group = station_obj.fourier_coefficients_group 

426 try: 

427 run_fc_group = station_obj.fourier_coefficients_group.get_fc_group( 

428 fcrc.run_id 

429 ) 

430 except MTH5Error as e: 

431 error_msg = f"Failed to get fc group {fcrc.run_id}" 

432 logger.error(error_msg) 

433 msg = f"Available FC Groups for station {fcrc.station_id}: " 

434 msg = f"{msg} {station_fc_group.groups_list}" 

435 logger.error(msg) 

436 logger.error(f"Maybe try adding FCs for {fcrc.run_id}") 

437 raise e # MTH5Error(error_msg) 

438 

439 fc_dec_level = run_fc_group.get_decimation_level(fcrc.decimation_level_id) 

440 if fcrc.channels: 

441 channels = list(fcrc.channels) 

442 else: 

443 channels = None 

444 

445 fc_dec_level_xrds = fc_dec_level.to_xarray(channels=channels) 

446 # could create name mapper dict from run_fc_group.channel_summary here if we wanted to. 

447 

448 if fcrc.start: 

449 # TODO: Push slicing into the to_xarray() command so we only access what we need -- See issue #212 

450 cond = fc_dec_level_xrds.time >= fcrc.start_timestamp 

451 msg = f"trimming {sum(~cond.data)} samples to {fcrc.start} " 

452 logger.info(msg) 

453 fc_dec_level_xrds = fc_dec_level_xrds.where(cond) 

454 fc_dec_level_xrds = fc_dec_level_xrds.dropna(dim="time") 

455 

456 if fcrc.end: 

457 # TODO: Push slicing into the to_xarray() command so we only access what we need -- See issue #212 

458 cond = fc_dec_level_xrds.time <= fcrc.end_timestamp 

459 msg = f"trimming {sum(~cond.data)} samples to {fcrc.end} " 

460 logger.info(msg) 

461 fc_dec_level_xrds = fc_dec_level_xrds.where(cond) 

462 fc_dec_level_xrds = fc_dec_level_xrds.dropna(dim="time") 

463 

464 if label_scheme.id == "station_component": 

465 name_dict = { 

466 f"{x}": label_scheme.join((fcrc.station_id, x)) 

467 for x in fc_dec_level_xrds.data_vars 

468 } 

469 else: 

470 msg = f"Label Scheme elements {label_scheme.id} not implemented" 

471 raise NotImplementedError(msg) 

472 

473 if i_fcrc == 0: 

474 xrds = fc_dec_level_xrds.rename_vars(name_dict=name_dict) 

475 else: 

476 fc_dec_level_xrds = fc_dec_level_xrds.rename_vars(name_dict=name_dict) 

477 xrds = xrds.merge(fc_dec_level_xrds) 

478 

479 # Check that no nan came about as a result of the merge 

480 if bool(xrds.to_array().isnull().any()): 

481 msg = "Nan detected in multistation spectrogram" 

482 logger.warning(msg) 

483 

484 if rtype == "xrds": 

485 output = xrds 

486 else: 

487 output = MultivariateDataset(dataset=xrds, label_scheme=label_scheme) 

488 

489 return output