Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ filters \ channel_response.py: 89%

228 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from copy import deepcopy 

5from typing import Annotated 

6 

7import numpy as np 

8from loguru import logger 

9from pydantic import ( 

10 computed_field, 

11 Field, 

12 field_validator, 

13 model_validator, 

14 PrivateAttr, 

15 ValidationInfo, 

16) 

17 

18from mt_metadata.base.helpers import object_to_array, requires 

19from mt_metadata.common.units import get_unit_object 

20from mt_metadata.timeseries.filters import ( 

21 CoefficientFilter, 

22 FilterBase, 

23 FIRFilter, 

24 FrequencyResponseTableFilter, 

25 PoleZeroFilter, 

26 TimeDelayFilter, 

27) 

28from mt_metadata.timeseries.filters.plotting_helpers import plot_response 

29 

30 

31try: 

32 from obspy.core import inventory 

33except ImportError: 

34 inventory = None 

35 

36 

37# ===================================================== 

38 

39 

40class ChannelResponse(FilterBase): 

41 _supported_filters: list = PrivateAttr( 

42 [ 

43 PoleZeroFilter, 

44 CoefficientFilter, 

45 TimeDelayFilter, 

46 FrequencyResponseTableFilter, 

47 FIRFilter, 

48 ] 

49 ) 

50 

51 normalization_frequency: Annotated[ 

52 float, 

53 Field( 

54 default=0.0, 

55 description="Pass band frequency", 

56 alias=None, 

57 json_schema_extra={ 

58 "units": None, 

59 "required": True, 

60 "examples": "100", 

61 }, 

62 ), 

63 ] 

64 

65 filters_list: Annotated[ 

66 list[ 

67 PoleZeroFilter 

68 | CoefficientFilter 

69 | TimeDelayFilter 

70 | FrequencyResponseTableFilter 

71 | FIRFilter 

72 ], 

73 Field( 

74 default_factory=list, 

75 description="List of filters applied to the channel.", 

76 alias=None, 

77 json_schema_extra={ 

78 "units": None, 

79 "required": True, 

80 "examples": "[PoleZeroFilter, CoefficientFilter]", 

81 }, 

82 ), 

83 ] 

84 

85 frequencies: Annotated[ 

86 np.ndarray | list[float], 

87 Field( 

88 default_factory=lambda: np.empty(0, dtype=float), 

89 description="The frequencies at which a calibration of the filter were performed.", 

90 alias=None, 

91 json_schema_extra={ 

92 "units": "hertz", 

93 "required": True, 

94 "items": {"type": "number"}, 

95 "examples": '"[-0.0001., 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.001, ... 1, 2, 5, 10]"', 

96 }, 

97 ), 

98 ] 

99 

100 def __str__(self): 

101 lines = ["Filters Included:\n", "=" * 25, "\n"] 

102 for f in self.filters_list: 

103 lines.append(f.__str__()) 

104 lines.append(f"\n{'-'*20}\n") 

105 

106 return "".join(lines) 

107 

108 def __repr__(self): 

109 return self.__str__() 

110 

111 @field_validator("normalization_frequency", mode="after") 

112 @classmethod 

113 def validate_normalization_frequency( 

114 cls, value: float, info: ValidationInfo 

115 ) -> float: 

116 """ 

117 Validate that the normalization frequency is a positive float. 

118 If value is 0 or None, derive it from the pass_band. 

119 """ 

120 if value in [0.0, None]: 

121 # Create a temporary instance to access pass_band property 

122 instance = cls.model_construct(**info.data) 

123 

124 if hasattr(instance, "pass_band") and instance.pass_band is not None: 

125 pass_band = instance.pass_band 

126 # Calculate geometric mean of pass band 

127 norm_freq = np.round(10 ** np.mean(np.log10(pass_band)), 3) 

128 logger.info( 

129 f"Setting normalization frequency to {norm_freq} Hz based on pass band" 

130 ) 

131 return norm_freq 

132 

