Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ fourier_coefficients \ fc.py: 92%

134 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from collections import OrderedDict 

5from typing import Annotated 

6 

7import numpy as np 

8from loguru import logger 

9from pydantic import Field, field_validator, model_validator, ValidationInfo 

10 

11from mt_metadata import NULL_VALUES 

12from mt_metadata.base import MetadataBase 

13from mt_metadata.common import ListDict, TimePeriod 

14from mt_metadata.common.enumerations import StrEnumerationBase 

15from mt_metadata.processing.fourier_coefficients.decimation import Decimation 

16 

17 

18# ===================================================== 

19class MethodEnum(StrEnumerationBase): 

20 fft = "fft" 

21 wavelet = "wavelet" 

22 other = "other" 

23 

24 

25class FC(MetadataBase): 

26 decimation_levels: Annotated[ 

27 list[str], 

28 Field( 

29 default_factory=list, 

30 description="List of decimation levels", 

31 alias=None, 

32 json_schema_extra={ 

33 "units": None, 

34 "required": True, 

35 "examples": ["[1, 2, 3]"], 

36 }, 

37 ), 

38 ] 

39 

40 id: Annotated[ 

41 str, 

42 Field( 

43 default="", 

44 description="ID given to the FC group", 

45 alias=None, 

46 json_schema_extra={ 

47 "units": None, 

48 "required": True, 

49 "examples": ["aurora_01"], 

50 }, 

51 ), 

52 ] 

53 

54 channels_estimated: Annotated[ 

55 list[str], 

56 Field( 

57 default_factory=list, 

58 description="list of channels estimated", 

59 alias=None, 

60 json_schema_extra={ 

61 "units": None, 

62 "required": True, 

63 "examples": [["ex", "hy"]], 

64 }, 

65 ), 

66 ] 

67 

68 starting_sample_rate: Annotated[ 

69 float, 

70 Field( 

71 default=1.0, 

72 description="Starting sample rate of the time series used to estimate FCs.", 

73 alias=None, 

74 json_schema_extra={ 

75 "units": "samples per second", 

76 "required": True, 

77 "examples": [60], 

78 }, 

79 ), 

80 ] 

81 

82 method: Annotated[ 

83 MethodEnum, 

84 Field( 

85 default=MethodEnum.fft, 

86 description="Fourier transform method", 

87 alias=None, 

88 json_schema_extra={ 

89 "units": None, 

90 "required": True, 

91 "examples": ["fft"], 

92 }, 

93 ), 

94 ] 

95 

96 time_period: Annotated[ 

97 TimePeriod, 

98 Field( 

99 default_factory=TimePeriod, # type: ignore 

100 description="Time period of the FCs", 

101 alias=None, 

102 json_schema_extra={ 

103 "units": None, 

104 "required": True, 

105 "examples": [TimePeriod(start="2020-01-01", end="2020-01-02")], 

106 }, 

107 ), 

108 ] 

109 

110 levels: Annotated[ 

111 ListDict, 

112 Field( 

113 default_factory=ListDict, # type: ignore 

114 description="ListDict of decimation levels and their parameters", 

115 alias=None, 

116 json_schema_extra={ 

117 "units": None, 

118 "required": True, 

119 "examples": ["ListDict containing Decimation objects"], 

120 }, 

121 ), 

122 ] 

123 

124 @field_validator("channels_estimated", "decimation_levels", mode="before") 

125 @classmethod 

126 def validate_channels_estimated( 

127 cls, value: list[str] | np.ndarray | str, info: ValidationInfo 

128 ) -> list[str]: 

129 if isinstance(value, np.ndarray): 

130 value = value.tolist() 

131 

132 if value in NULL_VALUES: 

133 return [] 

134 elif isinstance(value, (list, tuple)): 

135 return value 

136 

137 elif isinstance(value, (str)): 

138 value = value.split(",") 

139 return value 

140 

141 else: 

142 raise TypeError( 

143 "'channels_recorded' must be set with a list not " f"{type(value)}." 

144 ) 

145 

146 @field_validator("levels", mode="before") 

147 @classmethod 

148 def validate_levels(cls, value, info: ValidationInfo): 

149 # Handle None values first 

150 if value is None: 

151 return ListDict() 

152 

153 # Handle string representations that might come from HDF5 storage 

154 if isinstance(value, str): 

155 # If it's a string representation, try to parse it or return empty ListDict 

156 if value in ["", "none", "None", "ListDict()", "{}"]: 

157 return ListDict() 

158 # For other string values, try to maintain backward compatibility 

159 logger.warning(f"Converting string representation of levels: {value}") 

160 return ListDict() 

161 

162 if not isinstance(value, (list, tuple, dict, ListDict, OrderedDict)): 

163 msg = ( 

164 "input dl_list must be an iterable, should be a list or dict " 

165 f"not {type(value)}" 

166 ) 

167 logger.error(msg) 

168 raise TypeError(msg) 

169 

170 fails = [] 

171 levels = ListDict() 

172 if isinstance(value, (dict, ListDict, OrderedDict)): 

173 value_list = value.values() 

174 

175 elif isinstance(value, (list, tuple)): 

176 value_list = value 

177 

178 for ii, decimation_level in enumerate(value_list): 

