Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mth5 \ mth5 \ timeseries \ spectre \ spectrogram.py: 67%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-27 20:09 -0800

1""" 

2Module contains a class that represents a spectrogram. 

3i.e. A 2D time series of Fourier coefficients with axes time and the other frequency. 

4The datasets are xarray/dataframe and are fundmentally multivariate. 

5 

6""" 

7 

8from typing import List, Literal, Optional, Tuple, Union 

9 

10# Third-party imports 

11import pandas as pd 

12import xarray as xr 

13 

14# Standard library imports 

15from loguru import logger 

16 

17# Local imports 

18from mt_metadata.common.band import Band 

19from mt_metadata.processing.aurora.frequency_bands import FrequencyBands 

20 

21from mth5.timeseries.xarray_helpers import covariance_xr, initialize_xrda_2d 

22 

23 

24class Spectrogram(object): 

25 """ 

26 Class to contain methods for STFT objects. 

27 

28 TODO: Add OLS Z-estimates -- actually, these are properties of cross powers, not direct properties of spectrograms. 

29 TODO: Add Sims/Vozoff Z-estimates -- actually, these are properties of cross powers as well. 

30 **Note** Coherence is similarly, a property of cross powers. 

31 There are in fact, very few features that we would derive from an unaveraged spectrogram. Pretty much 

32 everything except statistical moments comes from cross powers. 

33 

34 Development Notes: 

35 - The spectrogram class is fundamental to MT Processing, and normally appears during the STFT operation. 

36 - The extract_band method returns another Spectrogram, having the same time axis as the parent 

37 object, but only a slice of the frequency range. Both of these have in common that their frequency axes 

38 are uniformly spaced, delta-f, where delta-f is dictated by the time series sample rate and the FFT window 

39 lenght. 

40 - There is a sibling spectral-time-series container that should be considered. Call it for now, a 

41 FrequencyChunkedSpectrogram (or an AveragedSpectrogram). This is a container similar to spectrogram, but 

42 the frequencies are not uniformly spaced (instead, often logartihmically spaced), they are made from one or 

43 more (possibly multivariate) spectrograms, and a FrequencyBands object. The key difference 

44 is that in a FrequencyChunkedSpectrogram object has a non-uniform spaced the Frequency axis which was prescribed 

45 by a metadata object. Most features, as well as TFs have a FrequencyChunkedSpectrogram representation, 

46 where final TFs are just time-averaged a FrequencyChunkedSpectrograms. 

47 

48 TODO: consider factoring a simpler class that does not make the uniform frequency axis assumption. 

49 Spectrogram would extend this class and add the _frequency_increment property (taken from the differece in 

50 the first two values of the frequency axis), and num_harmoincs in band. 

51 

52 """ 

53 

54 def __init__(self, dataset: Optional[xr.Dataset] = None): 

55 """ 

56 Constructor. 

57 

58 """ 

59 self._dataset = dataset 

60 self._frequency_increment = None 

61 self._frequency_band = None 

62 

63 def _lowest_frequency(self): # -> float: 

64 pass # return self.dataset.frequency.min 

65 

66 def _highest_frequency(self): # -> float: 

67 pass # return self.dataset.frequency.max 

68 

69 def __str__(self) -> str: 

70 """Returns a Description of frequency coverage""" 

71 if self.dataset is None: 

72 return "Dataless Spectrogram" 

73 intro = "Spectrogram:" 

74 frequency_coverage = ( 

75 f"{self.dataset.sizes['frequency']} harmonics, {self.frequency_increment}Hz spaced \n" 

76 f" from {self.dataset.frequency.data[0]} to {self.dataset.frequency.data[-1]} Hz." 

77 ) 

78 time_coverage = f"\n{self.dataset.sizes['time']} Time observations" 

79 time_coverage = f"{time_coverage} \nStart: {self.dataset.time.data[0]}" 

80 time_coverage = f"{time_coverage} \nEnd: {self.dataset.time.data[-1]}" 

81 

82 channel_coverage = list(self.dataset.data_vars.keys()) 

83 channel_coverage = "\n".join(channel_coverage) 

84 channel_coverage = f"\nChannels present: \n{channel_coverage}" 

85 return ( 

86 intro 

87 + "\n" 

88 + frequency_coverage 

89 + "\n" 

90 + time_coverage 

91 + "\n" 

92 + channel_coverage 

93 ) 

94 

95 def __repr__(self) -> str: 

96 return self.__str__() 

97 

98 @property 

99 def dataset(self): 

