Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ timeseries \ spectre \ multiple_station.py: 67%
150 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:01 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:01 -0800
1"""
2 Work In progress
4 This module is concerned with working with Fourier coefficient data
6 TODO:
7 2. Give MultivariateDataset a covariance() method
9 Tools include prototypes for
10 - extracting portions of an FC Run Time Series
11 - merging multiple stations runs together into an xarray
12 - relabelling channels to avoid namespace clashes for multi-station data
14"""
16from dataclasses import dataclass
17from typing import List, Literal, Optional, Tuple, Union
19import numpy as np
20import pandas as pd
21import xarray as xr
22from loguru import logger
24import mth5.mth5
25from mth5.timeseries.spectre.spectrogram import Spectrogram
26from mth5.utils.exceptions import MTH5Error
29@dataclass
30class FCRunChunk:
31 """
33 This class formalizes the required metadata to specify a chunk of a timeseries of Fourier coefficients.
35 This may move to mt_metadata -- for now just use a dataclass as a prototype.
36 """
38 survey_id: str = "none"
39 station_id: str = ""
40 run_id: str = ""
41 decimation_level_id: str = "0"
42 start: str = ""
43 end: str = ""
44 channels: Tuple[str] = ()
46 @property
47 def start_timestamp(self) -> pd.Timestamp:
48 return pd.Timestamp(self.start)
50 @property
51 def end_timestamp(self) -> pd.Timestamp:
52 return pd.Timestamp(self.end)
54 @property
55 def duration(self) -> pd.Timestamp:
56 return self.end_timestamp - self.start_timestamp
59@dataclass
60class MultivariateLabelScheme:
61 """
62 Class to store information about how a multivariate (MV) dataset will be lablelled.
64 Has a scheme to handle the how channels will be named.
66 This is just a place holder to manage possible future complexity.
68 It seemed like a good idea to formalize the fact that we take, by default
69 f"{station}_{component}" as the MV channel label.
70 It also seemed like a good idea to record what the join character is.
71 In the event that we wind up with station names that have underscores in them, then we could,
72 for example, set the join character to "__".
74 TODO: Consider rename default to ("station", "data_var") instead of ("station", "component")
76 Parameters
77 ----------
78 :type label_elements: tuple
79 :param label_elements: This is meant to tell what information is being concatenated into an MV channel label.
80 :type join_char: str
81 :param join_char: The string that is used to join the label elements.
83 """
85 label_elements: tuple = (
86 "station",
87 "component",
88 )
89 join_char: str = "_"
91 @property
92 def id(self) -> str:
93 return self.join(self.label_elements)
95 def join(self, elements: Union[list, tuple]) -> str:
96 """
98 Join the label elements to a string
100 :type elements: tuple
101 :param elements: Expected to be the label elements, default are (station, component)
103 :return: The name of the channel (in a multiple-station context).
104 :rtype: str
106 """
107 return self.join_char.join(elements)
109 def split(self, mv_channel_name) -> dict:
110 """
112 Splits a multi-station channel name and returns a dict of strings, keyed by self.label_elements.
113 This method is basically the reverse of self.join
115 :param mv_channel_name: a multivariate channel name string
116 :type mv_channel_name: str
117 :return: Channel name as a dictionary.
118 :rtype: dict
120 """
121 splitted = mv_channel_name.split(self.join_char)
122 if len(splitted) != len(self.label_elements):
123 msg = f"Incompatable map {splitted} and {self.label_elements}"
124 logger.error(msg)
125 msg = f"cannot map {len(splitted)} to {len(self.label_elements)}"
126 raise ValueError(msg)
127 output = dict(zip(self.label_elements, splitted))
128 return output
131class MultivariateDataset(Spectrogram):
132 """
133 Here is a container for a multivariate spectral dataset.
134 The xarray is the main underlying item, but it will be useful to have functions that, for example returns a
135 list of the associated stations, or that return a list of channels that are associated with a station, etc.
137 This is intended to be used as a multivariate spectral dotaset at one frequency band.
139 TODO: Consider making this an extension of Spectrogram
140 TODO: Rename this class to MultivariateSpectrogram.
143 """
145 def __init__(
146 self,
147 dataset: xr.Dataset,
148 label_scheme: Optional[MultivariateLabelScheme] = None,
149 ):
150 super().__init__(dataset=dataset)
151 self._label_scheme = label_scheme
153 self._channels = None
154 self._stations = None
155 self._station_channels = None
157 @property
158 def label_scheme(self) -> MultivariateLabelScheme:
159 if self._label_scheme is None:
160 msg = f"No label scheme found for {self.__class__} -- setting to default"
161 logger.warning(msg)
162 self._label_scheme = MultivariateLabelScheme()
163 return self._label_scheme
165 @property
166 def channels(self) -> list:
167 """
168 returns a list of channels in the dataarray
169 """
170 if self._channels is None:
171 self._channels = list(self.dataarray.coords["variable"].values)
172 return self._channels
174 @property
175 def num_channels(self) -> int:
176 """returns a count of the total number of channels in the dataset"""
177 return len(self.channels)
179 @property
180 def stations(self) -> List[str]:
181 """
182 Parses the channel names, extracts the station names
184 return a unique list of stations preserving order.
185 """
186 if self._stations is None:
187 if self.label_scheme.id == "station_component":
188 tmp = [self.label_scheme.split(x)["station"] for x in self.channels]
189 # tmp = [x.split("_")[0] for x in self.channels]
190 stations = list(dict.fromkeys(tmp)) # order preserving unique values
191 self._stations = stations
192 else:
193 msg = f"No rule for parsting station names from label scheme {self.label_scheme.id}"
194 raise NotImplementedError(msg)
196 return self._stations
198 def station_channels(
199 self,
200 station: str,
201 ) -> List[str]:
202 """
203 This is a utility function that provides a way to access channel_names in a multivariate array associated
204 with a particular station.
205 The list is accessed via the self._station_channels attr, which gets set here if it has not
206 been initialized previously. self._station_channels is a dict keyed by station_id, with value
207 is a list of channel names for that station.
209 :param station: The name of the station.
210 :type station: str
212 :rtype: List[str]
213 :returns: list of channel names for the input station.
215 """
216 # set self._station_channels is not already done
217 if self._station_channels is None:
218 station_channels = {}
219 for station_id in self.stations:
220 station_channels[station_id] = self._get_station_channel_names(
221 station_id,
222 multivariate_labels=True,
223 )
224 self._station_channels = station_channels
226 return self._station_channels[station]
228 def _get_station_channel_names(
229 self, station: str, multivariate_labels: bool = True
230 ) -> List[str]:
231 """
233 This is a utility function that to get all channel names in a multivariate array associated
234 with a particular station.
236 :param station: The name of the station.
237 :type station: str
238 :param multivariate_labels: When set to true, returned values have the "full multivariate" channel names,
239 e.g. station "mt1" may return for example "mt1_ex", "mt1_ey", "mt1_hx" ... etc. If set to false the names
240 will be returned within the context of a station, so they may be for example "ex", "ey", "hx" ... etc.
241 The default value is True.
242 :type multivariate_labels: bool
244 :rtype: List[str]
245 :returns: Channel names for the input station.
247 """
248 station_channels = [
249 x
250 for x in self.channels
251 if station == x.split(self.label_scheme.join_char)[0]
252 ]
253 if not multivariate_labels:
254 station_channels = [
255 x.split(self.label_scheme.join_char)[1] for x in station_channels
256 ]
258 return station_channels
260 def archive_cross_powers(
261 self,
262 tf_station: str,
263 with_fcs: bool = True,
264 ):
265 """
266 tf_station: str
267 This tells us under which station we should store the output of this function.
268 TODO: Consider moving this to another function which performs archiving in future.
270 with_fcs: bool
271 If True, the features are packed into the same hdf5-group as the FCs,
272 as its own dataset.
273 If False: the features are packed into the hdf5 features-group.
275 Returns
276 -------
278 """
280 # TODO: Replace with Spectrogram's covariance_matrix
281 def cross_power(
282 self, aweights: Optional[np.ndarray] = None, bias: Optional[bool] = True
283 ) -> xr.DataArray:
284 """
285 Calculate the cross-power from a multivariate, complex-valued array of Fourier coefficients.
287 For a multivaraiate FC Dataset with n_time time windows, this returns an array with the same number of time
288 windows. At each time _t_, the result is a covariance matrix.
290 Caveats and Notes:
291 - This method calls numpy.cov, which means that the cross-power is computes as X@XH (rather than
292 XH@X). Sometimes X*XH is referred to as the Vozoff convention, whereas XH*X could be the
293 Bendat & Piersol convention.
294 - np.cov subtracts the meas before computing the cross terms.
295 - This methos will use the entire band of the spectrogram.
297 :param X: Multivariate time series as an xarray
298 :type X: xr.DataArray
299 :param aweights: This is a "passthrough" parameter to numpy.cov These relative weights are typically large for
300 observations considered "important" and smaller for observations considered less "important". If ``ddof=0``
301 the array of weights can be used to assign probabilities to observation vectors.
302 :type aweights: Optional[np.ndarray]
303 :param bias: bias=True normalizes by N instead of (N-1).
304 :type bias: bool
306 :rtype: xr.DataArray
307 :return: The covariance matrix of the data in xarray form.
309 """
310 X = self.dataarray
311 channels = list(X.coords["variable"].values)
313 S = xr.DataArray(
314 np.cov(X, aweights=aweights, bias=bias),
315 dims=["channel_1", "channel_2"],
316 coords={"channel_1": channels, "channel_2": channels},
317 )
318 return S
321# Weights vs masks
324def calculate_mask_from_feature(
325 feature_series,
326 threshold_obj, # has lower/upper bound, can be -inf, inf
327):
328 """
330 Returns
331 -------
333 """
334 mask1 = feature_series < threshold_obj.lower_bound
335 mask2 = feature_series > threshold_obj.upper_bound
336 return mask1 & mask2
339def calculate_weight_from_feature(
340 feature_series,
341 threshold_obj, # has lower/upper bound, can be -inf, inf
342):
343 """
344 This calculates a weighting function based on the thresholds
345 and possibly some other info, such as the distribution of the features.
347 The weigth function is interpolated over the range of the feature values
348 and then evaluated at the feature values.
349 Parameters
350 ----------
351 feature_series
352 threshold_obj
354 Returns
355 -------
357 """
360def merge_masks():
361 """
362 calcualtes a "final mask" that is loaded and applied to the data
363 input to regression
364 """
367def merge_weights():
368 """
369 calcualtes a "final mask" that is loaded and applied to the data
370 input to regression
371 Returns
372 -------
374 """
377# TODO: add this method to tf-estimation right before robust regression.
378def apply_masks_and_weights():
379 pass
382def make_multistation_spectrogram(
383 m: mth5.mth5.MTH5,
384 fc_run_chunks: list,
385 label_scheme: Optional[MultivariateLabelScheme] = MultivariateLabelScheme(),
386 rtype: Optional[Literal["xrds"]] = None,
387) -> Union[xr.Dataset, MultivariateDataset]:
388 """
390 See notes in mth5 issue #209. Takes a list of FCRunChunks and returns the largest contiguous
391 block of multichannel FC data available.
393 |----------Station 1 ------------|
394 |----------Station 2 ------------|
395 |--------------------Station 3 ----------------------|
398 |-------RETURNED------|
400 Handle additional runs in a separate call to this function and then concatenate time series afterwards.
402 Input must specify N (station-run-start-end-channel_list) tuples.
403 If channel_list is not provided, get all channels.
404 If start-end are not provided, read the whole run -- warn if runs are not all synchronous, and
405 truncate all to max(starts), min(ends) after the start and end times are sorted out.
407 Station IDs must be unique.
409 :param m: The mth5 object to get the FCs from.
410 :type m: mth5.mth5.MTH5
411 :param fc_run_chunks: Each element of this describes a chunk of a run to load from stored FCs.
412 :type fc_run_chunks: list
413 :param label_scheme: Specifies how the channels are to be named in the multivariate xarray.
414 :type label_scheme: Optional[MultivariateLabelScheme]
415 :param rtype: Specifies whether to return an xarray or a MultivariateDataset. Currently only supports "xrds",
416 otherwise will return MultivariateDataset.
417 :type rtype: Optional[Literal["xrds"]]
419 :rtype: Union[xarray.Dataset, MultivariateDataset]:
420 :return: The multivariate dataset, either as an xarray or as a MultivariateDataset
422 """
423 for i_fcrc, fcrc in enumerate(fc_run_chunks):
424 station_obj = m.get_station(fcrc.station_id, fcrc.survey_id)
425 station_fc_group = station_obj.fourier_coefficients_group
426 try:
427 run_fc_group = station_obj.fourier_coefficients_group.get_fc_group(
428 fcrc.run_id
429 )
430 except MTH5Error as e:
431 error_msg = f"Failed to get fc group {fcrc.run_id}"
432 logger.error(error_msg)
433 msg = f"Available FC Groups for station {fcrc.station_id}: "
434 msg = f"{msg} {station_fc_group.groups_list}"
435 logger.error(msg)
436 logger.error(f"Maybe try adding FCs for {fcrc.run_id}")
437 raise e # MTH5Error(error_msg)
439 fc_dec_level = run_fc_group.get_decimation_level(fcrc.decimation_level_id)
440 if fcrc.channels:
441 channels = list(fcrc.channels)
442 else:
443 channels = None
445 fc_dec_level_xrds = fc_dec_level.to_xarray(channels=channels)
446 # could create name mapper dict from run_fc_group.channel_summary here if we wanted to.
448 if fcrc.start:
449 # TODO: Push slicing into the to_xarray() command so we only access what we need -- See issue #212
450 cond = fc_dec_level_xrds.time >= fcrc.start_timestamp
451 msg = f"trimming {sum(~cond.data)} samples to {fcrc.start} "
452 logger.info(msg)
453 fc_dec_level_xrds = fc_dec_level_xrds.where(cond)
454 fc_dec_level_xrds = fc_dec_level_xrds.dropna(dim="time")
456 if fcrc.end:
457 # TODO: Push slicing into the to_xarray() command so we only access what we need -- See issue #212
458 cond = fc_dec_level_xrds.time <= fcrc.end_timestamp
459 msg = f"trimming {sum(~cond.data)} samples to {fcrc.end} "
460 logger.info(msg)
461 fc_dec_level_xrds = fc_dec_level_xrds.where(cond)
462 fc_dec_level_xrds = fc_dec_level_xrds.dropna(dim="time")
464 if label_scheme.id == "station_component":
465 name_dict = {
466 f"{x}": label_scheme.join((fcrc.station_id, x))
467 for x in fc_dec_level_xrds.data_vars
468 }
469 else:
470 msg = f"Label Scheme elements {label_scheme.id} not implemented"
471 raise NotImplementedError(msg)
473 if i_fcrc == 0:
474 xrds = fc_dec_level_xrds.rename_vars(name_dict=name_dict)
475 else:
476 fc_dec_level_xrds = fc_dec_level_xrds.rename_vars(name_dict=name_dict)
477 xrds = xrds.merge(fc_dec_level_xrds)
479 # Check that no nan came about as a result of the merge
480 if bool(xrds.to_array().isnull().any()):
481 msg = "Nan detected in multistation spectrogram"
482 logger.warning(msg)
484 if rtype == "xrds":
485 output = xrds
486 else:
487 output = MultivariateDataset(dataset=xrds, label_scheme=label_scheme)
489 return output