133 return value 

134 

135 @field_validator("frequencies", mode="before") 

136 @classmethod 

137 def validate_frequencies(cls, value: np.ndarray | list[float]) -> np.ndarray: 

138 """ 

139 Validate that the frequencies are a numpy array or list of floats. 

140 """ 

141 return object_to_array(value, dtype=float) 

142 

143 @field_validator("filters_list", mode="before") 

144 @classmethod 

145 def validate_filters_list(cls, value: list) -> list: 

146 """ 

147 Validate that the filters_list is a list of filter objects. 

148 """ 

149 if not isinstance(value, list): 

150 raise ValueError("filters_list must be a list of filter objects.") 

151 

152 value = cls._validate_filters_list(value) 

153 value = cls._check_consistency_of_units(value) 

154 return value 

155 

156 @model_validator(mode="after") 

157 def update_units_and_normalization_frequency_from_filters_list( 

158 self, 

159 ) -> "ChannelResponse": 

160 """Update units_in and units_out based on filters_list.""" 

161 if self.filters_list: 

162 object.__setattr__(self, "units_in", self.filters_list[0].units_in) 

163 object.__setattr__(self, "units_out", self.filters_list[-1].units_out) 

164 if self.normalization_frequency == 0.0: 

165 pass_band = self.pass_band 

166 if pass_band is not None: 

167 # Calculate geometric mean of pass band 

168 with np.errstate(divide="ignore"): 

169 norm_freq = np.round(10 ** np.mean(np.log10(pass_band)), 3) 

170 logger.debug( 

171 f"Setting normalization frequency to {norm_freq} Hz based on pass band" 

172 ) 

173 # Set normalization frequency to the gain of the first filter 

174 object.__setattr__(self, "normalization_frequency", norm_freq) 

175 return self 

176 

177 @classmethod 

178 def _validate_filters_list(cls, filters_list): 

179 """ 

180 make sure the filters list is valid. 

181 

182 :param filters_list: DESCRIPTION 

183 :type filters_list: TYPE 

184 :return: DESCRIPTION 

185 :rtype: TYPE 

186 

187 """ 

188 

189 def is_supported_filter(item): 

190 # Convert the list to a tuple of filter classes 

191 supported_filter_types = tuple(cls._supported_filters.default) 

192 # Check if item is an instance of any of the supported filter types 

193 return isinstance(item, supported_filter_types) 

194 

195 if filters_list in [[], None]: 

196 return [] 

197 

198 if not isinstance(filters_list, list): 

199 msg = f"Input filters list must be a list not {type(filters_list)}" 

200 logger.error(msg) 

201 raise TypeError(msg) 

202 

203 fails = [] 

204 return_list = [] 

205 for item in filters_list: 

206 if is_supported_filter(item): 

207 return_list.append(item) 

208 else: 

209 fails.append(f"Item is not a supported filter type, {type(item)}") 

210 

211 if fails: 

212 raise TypeError(", ".join(fails)) 

213 

214 return return_list 

215 

216 @classmethod 

217 def _check_consistency_of_units(cls, filters_list): 

218 """ 

219 confirms that the input and output units of each filter state are consistent 

220 """ 

221 if len(filters_list) > 1: 

222 previous_units = filters_list[0].units_out 

223 for mt_filter in filters_list[1:]: 

224 if mt_filter.units_in != previous_units: 

225 msg = ( 

226 "Unit consistency is incorrect. " 

227 f"The input units for {mt_filter.name} should be " 

228 f"{previous_units} not {mt_filter.units_in}" 

229 ) 

230 logger.error(msg) 

231 raise ValueError(msg) 

232 previous_units = mt_filter.units_out 

233 

234 return filters_list 

235 

236 @computed_field 

237 @property 

238 def names(self) -> list[str]: 

239 """names of the filters""" 

240 names = [] 

241 if self.filters_list: 

242 names = [f.name for f in self.filters_list] 

243 return names 

244 

245 @computed_field 

246 @property 