100 """returns the underlying xarray data""" 

101 return self._dataset 

102 

103 @property 

104 def dataarray(self): 

105 """returns the underlying xarray data""" 

106 return self._dataset.to_array() 

107 

108 @property 

109 def time_axis(self): 

110 """returns the time axis of the underlying xarray""" 

111 return self.dataset.time 

112 

113 @property 

114 def frequency_axis(self): 

115 """returns the frequency axis of the underlying xarray""" 

116 return self.dataset.frequency 

117 

118 @property 

119 def frequency_band(self) -> Band: 

120 """returns a frequency band object representing the spectrograms band (assumes continuous)""" 

121 if self._frequency_band is None: 

122 band = Band( 

123 frequency_min=self.frequency_axis.min().item(), 

124 frequency_max=self.frequency_axis.max().item(), 

125 ) 

126 self._frequency_band = band 

127 return self._frequency_band 

128 

129 @property 

130 def frequency_increment(self): 

131 """ 

132 returns the "delta f" of the frequency axis 

133 - assumes uniformly sampled in frequency domain 

134 """ 

135 if self._frequency_increment is None: 

136 frequency_axis = self.dataset.frequency 

137 try: 

138 self._frequency_increment = ( 

139 frequency_axis.data[1] - frequency_axis.data[0] 

140 ) 

141 except IndexError: 

142 msg = "frequency increment for spectrogram with frequency axis of length 1 is not defined" 

143 logger.debug(msg) 

144 self._frequency_increment = "undefined" 

145 return self._frequency_increment 

146 

147 def num_harmonics_in_band(self, frequency_band: Band, epsilon: float = 1e-7) -> int: 

148 """ 

149 

150 Returns the number of harmonics within the frequency band in the underlying dataset 

151 

152 Parameters 

153 ---------- 

154 frequency_band 

155 stft_obj 

156 

157 Returns 

158 ------- 

159 num_harmonics: int 

160 The number of harmonics in the underlying dataset within the given frequency band. 

161 

162 """ 

163 extracted_spectrogram = self.extract_band(frequency_band, epsilon=epsilon) 

164 num_harmonics = len(extracted_spectrogram.frequency_axis) 

165 return num_harmonics 

166 

167 def extract_band( 

168 self, 

169 frequency_band: Band, 

170 channels: Optional[list] = None, 

171 epsilon: Optional[float] = None, 

172 ): 

173 """ 

174 Returns another instance of Spectrogram, with the frequency axis reduced to the input band. 

175 

176 Parameters 

177 ---------- 

178 frequency_band 

179 channels 

180 

181 Returns 

182 ------- 

183 spectrogram: aurora.time_series.spectrogram.Spectrogram 

184 Returns a Spectrogram object with only the extracted band for a dataset 

185 

186 """ 

187 # Set epsilon to a floating point value if it was not provided 

188 # self.frequency_increment / 2.0 is the legacy default 

189 if epsilon is None: 

190 epsilon = self.frequency_increment / 2.0 

191 

192 extracted_band_dataset = extract_band( 

193 frequency_band, self.dataset, channels=channels, epsilon=epsilon 

194 ) 

195 # Drop NaN values along the frequency dimension 

196 # extracted_band_dataset = extracted_band_dataset.dropna(dim='frequency', how='any') 

197 spectrogram = Spectrogram(dataset=extracted_band_dataset) 

198 return spectrogram 

199 

200 def cross_power_label(self, ch1: str, ch2: str, join_char: str = "_"): 

201 """joins channel names with join_char""" 

202 return f"{ch1}{join_char}{ch2}" 

203 

204 def _validate_frequency_bands( 

205 self, 

206 frequency_bands: FrequencyBands, 

207 strict: bool = True, 

208 ): 

209 """ 

210 Make sure that the frequency bands passed are relevant. If not, drop and warn. 

211 

212 :param frequency_bands: A collection of bands 

213 :type frequency_bands: FrequencyBands 

214 :param strict: If true, band must be contained to be valid, if false, any overlapping band is valid. 

215 :type strict: bool 

216 :return: 

217 """ 

218 if strict: 

219 valid_bands = [ 

220 x for x in frequency_bands.bands() if self.frequency_band.contains(x) 

221 ] 

222 else: 

223 valid_bands = [ 

224 x for x in frequency_bands.bands() if self.frequency_band.overlaps(x) 

225 ] 

226 lower_bounds = [x.lower_bound for x in valid_bands] 

227 upper_bounds = [x.upper_bound for x in valid_bands] 

