Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ timeseries \ spectre \ spectrogram.py: 67%
132 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-27 20:09 -0800
1"""
2Module contains a class that represents a spectrogram.
3i.e. A 2D time series of Fourier coefficients with axes time and the other frequency.
4The datasets are xarray/dataframe and are fundmentally multivariate.
6"""
8from typing import List, Literal, Optional, Tuple, Union
10# Third-party imports
11import pandas as pd
12import xarray as xr
14# Standard library imports
15from loguru import logger
17# Local imports
18from mt_metadata.common.band import Band
19from mt_metadata.processing.aurora.frequency_bands import FrequencyBands
21from mth5.timeseries.xarray_helpers import covariance_xr, initialize_xrda_2d
24class Spectrogram(object):
25 """
26 Class to contain methods for STFT objects.
28 TODO: Add OLS Z-estimates -- actually, these are properties of cross powers, not direct properties of spectrograms.
29 TODO: Add Sims/Vozoff Z-estimates -- actually, these are properties of cross powers as well.
30 **Note** Coherence is similarly, a property of cross powers.
31 There are in fact, very few features that we would derive from an unaveraged spectrogram. Pretty much
32 everything except statistical moments comes from cross powers.
34 Development Notes:
35 - The spectrogram class is fundamental to MT Processing, and normally appears during the STFT operation.
36 - The extract_band method returns another Spectrogram, having the same time axis as the parent
37 object, but only a slice of the frequency range. Both of these have in common that their frequency axes
38 are uniformly spaced, delta-f, where delta-f is dictated by the time series sample rate and the FFT window
39 lenght.
40 - There is a sibling spectral-time-series container that should be considered. Call it for now, a
41 FrequencyChunkedSpectrogram (or an AveragedSpectrogram). This is a container similar to spectrogram, but
42 the frequencies are not uniformly spaced (instead, often logartihmically spaced), they are made from one or
43 more (possibly multivariate) spectrograms, and a FrequencyBands object. The key difference
44 is that in a FrequencyChunkedSpectrogram object has a non-uniform spaced the Frequency axis which was prescribed
45 by a metadata object. Most features, as well as TFs have a FrequencyChunkedSpectrogram representation,
46 where final TFs are just time-averaged a FrequencyChunkedSpectrograms.
48 TODO: consider factoring a simpler class that does not make the uniform frequency axis assumption.
49 Spectrogram would extend this class and add the _frequency_increment property (taken from the differece in
50 the first two values of the frequency axis), and num_harmoincs in band.
52 """
54 def __init__(self, dataset: Optional[xr.Dataset] = None):
55 """
56 Constructor.
58 """
59 self._dataset = dataset
60 self._frequency_increment = None
61 self._frequency_band = None
63 def _lowest_frequency(self): # -> float:
64 pass # return self.dataset.frequency.min
66 def _highest_frequency(self): # -> float:
67 pass # return self.dataset.frequency.max
69 def __str__(self) -> str:
70 """Returns a Description of frequency coverage"""
71 if self.dataset is None:
72 return "Dataless Spectrogram"
73 intro = "Spectrogram:"
74 frequency_coverage = (
75 f"{self.dataset.sizes['frequency']} harmonics, {self.frequency_increment}Hz spaced \n"
76 f" from {self.dataset.frequency.data[0]} to {self.dataset.frequency.data[-1]} Hz."
77 )
78 time_coverage = f"\n{self.dataset.sizes['time']} Time observations"
79 time_coverage = f"{time_coverage} \nStart: {self.dataset.time.data[0]}"
80 time_coverage = f"{time_coverage} \nEnd: {self.dataset.time.data[-1]}"
82 channel_coverage = list(self.dataset.data_vars.keys())
83 channel_coverage = "\n".join(channel_coverage)
84 channel_coverage = f"\nChannels present: \n{channel_coverage}"
85 return (
86 intro
87 + "\n"
88 + frequency_coverage
89 + "\n"
90 + time_coverage
91 + "\n"
92 + channel_coverage
93 )
95 def __repr__(self) -> str:
96 return self.__str__()
98 @property
99 def dataset(self):
100 """returns the underlying xarray data"""
101 return self._dataset
103 @property
104 def dataarray(self):
105 """returns the underlying xarray data"""
106 return self._dataset.to_array()
108 @property
109 def time_axis(self):
110 """returns the time axis of the underlying xarray"""
111 return self.dataset.time
113 @property
114 def frequency_axis(self):
115 """returns the frequency axis of the underlying xarray"""
116 return self.dataset.frequency
118 @property
119 def frequency_band(self) -> Band:
120 """returns a frequency band object representing the spectrograms band (assumes continuous)"""
121 if self._frequency_band is None:
122 band = Band(
123 frequency_min=self.frequency_axis.min().item(),
124 frequency_max=self.frequency_axis.max().item(),
125 )
126 self._frequency_band = band
127 return self._frequency_band
129 @property
130 def frequency_increment(self):
131 """
132 returns the "delta f" of the frequency axis
133 - assumes uniformly sampled in frequency domain
134 """
135 if self._frequency_increment is None:
136 frequency_axis = self.dataset.frequency
137 try:
138 self._frequency_increment = (
139 frequency_axis.data[1] - frequency_axis.data[0]
140 )
141 except IndexError:
142 msg = "frequency increment for spectrogram with frequency axis of length 1 is not defined"
143 logger.debug(msg)
144 self._frequency_increment = "undefined"
145 return self._frequency_increment
147 def num_harmonics_in_band(self, frequency_band: Band, epsilon: float = 1e-7) -> int:
148 """
150 Returns the number of harmonics within the frequency band in the underlying dataset
152 Parameters
153 ----------
154 frequency_band
155 stft_obj
157 Returns
158 -------
159 num_harmonics: int
160 The number of harmonics in the underlying dataset within the given frequency band.
162 """
163 extracted_spectrogram = self.extract_band(frequency_band, epsilon=epsilon)
164 num_harmonics = len(extracted_spectrogram.frequency_axis)
165 return num_harmonics
167 def extract_band(
168 self,
169 frequency_band: Band,
170 channels: Optional[list] = None,
171 epsilon: Optional[float] = None,
172 ):
173 """
174 Returns another instance of Spectrogram, with the frequency axis reduced to the input band.
176 Parameters
177 ----------
178 frequency_band
179 channels
181 Returns
182 -------
183 spectrogram: aurora.time_series.spectrogram.Spectrogram
184 Returns a Spectrogram object with only the extracted band for a dataset
186 """
187 # Set epsilon to a floating point value if it was not provided
188 # self.frequency_increment / 2.0 is the legacy default
189 if epsilon is None:
190 epsilon = self.frequency_increment / 2.0
192 extracted_band_dataset = extract_band(
193 frequency_band, self.dataset, channels=channels, epsilon=epsilon
194 )
195 # Drop NaN values along the frequency dimension
196 # extracted_band_dataset = extracted_band_dataset.dropna(dim='frequency', how='any')
197 spectrogram = Spectrogram(dataset=extracted_band_dataset)
198 return spectrogram
200 def cross_power_label(self, ch1: str, ch2: str, join_char: str = "_"):
201 """joins channel names with join_char"""
202 return f"{ch1}{join_char}{ch2}"
204 def _validate_frequency_bands(
205 self,
206 frequency_bands: FrequencyBands,
207 strict: bool = True,
208 ):
209 """
210 Make sure that the frequency bands passed are relevant. If not, drop and warn.
212 :param frequency_bands: A collection of bands
213 :type frequency_bands: FrequencyBands
214 :param strict: If true, band must be contained to be valid, if false, any overlapping band is valid.
215 :type strict: bool
216 :return:
217 """
218 if strict:
219 valid_bands = [
220 x for x in frequency_bands.bands() if self.frequency_band.contains(x)
221 ]
222 else:
223 valid_bands = [
224 x for x in frequency_bands.bands() if self.frequency_band.overlaps(x)
225 ]
226 lower_bounds = [x.lower_bound for x in valid_bands]
227 upper_bounds = [x.upper_bound for x in valid_bands]
228 valid_frequency_bands = FrequencyBands(
229 pd.DataFrame(
230 data={
231 "lower_bound": lower_bounds,
232 "upper_bound": upper_bounds,
233 }
234 )
235 )
237 # TODO: If strict, only take bands that are contained
238 return valid_frequency_bands
240 def cross_powers(
241 self,
242 frequency_bands: FrequencyBands,
243 channel_pairs: Optional[List[Tuple[str, str]]] = None,
244 ):
245 """
246 Compute cross powers between channel pairs for given frequency bands.
248 TODO: Add handling for case when band in frequency_bands is not contained
249 in self.frequencies.
251 Parameters
252 ----------
253 frequency_bands : FrequencyBands
254 The frequency bands to compute cross powers for. Each element of this iterable
255 tells the lower and upper bounds of the cross-power calculation bands.
256 These may become objects with information about tapers as ewwll.
257 channel_pairs : list of tuples, optional
258 List of channel pairs to compute cross powers for.
259 If None, all possible pairs will be used.
261 Returns
262 -------
263 xr.Dataset
264 Dataset containing cross powers for all channel pairs.
265 Each variable is named by the channel pair (e.g. 'ex_hy')
266 and contains a 2D array with dimensions (frequency, time).
267 All variables share common frequency and time coordinates.
268 """
269 from itertools import combinations_with_replacement
271 valid_frequency_bands = self._validate_frequency_bands(frequency_bands)
273 # If no channel pairs specified, use all possible pairs
274 if channel_pairs is None:
275 channels = list(self.dataset.data_vars.keys())
276 channel_pairs = list(combinations_with_replacement(channels, 2))
278 # Create variable names from channel pairs
279 var_names = [self.cross_power_label(ch1, ch2) for ch1, ch2 in channel_pairs]
281 # Initialize a single multi-channel 2D xarray
282 xpower_array = initialize_xrda_2d(
283 var_names,
284 coords={
285 "frequency": frequency_bands.band_centers(),
286 "time": self.dataset.time.values,
287 },
288 dtype=complex,
289 )
291 # Compute cross powers for each band and channel pair
292 for band in valid_frequency_bands.bands():
293 # Extract band data
294 band_data = self.extract_band(band).dataset
296 # Compute cross powers for each channel pair
297 for ch1, ch2 in channel_pairs:
298 label = self.cross_power_label(ch1, ch2)
299 # Always compute as ch1 * conj(ch2)
300 xpower = (band_data[ch1] * band_data[ch2].conj()).mean(dim="frequency")
302 # Store the cross power
303 xpower_array.loc[
304 dict(
305 frequency=band.center_frequency,
306 variable=label,
307 time=slice(None),
308 )
309 ] = xpower
311 return xpower_array
313 def covariance_matrix(
314 self, band_data: Optional["Spectrogram"] = None, method: str = "numpy_cov"
315 ) -> xr.DataArray:
316 """
317 TODO: Add tests for this WIP Work-in-progress method
318 Compute full covariance matrix for spectrogram data.
320 For complex-valued data, the result is a Hermitian matrix where:
321 - diagonal elements are real-valued variances
322 - off-diagonal element [i,j] is E[ch_i * conj(ch_j)]
323 - off-diagonal element [j,i] is the complex conjugate of [i,j]
325 Parameters
326 ----------
327 band_data : Spectrogram, optional
328 If provided, compute covariance for this data
329 If None, use the full spectrogram
330 method : str
331 Computation method. Currently only supports 'numpy_cov'
333 Returns
334 -------
335 xr.DataArray
336 Hermitian covariance matrix with proper channel labeling
337 For channels i,j: matrix[i,j] = E[ch_i * conj(ch_j)]
338 """
339 data = band_data or self
340 flat_data = data.flatten(chunk_by="time")
342 if method == "numpy_cov":
343 # Convert to DataArray for covariance_xr
344 stacked = flat_data.to_array(dim="variable")
345 return covariance_xr(stacked)
346 else:
347 raise ValueError(f"Unknown method: {method}")
349 def _get_all_channel_pairs(self) -> List[Tuple[str, str]]:
350 """Get all unique channel pairs (upper triangle)"""
351 channels = list(self.dataset.data_vars.keys())
352 pairs = []
353 for i, ch1 in enumerate(channels[:-1]):
354 for ch2 in channels[i + 1 :]:
355 pairs.append((ch1, ch2))
356 return pairs
358 def flatten(self, chunk_by: Literal["time", "frequency"] = "time") -> xr.Dataset:
359 """
361 Reshape the 2D spectrogram into a 1D flattened xarray (time-chunked by default).
363 Parameters
364 ----------
365 chunk_by: Literal["time", "frequency"]
366 Reshaping the 2D spectrogram can be done two ways, (basically "row-major",
367 or column-major). In xarray, but we either keep frequency constant and iterate
368 over time, or keep time constant and iterate over frequency (in the inner loop).
370 Returns
371 -------
372 xarray.Dataset : The dataset from the band spectrogram, stacked.
374 Development Notes:
375 The flattening used in tf calculation by default is opposite to here
376 dataset.stack(observation=("frequency", "time"))
377 However, for feature extraction, it may make sense to swap the order:
378 xrds = band_spectrogram.dataset.stack(observation=("time", "frequency"))
379 This is like chunking into time windows and allows individual features to be computed on each time window -- if desired.
380 Still need to split the time series though--Splitting to time would be a reshape by (last_freq_index-first_freq_index).
381 Using pure xarray this may not matter but if we drop down into numpy it could be useful.
384 """
385 if chunk_by == "time":
386 observation = ("time", "frequency")
387 elif chunk_by == "frequency":
388 observation = ("frequency", "time")
389 else:
390 msg = f"Invalid argument chunk_by={chunk_by}, must be one of ['time', 'frequency']"
391 logger.error(msg)
392 raise ValueError(msg)
394 return self.dataset.stack(observation=observation)
397def extract_band(
398 frequency_band: Band,
399 fft_obj: Union[xr.Dataset, xr.DataArray],
400 channels: Optional[list] = None,
401 epsilon: float = 1e-7,
402) -> Union[xr.Dataset, xr.DataArray]:
403 """
404 Extracts a frequency band from xr.DataArray representing a spectrogram.
406 TODO: Update variable names.
408 Development Notes:
409 Base dataset object should be a xr.DataArray (not xr.Dataset)
410 - drop=True does not play nice with h5py and Dataset, results in a type error.
411 File "stringsource", line 2, in h5py.h5r.Reference.__reduce_cython__
412 TypeError: no default __reduce__ due to non-trivial __cinit__
413 However, it works OK with DataArray.
415 Parameters
416 ----------
417 frequency_band: mt_metadata.common.band.Band
418 Specifies interval corresponding to a frequency band
419 fft_obj: xarray.core.dataset.Dataset
420 Short-time-Fourier-transformed datat. Can be multichannel.
421 channels: list
422 Channel names to extract.
423 epsilon: float
424 Use this when you are worried about missing a frequency due to
425 round off error. This is in general not needed if we use a df/2 pad
426 around true harmonics.
428 Returns
429 -------
430 extracted_band: xr.DataArray
431 The frequencies within the band passed into this function
432 """
433 cond1 = fft_obj.frequency >= frequency_band.lower_bound - epsilon
434 cond2 = fft_obj.frequency <= frequency_band.upper_bound + epsilon
435 try:
436 extracted_band = fft_obj.where(cond1 & cond2, drop=True)
437 except TypeError: # see Note #1
438 tmp = fft_obj.to_array()
439 extracted_band = tmp.where(cond1 & cond2, drop=True)
440 extracted_band = extracted_band.to_dataset("variable")
441 if channels:
442 extracted_band = extracted_band[channels]
444 if len(extracted_band.frequency) == 0:
445 msg = (
446 f"Frequency band {frequency_band} does not overlap with the frequencies "
447 f"of the input dataset. Frequencies in dataset are: {fft_obj.frequency.values}. "
448 "Skipping band extraction. Consider reforming the bands."
449 )
450 logger.warning(msg)
451 return extracted_band