247 def pass_band(self) -> list[float]: 

248 """estimate pass band for all filters in frequency""" 

249 if self.frequencies is None: 

250 logger.debug("No frequencies provided, cannot calculate pass band") 

251 return None 

252 

253 if len(self.frequencies) == 0: 

254 logger.debug("No frequencies provided, cannot calculate pass band") 

255 return None 

256 

257 pb = [] 

258 for f in self.filters_list: 

259 if hasattr(f, "pass_band"): 

260 f_pb = f.pass_band(self.frequencies) 

261 if f_pb is None: 

262 continue 

263 pb.append((f_pb.min(), f_pb.max())) 

264 

265 if pb != []: 

266 pb = np.array(pb) 

267 return np.array([pb[:, 0].max(), pb[:, 1].min()]) 

268 return None 

269 

270 @computed_field 

271 @property 

272 def non_delay_filters(self) -> list: 

273 """ 

274 

275 :return: all the non-time_delay filters as a list 

276 

277 """ 

278 non_delay_filters = [x for x in self.filters_list if x.type != "time delay"] 

279 return non_delay_filters 

280 

281 @computed_field 

282 @property 

283 def delay_filters(self) -> list[TimeDelayFilter]: 

284 """ 

285 

286 :return: all the time delay filters as a list 

287 

288 """ 

289 delay_filters = [x for x in self.filters_list if x.type == "time delay"] 

290 return delay_filters 

291 

292 @computed_field 

293 @property 

294 def total_delay(self) -> float: 

295 """ 

296 

297 :return: the total delay of all filters 

298 

299 """ 

300 delay_filters = self.delay_filters 

301 total_delay = 0.0 

302 for delay_filter in delay_filters: 

303 total_delay += delay_filter.delay 

304 return total_delay 

305 

306 def get_indices_of_filters_to_remove( 

307 self, include_decimation=False, include_delay=False 

308 ): 

309 indices = list(np.arange(len(self.filters_list))) 

310 

311 if not include_delay: 

312 indices = [i for i in indices if self.filters_list[i].type != "time delay"] 

313 

314 if not include_decimation: 

315 indices = [i for i in indices if not self.filters_list[i].decimation_active] 

316 

317 return indices 

318 

319 def get_list_of_filters_to_remove( 

320 self, include_decimation=False, include_delay=False 

321 ): 

322 """ 

323 

324 :param include_decimation: bool 

325 :param include_delay: bool 

326 :return: 

327 

328 # Experimental snippet if we want to allow filters with the opposite convention 

329 # into channel response -- I don't think we do. 

330 # if self.correction_operation == "multiply": 

331 # inverse_filters = [x.inverse() for x in self.filters_list] 

332 # self.filters_list = inverse_filters 

333 """ 

334 indices = self.get_indices_of_filters_to_remove( 

335 include_decimation=include_decimation, include_delay=include_delay 

336 ) 

337 return [self.filters_list[i] for i in indices] 

338 

339 def complex_response( 

340 self, 

341 frequencies=None, 

342 filters_list=None, 

343 include_decimation=False, 

344 include_delay=False, 

345 normalize=False, 

346 **kwargs, 

347 ): 

348 """ 

349 Computes the complex response of self. 

350 Allows the user to optionally supply a subset of filters 

351 

352 :param frequencies: frequencies to compute complex response, 

353 defaults to None 

354 :type frequencies: np.ndarray, optional 

355 :param include_delay: include delay in complex response, 

356 defaults to False 

357 :type include_delay: bool, optional 

358 :param include_decimation: Include decimation in response, 

359 defaults to True 

360 :type include_decimation: bool, optional 

361 :param normalize: normalize the response to 1, defaults to False 

362 :type normalize: bool, optional 

363 :return: complex response along give frequency array 

364 :rtype: np.ndarray 

365 

366 """ 

367 if frequencies is not None: 

368 self.frequencies = frequencies 

369 

370 # make filters list if not supplied 