228 valid_frequency_bands = FrequencyBands( 

229 pd.DataFrame( 

230 data={ 

231 "lower_bound": lower_bounds, 

232 "upper_bound": upper_bounds, 

233 } 

234 ) 

235 ) 

236 

237 # TODO: If strict, only take bands that are contained 

238 return valid_frequency_bands 

239 

240 def cross_powers( 

241 self, 

242 frequency_bands: FrequencyBands, 

243 channel_pairs: Optional[List[Tuple[str, str]]] = None, 

244 ): 

245 """ 

246 Compute cross powers between channel pairs for given frequency bands. 

247 

248 TODO: Add handling for case when band in frequency_bands is not contained 

249 in self.frequencies. 

250 

251 Parameters 

252 ---------- 

253 frequency_bands : FrequencyBands 

254 The frequency bands to compute cross powers for. Each element of this iterable 

255 tells the lower and upper bounds of the cross-power calculation bands. 

256 These may become objects with information about tapers as ewwll. 

257 channel_pairs : list of tuples, optional 

258 List of channel pairs to compute cross powers for. 

259 If None, all possible pairs will be used. 

260 

261 Returns 

262 ------- 

263 xr.Dataset 

264 Dataset containing cross powers for all channel pairs. 

265 Each variable is named by the channel pair (e.g. 'ex_hy') 

266 and contains a 2D array with dimensions (frequency, time). 

267 All variables share common frequency and time coordinates. 

268 """ 

269 from itertools import combinations_with_replacement 

270 

271 valid_frequency_bands = self._validate_frequency_bands(frequency_bands) 

272 

273 # If no channel pairs specified, use all possible pairs 

274 if channel_pairs is None: 

275 channels = list(self.dataset.data_vars.keys()) 

276 channel_pairs = list(combinations_with_replacement(channels, 2)) 

277 

278 # Create variable names from channel pairs 

279 var_names = [self.cross_power_label(ch1, ch2) for ch1, ch2 in channel_pairs] 

280 

281 # Initialize a single multi-channel 2D xarray 

282 xpower_array = initialize_xrda_2d( 

283 var_names, 

284 coords={ 

285 "frequency": frequency_bands.band_centers(), 

286 "time": self.dataset.time.values, 

287 }, 

288 dtype=complex, 

289 ) 

290 

291 # Compute cross powers for each band and channel pair 

292 for band in valid_frequency_bands.bands(): 

293 # Extract band data 

294 band_data = self.extract_band(band).dataset 

295 

296 # Compute cross powers for each channel pair 

297 for ch1, ch2 in channel_pairs: 

298 label = self.cross_power_label(ch1, ch2) 

299 # Always compute as ch1 * conj(ch2) 

300 xpower = (band_data[ch1] * band_data[ch2].conj()).mean(dim="frequency") 

301 

302 # Store the cross power 

303 xpower_array.loc[ 

304 dict( 

305 frequency=band.center_frequency, 

306 variable=label, 

307 time=slice(None), 

308 ) 

309 ] = xpower 

310 

311 return xpower_array 

312 

313 def covariance_matrix( 

314 self, band_data: Optional["Spectrogram"] = None, method: str = "numpy_cov" 

315 ) -> xr.DataArray: 

316 """ 

317 TODO: Add tests for this WIP Work-in-progress method 

318 Compute full covariance matrix for spectrogram data. 

319 

320 For complex-valued data, the result is a Hermitian matrix where: 

321 - diagonal elements are real-valued variances 

322 - off-diagonal element [i,j] is E[ch_i * conj(ch_j)] 

323 - off-diagonal element [j,i] is the complex conjugate of [i,j] 

324 

325 Parameters 

326 ---------- 

327 band_data : Spectrogram, optional 

328 If provided, compute covariance for this data 

329 If None, use the full spectrogram 

330 method : str 

331 Computation method. Currently only supports 'numpy_cov' 

332 

333 Returns 

334 ------- 

335 xr.DataArray 

336 Hermitian covariance matrix with proper channel labeling 

337 For channels i,j: matrix[i,j] = E[ch_i * conj(ch_j)] 

338 """ 

339 data = band_data or self 

340 flat_data = data.flatten(chunk_by="time") 

341 

342 if method == "numpy_cov": 

343 # Convert to DataArray for covariance_xr 

344 stacked = flat_data.to_array(dim="variable") 

345 return covariance_xr(stacked) 

346 else: 

347 raise ValueError(f"Unknown method: {method}") 

348 

