Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ features \ striding_window_coherence.py: 98%

40 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ============================================================================== 

2# Imports 

3# ============================================================================== 

4from typing import Annotated 

5 

6import numpy as np 

7import scipy.signal as ssig 

8from pydantic import Field, model_validator 

9 

10from mt_metadata.features.coherence import Coherence 

11from mt_metadata.processing.window import Window 

12 

13 

14# from pathos.multiprocessing import ProcessingPool as Pool 

15 

16 

17# ============================================================================== 

18 

19 

20class StridingWindowCoherence(Coherence): 

21 """ 

22 Computes coherence for each sub-window (FFT window) across the time series. 

23 Returns a 2D array: (window index, frequency). 

24 """ 

25 

26 subwindow: Annotated[ 

27 Window, 

28 Field( 

29 default_factory=Window, # type: ignore 

30 description="The window used for the subwindow coherence calculation.", 

31 json_schema_extra={ 

32 "units": None, 

33 "required": False, 

34 "examples": ["hann", "hamming", "blackman"], 

35 }, 

36 ), 

37 ] 

38 

39 @model_validator(mode="before") 

40 @classmethod 

41 def set_defaults(cls, data: dict) -> dict: 

42 data["name"] = "striding_window_coherence" 

43 data["domain"] = "frequency" 

44 data["description"] = ( 

45 "Computes coherence for each sub-window " 

46 "(FFT window) across the time series." 

47 ) 

48 

49 return data 

50 

51 def set_subwindow_from_window(self, fraction=0.2): 

52 """ 

53 Set the subwindow as a fraction of the main window. 

54 """ 

55 self.subwindow = Window() # type: ignore 

56 self.subwindow.type = self.window.type 

57 self.subwindow.num_samples = int(self.window.num_samples * fraction) 

58 self.subwindow.overlap = int(self.subwindow.num_samples // 2) 

59 self.subwindow.additional_args = self.window.additional_args 

60 # No need to update stride; main window stride is set by self.window.num_samples_advance 

61 

62 def compute( 

63 self, ts_1: np.ndarray, ts_2: np.ndarray, parallel: bool = False 

64 ) -> tuple[np.ndarray, np.ndarray]: 

65 """ 

66 For each main window (length self.window.num_samples, stride self.window.num_samples_advance), 

67 compute coherence using the subwindow parameters (self.subwindow) within that main window. 

68 Returns: 

69 frequencies: 1D array of frequencies 

70 coherences: 2D array (n_main_windows, n_frequencies) 

71 """ 

72 n = len(ts_1) 

73 main_win_len = self.window.num_samples 

74 main_stride = ( 

75 self.window.num_samples_advance 

76 if hasattr(self.window, "num_samples_advance") 

77 else main_win_len 

78 ) 

79 results = [] 

80 

81 if self.subwindow.type in [ 

82 "kaiser", 

83 "kaiser_bessel_derived", 

84 "gaussian", 

85 "general_cosine", 

86 "general_gaussian", 

87 "general_hamming", 

88 "dpss", 

89 "chebwin", 

90 ]: 

91 win_tuple = tuple( 

92 [self.subwindow.type] 

93 + [param for param in self.subwindow.additional_args.values()] 

94 ) 

95 

96 else: 

97 win_tuple = self.subwindow.type 

98 

99 ts_1 = np.nan_to_num(ts_1) 

100 ts_2 = np.nan_to_num(ts_2) 

101 

102 starts = range(0, n - main_win_len + 1, main_stride) 

103 # if parallel: 

104 

105 # def process_segment(start): 

106 # f, coh = ssig.coherence( 

107 # ts_1[start : start + main_win_len], 

108 # ts_2[start : start + main_win_len], 

109 # window=win_tuple, 

110 # nperseg=self.subwindow.num_samples, 

111 # noverlap=self.subwindow.overlap, 

112 # detrend=self.detrend, 

113 # ) 

114 # return f, coh 

115 

116 # with Pool() as pool: 

117 # results = pool.map(process_segment, starts) 

118 

119 # f = results[0][0] 

120 # coherences = [r[1] for r in results] 

121 

122 # else: 

123 coherences = [] 

124 for start in starts: 

125 end = start + main_win_len 

126 seg1 = ts_1[start:end] 

127 seg2 = ts_2[start:end] 

128 

129 f, coh = ssig.coherence( 

130 seg1, 

131 seg2, 

132 window=win_tuple, 

133 nperseg=self.subwindow.num_samples, 

134 noverlap=self.subwindow.overlap, 

135 detrend=self.detrend, 

136 ) 

137 coherences.append(coh) 

138 

139 return f, np.array(coherences)