179 try: 

180 if isinstance(decimation_level, Decimation): 

181 dl = decimation_level 

182 else: 

183 dl = Decimation() # type: ignore 

184 if hasattr(decimation_level, "to_dict"): 

185 decimation_level = decimation_level.to_dict() 

186 dl.from_dict(decimation_level) 

187 levels.append(dl) 

188 except Exception as error: 

189 msg = "Could not create decimation_level from dictionary: %s" 

190 fails.append(msg % error) 

191 logger.error(msg, error) 

192 

193 if len(fails) > 0: 

194 raise TypeError("\n".join(fails)) 

195 

196 return levels 

197 

198 @model_validator(mode="after") 

199 def synchronize_levels(self) -> "FC": 

200 """ 

201 Ensure that decimation_levels and levels are synchronized. 

202 - Creates Decimation objects for any levels in decimation_levels that don't exist in levels 

203 - Adds level names to decimation_levels for any existing levels not in the list 

204 """ 

205 # First, ensure all levels in decimation_levels have corresponding Decimation objects 

206 for level_name in self.decimation_levels: 

207 level_name_str = str(level_name) 

208 if level_name_str not in self.levels.keys(): 

209 # Create a new Decimation object with the level name as id 

210 new_decimation = Decimation(id=level_name_str) # type: ignore 

211 self.levels.append(new_decimation) 

212 

213 # Second, ensure all existing levels in the ListDict are in decimation_levels 

214 for level_name in self.levels.keys(): 

215 if level_name not in self.decimation_levels: 

216 self.decimation_levels.append(level_name) 

217 

218 return self 

219 

220 def has_decimation_level(self, level): 

221 """ 

222 Check to see if the decimation_level already exists 

223 

224 :param level: decimation_level level to look for 

225 :type level: string 

226 :return: True if found, False if not 

227 :rtype: boolean 

228 

229 """ 

230 

231 if level in self.decimation_levels: 

232 return True 

233 return False 

234 

235 def decimation_level_index(self, level): 

236 """ 

237 get index of the decimation_level in the decimation_level list 

238 """ 

239 if self.has_decimation_level(level): 

240 return self.levels.keys().index(str(level)) 

241 return None 

242 

243 def get_decimation_level(self, level): 

244 """ 

245 Get a decimation_level 

246 

247 :param level: decimation_level level to look for 

248 :type level: string 

249 :return: decimation_level object based on decimation_level type 

250 :rtype: :class:`mt_metadata.timeseries.decimation_level` 

251 

252 """ 

253 

254 if self.has_decimation_level(level): 

255 return self.levels[str(level)] 

256 

257 def add_decimation_level(self, fc_decimation): 

258 """ 

259 Add a decimation_level to the list, check if one exists if it does overwrite it 

260 

261 :param fc_decimation: decimation level object to add 

262 :type fc_decimation: :class:`mt_metadata.processing.fourier_coefficients.decimation_basemodel.Decimation` 

263 

264 """ 

265 if not isinstance(fc_decimation, (Decimation)): 

266 msg = f"Input must be metadata.decimation_level not {type(fc_decimation)}" 

267 logger.error(msg) 

268 raise ValueError(msg) 

269 

270 level_id = fc_decimation.id 

271 if self.has_decimation_level(level_id): 

272 self.levels[level_id].update(fc_decimation) 

273 logger.debug(f"level {level_id} already exists, updating metadata") 

274 else: 

275 self.levels.append(fc_decimation) 

276 # Also add to decimation_levels list if not present 

277 if level_id not in self.decimation_levels: 

278 self.decimation_levels.append(level_id) 

279 

280 self.update_time_period() 

281 

282 def remove_decimation_level(self, decimation_level_id): 

283 """ 

284 remove a ch from the survey 

285 

286 :param level: decimation_level level to look for 

287 :type level: string 

288 

289 """ 

290 

291 if self.has_decimation_level(decimation_level_id): 

292 self.levels.remove(decimation_level_id) 

293 # Also remove from decimation_levels list 

294 if decimation_level_id in self.decimation_levels: 

295 self.decimation_levels.remove(decimation_level_id) 

296 else: 

297 logger.warning(f"Could not find {decimation_level_id} to remove.") 

298 

299 self.update_time_period() 

300 

301 @property 

302 def n_decimation_levels(self): 

303 return len(self.levels) 

304 

305 def update_time_period(self): 

306 """ 

307 update time period from ch information 

308 """ 

309 start = [] 

310 end = [] 

311 for dl in self.levels: 

312 if dl.time_period.start != "1980-01-01T00:00:00+00:00": 

313 start.append(dl.time_period.start) 

314 if dl.time_period.start != "1980-01-01T00:00:00+00:00": 

315 end.append(dl.time_period.end) 

316 if start: 

317 if self.time_period.start == "1980-01-01T00:00:00+00:00": 

318 self.time_period.start = min(start) 

319 else: 

320 if self.time_period.start > min(start): 

321 self.time_period.start = min(start) 

322 if end: 

323 if self.time_period.end == "1980-01-01T00:00:00+00:00": 

324 self.time_period.end = max(end) 

325 else: 

326 if self.time_period.end < max(end): 

327 self.time_period.end = max(end)