371 if filters_list is None: 

372 logger.warning( 

373 "Filters list not provided, building list assuming all are applied" 

374 ) 

375 filters_list = self.get_list_of_filters_to_remove( 

376 include_decimation=include_decimation, 

377 include_delay=include_delay, 

378 ) 

379 

380 if len(filters_list) == 0: 

381 logger.warning(f"No filters associated with {self.__class__}, returning 1") 

382 return np.ones(len(self.frequencies), dtype=complex) 

383 

384 # define the product of all filters as the total response function 

385 result = filters_list[0].complex_response(self.frequencies) 

386 for ff in filters_list[1:]: 

387 result *= ff.complex_response(self.frequencies) 

388 

389 if normalize: 

390 result /= np.max(np.abs(result)) 

391 return result 

392 

393 def compute_instrument_sensitivity(self, normalization_frequency=None, sig_figs=6): 

394 """ 

395 Compute the StationXML instrument sensitivity for the given normalization frequency 

396 

397 :param normalization_frequency: DESCRIPTION 

398 :type normalization_frequency: TYPE 

399 :return: DESCRIPTION 

400 :rtype: TYPE 

401 

402 """ 

403 if normalization_frequency is not None: 

404 self.normalization_frequency = normalization_frequency 

405 sensitivity = 1.0 

406 for mt_filter in self.filters_list: 

407 complex_response = mt_filter.complex_response(self.normalization_frequency) 

408 sensitivity *= complex_response.astype(complex) 

409 try: 

410 sensitivity = np.abs(sensitivity[0]) 

411 except (IndexError, TypeError): 

412 sensitivity = np.abs(sensitivity) 

413 

414 if sensitivity == 0.0: 

415 logger.warning( 

416 "Sensitivity is zero, cannot compute instrument sensitivity. " 

417 "Returning 1.0" 

418 ) 

419 return 1.0 

420 if np.isnan(sensitivity): 

421 logger.warning("Sensitivity is NaN, setting to 1.0") 

422 sensitivity = 1.0 

423 return round(sensitivity, sig_figs - int(np.floor(np.log10(abs(sensitivity))))) 

424 

425 def compute_total_gain(self, sig_figs=16): 

426 """ 

427 Computing the total sensitivity seems to be different than just adding all the gains together. 

428 Overall the total sensitivity is useless for MT cause they don't have the ability to use the units. 

429 So if a person downloads data from the DMC, they will simply use the filters provided. 

430 

431 Parameters 

432 ---------- 

433 sig_figs : int, optional 

434 _description_, by default 6 

435 

436 Returns 

437 ------- 

438 _type_ 

439 _description_ 

440 

441 Raises 

442 ------ 

443 ValueError 

444 _description_ 

445 """ 

446 total_gain = 1 

447 for mt_filter in self.filters_list: 

448 total_gain *= mt_filter.gain 

449 

450 return round(total_gain, sig_figs - int(np.floor(np.log10(abs(total_gain))))) 

451 

452 @requires(obspy=inventory) 

453 def to_obspy(self, sample_rate=1): 

454 """ 

455 Output :class:`obspy.core.inventory.InstrumentSensitivity` object that 

456 can be used in a stationxml file. 

457 

458 :param normalization_frequency: DESCRIPTION 

459 :type normalization_frequency: TYPE 

460 :return: DESCRIPTION 

461 :rtype: TYPE 

462 

463 """ 

464 total_sensitivity = self.compute_instrument_sensitivity() 

465 total_gain = self.compute_total_gain() 

466 

467 if total_sensitivity != total_gain: 

468 logger.info( 

469 f"total sensitivity {total_sensitivity} != total gain {total_gain}. Using total_gain." 

470 ) 

471 total_sensitivity = total_gain 

472 

473 units_in_obj = get_unit_object(self.units_in) 

474 units_out_obj = get_unit_object(self.units_out) 

475 

476 total_response = inventory.Response() 