349 def _get_all_channel_pairs(self) -> List[Tuple[str, str]]: 

350 """Get all unique channel pairs (upper triangle)""" 

351 channels = list(self.dataset.data_vars.keys()) 

352 pairs = [] 

353 for i, ch1 in enumerate(channels[:-1]): 

354 for ch2 in channels[i + 1 :]: 

355 pairs.append((ch1, ch2)) 

356 return pairs 

357 

358 def flatten(self, chunk_by: Literal["time", "frequency"] = "time") -> xr.Dataset: 

359 """ 

360 

361 Reshape the 2D spectrogram into a 1D flattened xarray (time-chunked by default). 

362 

363 Parameters 

364 ---------- 

365 chunk_by: Literal["time", "frequency"] 

366 Reshaping the 2D spectrogram can be done two ways, (basically "row-major", 

367 or column-major). In xarray, but we either keep frequency constant and iterate 

368 over time, or keep time constant and iterate over frequency (in the inner loop). 

369 

370 Returns 

371 ------- 

372 xarray.Dataset : The dataset from the band spectrogram, stacked. 

373 

374 Development Notes: 

375 The flattening used in tf calculation by default is opposite to here 

376 dataset.stack(observation=("frequency", "time")) 

377 However, for feature extraction, it may make sense to swap the order: 

378 xrds = band_spectrogram.dataset.stack(observation=("time", "frequency")) 

379 This is like chunking into time windows and allows individual features to be computed on each time window -- if desired. 

380 Still need to split the time series though--Splitting to time would be a reshape by (last_freq_index-first_freq_index). 

381 Using pure xarray this may not matter but if we drop down into numpy it could be useful. 

382 

383 

384 """ 

385 if chunk_by == "time": 

386 observation = ("time", "frequency") 

387 elif chunk_by == "frequency": 

388 observation = ("frequency", "time") 

389 else: 

390 msg = f"Invalid argument chunk_by={chunk_by}, must be one of ['time', 'frequency']" 

391 logger.error(msg) 

392 raise ValueError(msg) 

393 

394 return self.dataset.stack(observation=observation) 

395 

396 

397def extract_band( 

398 frequency_band: Band, 

399 fft_obj: Union[xr.Dataset, xr.DataArray], 

400 channels: Optional[list] = None, 

401 epsilon: float = 1e-7, 

402) -> Union[xr.Dataset, xr.DataArray]: 

403 """ 

404 Extracts a frequency band from xr.DataArray representing a spectrogram. 

405 

406 TODO: Update variable names. 

407 

408 Development Notes: 

409 Base dataset object should be a xr.DataArray (not xr.Dataset) 

410 - drop=True does not play nice with h5py and Dataset, results in a type error. 

411 File "stringsource", line 2, in h5py.h5r.Reference.__reduce_cython__ 

412 TypeError: no default __reduce__ due to non-trivial __cinit__ 

413 However, it works OK with DataArray. 

414 

415 Parameters 

416 ---------- 

417 frequency_band: mt_metadata.common.band.Band 

418 Specifies interval corresponding to a frequency band 

419 fft_obj: xarray.core.dataset.Dataset 

420 Short-time-Fourier-transformed datat. Can be multichannel. 

421 channels: list 

422 Channel names to extract. 

423 epsilon: float 

424 Use this when you are worried about missing a frequency due to 

425 round off error. This is in general not needed if we use a df/2 pad 

426 around true harmonics. 

427 

428 Returns 

429 ------- 

430 extracted_band: xr.DataArray 

431 The frequencies within the band passed into this function 

432 """ 

433 cond1 = fft_obj.frequency >= frequency_band.lower_bound - epsilon 

434 cond2 = fft_obj.frequency <= frequency_band.upper_bound + epsilon 

435 try: 

436 extracted_band = fft_obj.where(cond1 & cond2, drop=True) 

437 except TypeError: # see Note #1 

438 tmp = fft_obj.to_array() 

439 extracted_band = tmp.where(cond1 & cond2, drop=True) 

440 extracted_band = extracted_band.to_dataset("variable") 

441 if channels: 

442 extracted_band = extracted_band[channels] 

443 

444 if len(extracted_band.frequency) == 0: 

445 msg = ( 

446 f"Frequency band {frequency_band} does not overlap with the frequencies " 

447 f"of the input dataset. Frequencies in dataset are: {fft_obj.frequency.values}. " 

448 "Skipping band extraction. Consider reforming the bands." 

449 ) 

450 logger.warning(msg) 

451 return extracted_band