477 total_response.instrument_sensitivity = inventory.InstrumentSensitivity( 

478 total_sensitivity, 

479 self.normalization_frequency, 

480 units_in_obj.symbol, 

481 units_out_obj.symbol, 

482 input_units_description=units_in_obj.name, 

483 output_units_description=units_out_obj.name, 

484 ) 

485 

486 for ii, f in enumerate(self.filters_list, 1): 

487 if f.type in ["coefficient"]: 

488 if f.units_out not in ["count", "digital counts"]: 

489 logger.debug(f"converting CoefficientFilter {f.name} to PZ") 

490 pz = PoleZeroFilter() 

491 pz.gain = f.gain 

492 pz.units_in = f.units_in 

493 pz.units_out = f.units_out 

494 pz.comments = f.comments 

495 pz.name = f.name 

496 else: 

497 pz = f 

498 

499 total_response.response_stages.append( 

500 pz.to_obspy( 

501 stage_number=ii, 

502 normalization_frequency=self.normalization_frequency, 

503 sample_rate=sample_rate, 

504 ) 

505 ) 

506 else: 

507 total_response.response_stages.append( 

508 f.to_obspy( 

509 stage_number=ii, 

510 normalization_frequency=self.normalization_frequency, 

511 sample_rate=sample_rate, 

512 ) 

513 ) 

514 

515 return total_response 

516 

517 def plot_response( 

518 self, 

519 frequencies=None, 

520 x_units="period", 

521 unwrap=True, 

522 pb_tol=1e-1, 

523 interpolation_method="slinear", 

524 include_delay=False, 

525 include_decimation=False, 

526 ): 

527 """ 

528 Plot the response 

529 

530 :param frequencies: frequencies to compute response, defaults to None 

531 :type frequencies: np.ndarray, optional 

532 :param x_units: [ period | frequency ], defaults to "period" 

533 :type x_units: string, optional 

534 :param unwrap: Unwrap phase, defaults to True 

535 :type unwrap: bool, optional 

536 :param pb_tol: pass band tolerance, defaults to 1e-1 

537 :type pb_tol: float, optional 

538 :param interpolation_method: Interpolation method see scipy.signal.interpolate 

539 [ slinear | nearest | cubic | quadratic | ], defaults to "slinear" 

540 :type interpolation_method: string, optional 

541 :param include_delay: include delays in response, defaults to False 

542 :type include_delay: bool, optional 

543 :param include_decimation: Include decimation in response, 

544 defaults to True 

545 :type include_decimation: bool, optional 

546 

547 """ 

548 

549 if frequencies is not None: 

550 self.frequencies = frequencies 

551 

552 # get only the filters desired 

553 if include_delay: 

554 filters_list = deepcopy(self.filters_list) 

555 else: 

556 filters_list = deepcopy(self.non_delay_filters) 

557 

558 if not include_decimation: 

559 filters_list = deepcopy( 

560 [x for x in filters_list if not x.decimation_active] 

561 ) 

562 

563 cr_kwargs = {"interpolation_method": interpolation_method} 

564 

565 # get response of individual filters 

566 cr_list = [ 

567 f.complex_response(self.frequencies, **cr_kwargs) for f in filters_list 

568 ] 

569 

570 # compute total response 

571 cr_kwargs["include_delay"] = include_delay 

572 cr_kwargs["include_decimation"] = include_decimation 

573 complex_response = self.complex_response(self.frequencies, **cr_kwargs) 

574 

575 cr_list.append(complex_response) 

576 labels = [f.name for f in filters_list] + ["Total Response"] 

577 

578 # plot with proper attributes. 

579 kwargs = { 

580 "title": f"Channel Response: [{', '.join([f.name for f in filters_list])}]", 

581 "unwrap": unwrap, 

582 "x_units": x_units, 

583 "pass_band": self.pass_band, 

584 "label": labels, 

585 "normalization_frequency": self.normalization_frequency, 

586 } 

587 

588 plot_response(self.frequencies, cr_list, **kwargs)