Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ transfer_functions \ core.py: 74%

912 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# -*- coding: utf-8 -*- 

2""" 

3.. module:: TF 

4 :synopsis: The main container for transfer functions 

5 

6.. moduleauthor:: Jared Peacock <jpeacock@usgs.gov> 

7""" 

8 

9from collections import OrderedDict 

10from copy import deepcopy 

11 

12# ============================================================================== 

13from pathlib import Path 

14from typing import Any, Literal 

15 

16import numpy as np 

17import xarray as xr 

18from loguru import logger 

19from typing_extensions import Self 

20 

21from mt_metadata import DEFAULT_CHANNEL_NOMENCLATURE 

22from mt_metadata.base.helpers import validate_name 

23from mt_metadata.common.list_dict import ListDict 

24from mt_metadata.timeseries import Electric, Magnetic, Run 

25from mt_metadata.timeseries import Station as TSStation 

26from mt_metadata.timeseries import Survey 

27from mt_metadata.transfer_functions.io import EDI, EMTFXML, JFile, ZMM, ZongeMTAvg 

28from mt_metadata.transfer_functions.io.zfiles.metadata import Channel as ZChannel 

29from mt_metadata.transfer_functions.tf import Station 

30 

31 

32# ============================================================================= 

33 

34 

35class TF: 

36 """ 

37 Generic container to hold information about an electromagnetic 

38 transfer funtion 

39 

40 The thought here is to have a central container TF.dataset which is an 

41 xarray.Dataset that contains the impedance, tipper, errors and covariance 

42 values. There are helper functions to get and set these from the 

43 TF.dataset. Cause most of the time the user will want just the impedance 

44 or the tipper and associated errors. We are accommodating EMTF style 

45 covariances to accurately rotated data errors. 

46 

47 When reading and writing edi files this information will be lost. 

48 

49 """ 

50 

51 # Class-level template cache 

52 _template_cache = {} 

53 

54 def __init__(self, fn: str | Path | None = None, **kwargs): 

55 # set metadata for the station 

56 self._survey_metadata = self._initialize_metadata() 

57 self.channel_nomenclature = DEFAULT_CHANNEL_NOMENCLATURE 

58 self._inverse_channel_nomenclature = {} 

59 

60 self._rotation_angle = 0 

61 self.save_dir = Path.cwd() 

62 

63 self._dataset_attr_dict = { 

64 "survey": "survey_metadata.id", 

65 "project": "survey_metadata.project", 

66 "id": "station_metadata.id", 

67 "name": "station_metadata.geographic_name", 

68 "latitude": "station_metadata.location.latitude", 

69 "longitude": "station_metadata.location.longitude", 

70 "elevation": "station_metadata.location.elevation", 

71 "declination": "station_metadata.location.declination.value", 

72 "datum": "station_metadata.location.datum", 

73 "acquired_by": "station_metadata.acquired_by.author", 

74 "start": "station_metadata.time_period.start", 

75 "end": "station_metadata.time_period.end", 

76 "runs_processed": "station_metadata.run_list", 

77 "coordinate_system": "station_metadata.orientation.reference_frame", 

78 } 

79 

80 self._read_write_dict = { 

81 "edi": {"write": self.to_edi, "read": self.from_edi}, 

82 "xml": {"write": self.to_emtfxml, "read": self.from_emtfxml}, 

83 "emtfxml": {"write": self.to_emtfxml, "read": self.from_emtfxml}, 

84 "j": {"write": self.to_jfile, "read": self.from_jfile}, 

85 "zmm": {"write": self.to_zmm, "read": self.from_zmm}, 

86 "zrr": {"write": self.to_zrr, "read": self.from_zrr}, 

87 "zss": {"write": self.to_zss, "read": self.from_zss}, 

88 "avg": {"write": self.to_avg, "read": self.from_avg}, 

89 } 

90 

91 tf_set = False 

92 try: 

93 period = kwargs.pop("period") 

94 self._transfer_function = self._initialize_transfer_function(periods=period) 

95 tf_set = True 

96 except KeyError: 

97 try: 

98 period = 1.0 / kwargs.pop("frequency") 

99 self._transfer_function = self._initialize_transfer_function( 

100 periods=period 

101 ) 

102 tf_set = True 

103 except KeyError: 

104 pass 

105 

106 for key, value in kwargs.items(): 

107 setattr(self, key, value) 

108 

109 if not tf_set: 

110 self._transfer_function = self._initialize_transfer_function() 

111 

112 self.fn = fn 

113 

114 @property 

115 def inverse_channel_nomenclature(self) -> dict[str, str]: 

116 if not self._inverse_channel_nomenclature: 

117 self._inverse_channel_nomenclature = { 

118 v: k for k, v in self.channel_nomenclature.items() 

119 } 

120 return self._inverse_channel_nomenclature 

121 

122 def __str__(self) -> str: 

123 lines = [f"Station: {self.station}", "-" * 50] 

124 lines.append(f"\tSurvey: {self.survey_metadata.id}") 

125 lines.append(f"\tProject: {self.survey_metadata.project}") 

126 lines.append(f"\tAcquired by: {self.station_metadata.acquired_by.author}") 

127 lines.append(f"\tAcquired date: {self.station_metadata.time_period.start}") 

128 lines.append(f"\tLatitude: {self.latitude:.3f}") 

129 lines.append(f"\tLongitude: {self.longitude:.3f}") 

130 lines.append(f"\tElevation: {self.elevation:.3f}") 

131 lines.append("\tDeclination: ") 

132 lines.append( 

133 f"\t\tValue: {self.station_metadata.location.declination.value}" 

134 ) 

135 lines.append( 

136 f"\t\tModel: {self.station_metadata.location.declination.model}" 

137 ) 

138 lines.append( 

139 f"\tCoordinate System: {self.station_metadata.orientation.reference_frame}" 

140 ) 

141 

142 lines.append(f"\tImpedance: {self.has_impedance()}") 

143 lines.append(f"\tTipper: {self.has_tipper()}") 

144 

145 if self.period is not None: 

146 lines.append(f"\tN Periods: {len(self.period)}") 

147 

148 lines.append("\tPeriod Range:") 

149 lines.append(f"\t\tMin: {self.period.min():.5E} s") 

150 lines.append(f"\t\tMax: {self.period.max():.5E} s") 

151 

152 lines.append("\tFrequency Range:") 

153 lines.append(f"\t\tMin: {1./self.period.max():.5E} Hz") 

154 lines.append(f"\t\tMax: {1./self.period.min():.5E} Hz") 

155 return "\n".join(lines) 

156 

157 def __repr__(self) -> str: 

158 lines = [] 

159 lines.append(f"survey='{self.survey}'") 

160 lines.append(f"station='{self.station}'") 

161 lines.append(f"latitude={self.latitude:.2f}") 

162 lines.append(f"longitude={self.longitude:.2f}") 

163 lines.append(f"elevation={self.elevation:.2f}") 

164 

165 return f"TF( {(', ').join(lines)} )" 

166 

167 def __eq__(self, other: object) -> bool: 

168 """ 

169 Check if two TF objects are equal. 

170 

171 Parameters 

172 ---------- 

173 other: object 

174 Another object to compare with 

175 

176 Returns 

177 ------- 

178 bool 

179 True if equal, False otherwise 

180 """ 

181 is_equal = True 

182 if not isinstance(other, TF): 

183 logger.info(f"Comparing object is not TF, type {type(other)}") 

184 is_equal = False 

185 if self.station_metadata != other.station_metadata: 

186 logger.info("Station metadata is not equal") 

187 is_equal = False 

188 if self.survey_metadata != other.survey_metadata: 

189 logger.info("Survey Metadata is not equal") 

190 is_equal = False 

191 if self.has_transfer_function() and other.has_transfer_function(): 

192 if not self.transfer_function.equals(other.transfer_function): 

193 logger.info("TF is not equal") 

194 is_equal = False 

195 elif not self.has_transfer_function() and not other.has_transfer_function(): 

196 pass 

197 else: 

198 logger.info("TF is not equal") 

199 is_equal = False 

200 

201 return is_equal 

202 

203 def __ne__(self, other: object) -> bool: 

204 return not self.__eq__(other) 

205 

206 def __deepcopy__(self, memo: dict[int, Any]) -> Self: 

207 cls = self.__class__ 

208 result = cls.__new__(cls) 

209 memo[id(self)] = result 

210 for k, v in self.__dict__.items(): 

211 if k in ["logger"]: 

212 continue 

213 

214 setattr(result, k, deepcopy(v, memo)) 

215 return result 

216 

217 def copy(self) -> Self: 

218 """ 

219 Create a deep copy of the current object. 

220 

221 Returns 

222 ------- 

223 Self 

224 A deep copy of the current object. 

225 """ 

226 return deepcopy(self) 

227 

228 def _add_channels( 

229 self, run_metadata: Run, default: list[str] = ["ex", "ey", "hx", "hy", "hz"] 

230 ) -> Run: 

231 """ 

232 Add channels to a run. 

233 

234 Parameters 

235 ---------- 

236 run_metadata: Run 

237 The run metadata to add channels to. 

238 default: list[str], optional 

239 The default list of channels to add. 

240 

241 Returns 

242 ------- 

243 Run 

244 The updated run metadata. 

245 """ 

246 for ch in [cc for cc in default if cc.startswith("e")]: 

247 run_metadata.add_channel(Electric(component=ch)) 

248 for ch in [cc for cc in default if cc.startswith("h")]: 

249 run_metadata.add_channel(Magnetic(component=ch)) 

250 

251 return run_metadata 

252 

253 def _initialize_metadata(self) -> Survey: 

254 """ 

255 Create a single `Survey` object to store all metadata 

256 

257 This will include all stations and runs. 

258 

259 """ 

260 

261 survey_metadata = Survey(id="0") 

262 survey_metadata.stations.append(Station(id="0")) 

263 survey_metadata.stations[0].runs.append(Run(id="0")) 

264 

265 self._add_channels(survey_metadata.stations[0].runs[0]) 

266 

267 return survey_metadata 

268 

269 def _validate_run_metadata(self, run_metadata: Run) -> Run: 

270 """ 

271 Validate run metadata. 

272 

273 Parameters 

274 ---------- 

275 run_metadata: Run 

276 The run metadata to validate. 

277 

278 Returns 

279 ------- 

280 Run 

281 The validated run metadata. 

282 

283 """ 

284 

285 if not isinstance(run_metadata, Run): 

286 if isinstance(run_metadata, dict): 

287 if "run" not in [cc.lower() for cc in run_metadata.keys()]: 

288 run_metadata = {"Run": run_metadata} 

289 r_metadata = Run() 

290 r_metadata.from_dict(run_metadata) 

291 logger.debug("Loading from metadata dict") 

292 return r_metadata 

293 else: 

294 msg = ( 

295 f"input metadata must be type {type(self.run_metadata)} " 

296 f"or dict, not {type(run_metadata)}" 

297 ) 

298 logger.error(msg) 

299 raise TypeError(msg) 

300 return run_metadata 

301 

302 def _validate_station_metadata(self, station_metadata: Station) -> Station: 

303 """ 

304 Validate station metadata. 

305 

306 Parameters 

307 ---------- 

308 station_metadata: Station 

309 The station metadata to validate. 

310 

311 Returns 

312 ------- 

313 Station 

314 The validated station metadata. 

315 

316 """ 

317 

318 if not isinstance(station_metadata, Station): 

319 if isinstance(station_metadata, dict): 

320 if "station" not in [cc.lower() for cc in station_metadata.keys()]: 

321 station_metadata = {"Station": station_metadata} 

322 st_metadata = Station() 

323 st_metadata.from_dict(station_metadata) 

324 logger.debug("Loading from metadata dict") 

325 return st_metadata 

326 else: 

327 msg = ( 

328 f"input metadata must be type {type(self.station_metadata)}" 

329 f" or dict, not {type(station_metadata)}" 

330 ) 

331 logger.error(msg) 

332 raise TypeError(msg) 

333 return station_metadata 

334 

335 def _validate_survey_metadata(self, survey_metadata: Survey) -> Survey: 

336 """ 

337 Validate survey metadata. 

338 

339 Parameters 

340 ---------- 

341 survey_metadata: Survey 

342 The survey metadata to validate. 

343 

344 Returns 

345 ------- 

346 Survey 

347 The validated survey metadata. 

348 """ 

349 

350 if not isinstance(survey_metadata, Survey): 

351 if isinstance(survey_metadata, dict): 

352 if "survey" not in [cc.lower() for cc in survey_metadata.keys()]: 

353 survey_metadata = {"Survey": survey_metadata} 

354 sv_metadata = Survey() 

355 sv_metadata.from_dict(survey_metadata) 

356 logger.debug("Loading from metadata dict") 

357 return sv_metadata 

358 else: 

359 msg = ( 

360 f"input metadata must be type {type(self.survey_metadata)}" 

361 f" or dict, not {type(survey_metadata)}" 

362 ) 

363 logger.error(msg) 

364 raise TypeError(msg) 

365 return survey_metadata 

366 

367 ### Properties ------------------------------------------------------------ 

368 @property 

369 def survey_metadata(self) -> Survey: 

370 """ 

371 Survey metadata. 

372 """ 

373 return self._survey_metadata 

374 

375 @survey_metadata.setter 

376 def survey_metadata(self, survey_metadata: Survey) -> None: 

377 """ 

378 Set survey metadata. 

379 

380 Parameters 

381 ---------- 

382 survey_metadata: Survey 

383 The survey metadata object or dictionary to set. 

384 

385 """ 

386 

387 if survey_metadata is not None: 

388 survey_metadata = self._validate_survey_metadata(survey_metadata) 

389 self._survey_metadata.update(survey_metadata) 

390 for station in survey_metadata.stations: 

391 station.update_time_period() 

392 self._survey_metadata.add_station(station) 

393 

394 if len(self._survey_metadata.stations.keys()) > 1: 

395 if "0" in self._survey_metadata.stations.keys(): 

396 self._survey_metadata.stations.remove("0") 

397 

398 self._survey_metadata.update_time_period() 

399 

400 @property 

401 def station_metadata(self) -> Station: 

402 """ 

403 Station metadata from survey_metadata.stations[0] 

404 """ 

405 

406 return self.survey_metadata.stations[0] 

407 

408 @station_metadata.setter 

409 def station_metadata(self, station_metadata: Station | None = None) -> None: 

410 """ 

411 Set station metadata from a valid input. 

412 

413 Parameters 

414 ---------- 

415 station_metadata: Station | None 

416 The station metadata object or dictionary to set. 

417 """ 

418 

419 if station_metadata is not None: 

420 station_metadata = self._validate_station_metadata(station_metadata) 

421 

422 runs = ListDict() 

423 if self.run_metadata.id not in ["0", 0, None]: 

424 runs.append(self.run_metadata.copy()) 

425 runs.extend(station_metadata.runs) 

426 if len(runs) == 0: 

427 runs[0] = Run(id="0") 

428 # be sure there is a level below 

429 if len(runs[0].channels) == 0: 

430 self._add_channels(runs[0]) 

431 stations = ListDict() 

432 stations.append(station_metadata) 

433 stations[0].runs = runs 

434 stations[0].update_time_period() 

435 

436 self.survey_metadata.stations = stations 

437 self._survey_metadata.update_time_period() 

438 

439 @property 

440 def run_metadata(self) -> Run: 

441 """ 

442 Run metadata from survey_metadata.stations[0].runs[0] 

443 """ 

444 

445 return self.survey_metadata.stations[0].runs[0] 

446 

447 @run_metadata.setter 

448 def run_metadata(self, run_metadata: Run | None = None) -> None: 

449 """ 

450 Set run metadata from a valid input. 

451 

452 Parameters 

453 ---------- 

454 run_metadata: Run | None 

455 The run metadata object or dictionary to set. 

456 """ 

457 

458 # need to make sure the first index is the desired channel 

459 if run_metadata is not None: 

460 run_metadata = self._validate_run_metadata(run_metadata) 

461 

462 runs = ListDict() 

463 runs.append(run_metadata) 

464 channels = ListDict() 

465 if self.component is not None: 

466 key = str(self.component) 

467 

468 channels.append(self.station_metadata.runs[0].channels[key]) 

469 # add existing channels 

470 channels.extend(self.run_metadata.channels, skip_keys=[key, "0"]) 

471 # add channels from input metadata 

472 channels.extend(run_metadata.channels) 

473 

474 runs[0].channels = channels 

475 runs.extend(self.station_metadata.runs, skip_keys=[run_metadata.id, "0"]) 

476 

477 self._survey_metadata.stations[0].runs = runs 

478 

479 def _get_template_key(self): 

480 """Generate a cache key based on channel nomenclature""" 

481 return tuple(sorted(self.channel_nomenclature.items())) 

482 

483 def _initialize_transfer_function(self, periods=[1]): 

484 """ 

485 Create transfer function dataset efficiently using a cached template. 

486 """ 

487 template_key = self._get_template_key() 

488 

489 # Create template if not cached 

490 if template_key not in self._template_cache: 

491 tf = xr.DataArray( 

492 data=0.0 + 0j, 

493 dims=["period", "output", "input"], 

494 coords={ 

495 "period": [1], # Single period for template 

496 "output": self._ch_output_dict["all"], 

497 "input": self._ch_input_dict["all"], 

498 }, 

499 name="transfer_function", 

500 ) 

501 

502 tf_err = xr.DataArray( 

503 data=0.0, 

504 dims=["period", "output", "input"], 

505 coords={ 

506 "period": [1], 

507 "output": self._ch_output_dict["all"], 

508 "input": self._ch_input_dict["all"], 

509 }, 

510 name="transfer_function_error", 

511 ) 

512 

513 tf_model_err = xr.DataArray( 

514 data=0.0, 

515 dims=["period", "output", "input"], 

516 coords={ 

517 "period": [1], 

518 "output": self._ch_output_dict["all"], 

519 "input": self._ch_input_dict["all"], 

520 }, 

521 name="transfer_function_model_error", 

522 ) 

523 

524 inv_signal_power = xr.DataArray( 

525 data=0.0 + 0j, 

526 dims=["period", "output", "input"], 

527 coords={ 

528 "period": [1], 

529 "output": self._ch_output_dict["all"], 

530 "input": self._ch_input_dict["all"], 

531 }, 

532 name="inverse_signal_power", 

533 ) 

534 

535 residual_covariance = xr.DataArray( 

536 data=0.0 + 0j, 

537 dims=["period", "output", "input"], 

538 coords={ 

539 "period": [1], 

540 "output": self._ch_output_dict["all"], 

541 "input": self._ch_input_dict["all"], 

542 }, 

543 name="residual_covariance", 

544 ) 

545 

546 # will need to add in covariance in some fashion 

547 template = xr.Dataset( 

548 { 

549 tf.name: tf, 

550 tf_err.name: tf_err, 

551 tf_model_err.name: tf_model_err, 

552 inv_signal_power.name: inv_signal_power, 

553 residual_covariance.name: residual_covariance, 

554 }, 

555 coords={ 

556 "period": [1], 

557 "output": self._ch_output_dict["all"], 

558 "input": self._ch_input_dict["all"], 

559 }, 

560 ) 

561 self._template_cache[template_key] = template 

562 

563 # Copy template and adjust periods 

564 dataset = self._template_cache[template_key].copy(deep=True) 

565 

566 if len(periods) != 1 or periods[0] != 1: 

567 # Expand/adjust to match requested periods 

568 dataset = dataset.reindex(period=periods, fill_value=0.0) 

569 

570 return dataset 

571 

572 # ========================================================================== 

573 # Properties 

574 # ========================================================================== 

575 @property 

576 def channel_nomenclature(self) -> dict: 

577 """Channel nomenclature dictionary keyed by channel names. 

578 

579 For example: 

580 

581 {'ex': 'ex', 'ey': 'ey', 'hx': 'hx', 'hy': 'hy', 'hz': 'hz'} 

582 """ 

583 return self._channel_nomenclature 

584 

585 @channel_nomenclature.setter 

586 def channel_nomenclature(self, ch_dict: dict) -> None: 

587 """ 

588 Set the channel nomenclature dictionary. 

589 

590 Parameters 

591 ---------- 

592 ch_dict : dict 

593 A dictionary containing channel names and their corresponding labels. 

594 """ 

595 

596 if not isinstance(ch_dict, dict): 

597 raise TypeError( 

598 "Channel_nomenclature must be a dictionary with keys " 

599 "['ex', 'ey', 'hx', 'hy', 'hz']." 

600 ) 

601 

602 self._channel_nomenclature = ch_dict 

603 # unpack channel nomenclature dict 

604 self.ex = self._channel_nomenclature["ex"] 

605 self.ey = self._channel_nomenclature["ey"] 

606 self.hx = self._channel_nomenclature["hx"] 

607 self.hy = self._channel_nomenclature["hy"] 

608 self.hz = self._channel_nomenclature["hz"] 

609 self.ex_ey = [self.ex, self.ey] 

610 self.hx_hy = [self.hx, self.hy] 

611 self.ex_ey_hz = [self.ex, self.ey, self.hz] 

612 

613 @property 

614 def _ch_input_dict(self) -> dict: 

615 return { 

616 "impedance": self.hx_hy, 

617 "tipper": self.hx_hy, 

618 "impedance_error": self.hx_hy, 

619 "impedance_model_error": self.hx_hy, 

620 "tipper_error": self.hx_hy, 

621 "tipper_model_error": self.hx_hy, 

622 "isp": self.hx_hy, 

623 "res": self.ex_ey_hz, 

624 "tf": self.hx_hy, 

625 "tf_error": self.hx_hy, 

626 "all": [self.ex, self.ey, self.hz, self.hx, self.hy], 

627 } 

628 

629 @property 

630 def _ch_output_dict(self) -> dict: 

631 return { 

632 "impedance": self.ex_ey, 

633 "tipper": [self.hz], 

634 "impedance_error": self.ex_ey, 

635 "impedance_model_error": self.ex_ey, 

636 "tipper_error": [self.hz], 

637 "tipper_model_error": [self.hz], 

638 "isp": self.hx_hy, 

639 "res": self.ex_ey_hz, 

640 "tf": self.ex_ey_hz, 

641 "tf_error": self.ex_ey_hz, 

642 "all": [self.ex, self.ey, self.hz, self.hx, self.hy], 

643 } 

644 

645 @property 

646 def index_zxx(self) -> dict: 

647 return {"input": self.hx, "output": self.ex} 

648 

649 @property 

650 def index_zxy(self) -> dict: 

651 return {"input": self.hy, "output": self.ex} 

652 

653 @property 

654 def index_zyx(self) -> dict: 

655 return {"input": self.hx, "output": self.ey} 

656 

657 @property 

658 def index_zyy(self) -> dict: 

659 return {"input": self.hy, "output": self.ey} 

660 

661 @property 

662 def index_tzx(self) -> dict: 

663 return {"input": self.hx, "output": self.hz} 

664 

665 @property 

666 def index_tzy(self) -> dict: 

667 return {"input": self.hy, "output": self.hz} 

668 

669 @property 

670 def fn(self) -> Path: 

671 """reference to original data file""" 

672 return self._fn 

673 

674 @fn.setter 

675 def fn(self, value: Path | str | None) -> None: 

676 """set file name 

677 

678 Parameters 

679 ---------- 

680 value : Path | str | None 

681 The file name to set. 

682 """ 

683 if value is None: 

684 self._fn = None 

685 return 

686 self._fn = Path(value) 

687 self.save_dir = self._fn.parent 

688 

689 @property 

690 def latitude(self) -> float: 

691 """Latitude""" 

692 return self.station_metadata.location.latitude 

693 

694 @latitude.setter 

695 def latitude(self, latitude: float) -> None: 

696 """ 

697 set latitude making sure the input is in decimal degrees 

698 

699 upon setting utm coordinates are recalculated 

700 """ 

701 self.station_metadata.location.latitude = latitude 

702 

703 @property 

704 def longitude(self) -> float: 

705 """Longitude""" 

706 return self.station_metadata.location.longitude 

707 

708 @longitude.setter 

709 def longitude(self, longitude: float) -> None: 

710 """ 

711 set longitude making sure the input is in decimal degrees 

712 

713 upon setting utm coordinates are recalculated 

714 """ 

715 self.station_metadata.location.longitude = longitude 

716 

717 @property 

718 def elevation(self) -> float: 

719 """Elevation""" 

720 return self.station_metadata.location.elevation 

721 

722 @elevation.setter 

723 def elevation(self, elevation: float) -> None: 

724 """ 

725 set elevation, should be input as meters 

726 """ 

727 

728 self.station_metadata.location.elevation = elevation 

729 

730 @property 

731 def dataset(self) -> xr.Dataset: 

732 """ 

733 This will return an xarray dataset with proper metadata 

734 

735 Returns 

736 ------- 

737 xr.Dataset 

738 The xarray dataset with metadata. 

739 """ 

740 

741 for key, mkey in self._dataset_attr_dict.items(): 

742 obj, attr = mkey.split(".", 1) 

743 value = getattr(self, obj).get_attr_from_name(attr) 

744 

745 self._transfer_function.attrs[key] = value 

746 return self._transfer_function 

747 

748 def _validate_input_ndarray( 

749 self, ndarray: np.ndarray, atype: str = "impedance" 

750 ) -> None: 

751 """ 

752 Validate the input based on array type and component 

753 

754 Parameters 

755 ---------- 

756 ndarray : np.ndarray 

757 The input array to validate. 

758 atype : str 

759 The type of the array (e.g. "impedance", "tipper"). 

760 

761 """ 

762 shape_dict = { 

763 "impedance": (2, 2), 

764 "tipper": (1, 2), 

765 "impedance_error": (2, 2), 

766 "impedance_model_error": (2, 2), 

767 "tipper_error": (1, 2), 

768 "tipper_model_error": (1, 2), 

769 "isp": (2, 2), 

770 "res": (3, 3), 

771 "transfer_function": (3, 2), 

772 "transfer_function_error": (3, 2), 

773 "tf": (3, 2), 

774 "tf_error": (3, 2), 

775 } 

776 

777 shape = shape_dict[atype] 

778 if ndarray.shape[1:] != shape: 

779 msg = ( 

780 f"{atype} must be have shape (n_periods, {shape[0]}, " 

781 f"{shape[1]}), not {ndarray.shape}" 

782 ) 

783 logger.error(msg) 

784 raise TFError(msg) 

785 if ndarray.shape[0] != self.period.size: 

786 msg = ( 

787 f"New {atype} shape {ndarray.shape} not same as old {shape}, " 

788 "suggest creating a new instance." 

789 ) 

790 logger.error(msg) 

791 raise TFError(msg) 

792 

793 def _validate_input_dataarray( 

794 self, da: xr.DataArray, atype: str = "impedance" 

795 ) -> xr.DataArray: 

796 """ 

797 Validate an input data array 

798 

799 Parameters 

800 ---------- 

801 da : xr.DataArray 

802 The input data array to validate. 

803 atype : str 

804 The type of the array (e.g. "impedance", "tipper"). 

805 

806 """ 

807 

808 ch_in = self._ch_input_dict[atype] 

809 ch_out = self._ch_output_dict[atype] 

810 

811 # should test for shape 

812 if "period" not in da.coords.keys() or "input" not in da.coords.keys(): 

813 msg = f"Coordinates must be period, output, input, not {list(da.coords.keys())}" 

814 logger.error( 

815 msg, 

816 ) 

817 raise TFError(msg) 

818 if sorted(ch_out) != sorted(da.coords["output"].data.tolist()): 

819 msg = ( 

820 f"Output dimensions must be {ch_out} not " 

821 f"{da.coords['output'].data.tolist()}" 

822 ) 

823 logger.error(msg) 

824 raise TFError(msg) 

825 if sorted(ch_in) != sorted(da.coords["input"].data.tolist()): 

826 msg = ( 

827 f"Input dimensions must be {ch_in} not " 

828 f"{da.coords['input'].data.tolist()}" 

829 ) 

830 logger.error(msg) 

831 raise TFError(msg) 

832 # need to reorder the data array to the expected coordinates 

833 da = da.reindex(output=ch_out, input=ch_in) 

834 # if this is the first instantiation then just resize the 

835 # transfer function to fit the input 

836 if ( 

837 self._transfer_function.transfer_function.data.shape[0] == 1 

838 and not self.has_tipper() 

839 and not self.has_impedance() 

840 ): 

841 self._transfer_function = self._initialize_transfer_function(da.period) 

842 return da 

843 elif ( 

844 self._transfer_function.transfer_function.data.shape[0] == da.data.shape[0] 

845 ): 

846 return da 

847 else: 

848 msg = "Reassigning with a different shape is dangerous. Should re-initialize transfer_function or make a new instance of TF" 

849 logger.error(msg) 

850 raise TFError(msg) 

851 

852 def _set_data_array( 

853 self, value: xr.DataArray | np.ndarray | list | tuple | None, atype: str 

854 ) -> None: 

855 """ 

856 

857 Parameters 

858 ---------- 

859 value : xr.DataArray | np.ndarray | list | tuple | None 

860 The data array to set. 

861 atype : str 

862 The type of the array (e.g. "impedance", "tipper"). 

863 

864 """ 

865 if value is None: 

866 return 

867 key_dict = { 

868 "tf": "transfer_function", 

869 "impedance": "transfer_function", 

870 "tipper": "transfer_function", 

871 "isp": "inverse_signal_power", 

872 "res": "residual_covariance", 

873 "transfer_function": "transfer_function", 

874 "impedance_error": "transfer_function_error", 

875 "impedance_model_error": "transfer_function_model_error", 

876 "tipper_error": "transfer_function_error", 

877 "tipper_model_error": "transfer_function_model_error", 

878 "tf_error": "transfer_function_error", 

879 "tf_model_error": "transfer_function_model_error", 

880 "transfer_function_error": "transfer_function_error", 

881 "transfer_function_model_error": "transfer_function_model_error", 

882 } 

883 key = key_dict[atype] 

884 ch_in = self._ch_input_dict[atype] 

885 ch_out = self._ch_output_dict[atype] 

886 comps = dict(input=ch_in, output=ch_out) 

887 

888 if isinstance(value, (list, tuple, np.ndarray)): 

889 value = np.array(value) 

890 self._validate_input_ndarray(value, atype=atype) 

891 

892 self._transfer_function[key].loc[comps] = value 

893 elif isinstance(value, xr.DataArray): 

894 nda = self._validate_input_dataarray(value, atype=atype) 

895 

896 self._transfer_function[key].loc[comps] = nda 

897 else: 

898 msg = ( 

899 f"Data type {type(value)} not supported use a numpy " 

900 "array or xarray.DataArray" 

901 ) 

902 logger.error(msg) 

903 raise TFError(msg) 

904 

905 def has_transfer_function(self) -> bool: 

906 """ 

907 Check to see if the transfer function is not 0 and has 

908 transfer function components 

909 

910 Returns 

911 ------- 

912 bool 

913 True if the transfer function is not 0 and has components, False otherwise. 

914 

915 """ 

916 outputs = self._transfer_function.transfer_function.coords[ 

917 "output" 

918 ].data.tolist() 

919 if self.ex in outputs or self.ey in outputs or self.hz in outputs: 

920 if np.all( 

921 self._transfer_function.transfer_function.loc[ 

922 dict( 

923 input=self._ch_input_dict["tf"], 

924 output=self._ch_output_dict["tf"], 

925 ) 

926 ].data 

927 == 0 

928 ): 

929 return False 

930 return True 

931 return False 

932 

933 @property 

934 def transfer_function(self) -> xr.DataArray | None: 

935 """ 

936 

937 Returns 

938 ------- 

939 xr.DataArray | None 

940 The transfer function data array or None if not set. 

941 

942 """ 

943 if self.has_transfer_function(): 

944 ds = self.dataset.transfer_function.loc[ 

945 dict(input=self.hx_hy, output=self.ex_ey_hz) 

946 ] 

947 for key, mkey in self._dataset_attr_dict.items(): 

948 obj, attr = mkey.split(".", 1) 

949 value = getattr(self, obj).get_attr_from_name(attr) 

950 

951 ds.attrs[key] = value 

952 return ds 

953 

954 @transfer_function.setter 

955 def transfer_function(self, value: xr.DataArray | np.ndarray | list | tuple | None): 

956 """ 

957 Set the impedance from values 

958 

959 Parameters 

960 ---------- 

961 value : xr.DataArray | np.ndarray | list | tuple | None 

962 The data array to set. 

963 atype : str 

964 The type of the array (e.g. "impedance", "tipper"). 

965 

966 """ 

967 self._set_data_array(value, "tf") 

968 

969 @property 

970 def transfer_function_error(self) -> xr.DataArray | None: 

971 """ 

972 

973 Returns 

974 ------- 

975 xr.DataArray | None 

976 The transfer function error data array or None if not set. 

977 

978 """ 

979 if self.has_transfer_function(): 

980 ds = self.dataset.transfer_function_error.loc[ 

981 dict(input=self.hx_hy, output=self.ex_ey_hz) 

982 ] 

983 for key, mkey in self._dataset_attr_dict.items(): 

984 obj, attr = mkey.split(".", 1) 

985 value = getattr(self, obj).get_attr_from_name(attr) 

986 

987 ds.attrs[key] = value 

988 return ds 

989 

990 @transfer_function_error.setter 

991 def transfer_function_error( 

992 self, value: xr.DataArray | np.ndarray | list | tuple | None 

993 ): 

994 """ 

995 Set the impedance from values 

996 

997 Parameters 

998 ---------- 

999 value : xr.DataArray | np.ndarray | list | tuple | None 

1000 The data array to set. 

1001 atype : str 

1002 The type of the array (e.g. "impedance", "tipper"). 

1003 """ 

1004 self._set_data_array(value, "tf_error") 

1005 

1006 @property 

1007 def transfer_function_model_error(self) -> xr.DataArray | None: 

1008 """ 

1009 

1010 Returns 

1011 ------- 

1012 xr.DataArray | None 

1013 The transfer function model error data array or None if not set. 

1014 

1015 """ 

1016 if self.has_transfer_function(): 

1017 ds = self.dataset.transfer_function_model_error.loc[ 

1018 dict(input=self.hx_hy, output=self.ex_ey_hz) 

1019 ] 

1020 for key, mkey in self._dataset_attr_dict.items(): 

1021 obj, attr = mkey.split(".", 1) 

1022 value = getattr(self, obj).get_attr_from_name(attr) 

1023 

1024 ds.attrs[key] = value 

1025 return ds 

1026 

1027 @transfer_function_model_error.setter 

1028 def transfer_function_model_error( 

1029 self, value: xr.DataArray | np.ndarray | list | tuple | None 

1030 ): 

1031 """ 

1032 Set the impedance from values 

1033 

1034 Parameters 

1035 ---------- 

1036 value : xr.DataArray | np.ndarray | list | tuple | None 

1037 The data array to set. 

1038 atype : str 

1039 The type of the array (e.g. "impedance", "tipper"). 

1040 """ 

1041 self._set_data_array(value, "tf_model_error") 

1042 

1043 def has_impedance(self) -> bool: 

1044 """ 

1045 Check to see if the transfer function is not 0 and has 

1046 transfer function components 

1047 

1048 Returns 

1049 ------- 

1050 bool 

1051 True if the transfer function has impedance components, False otherwise. 

1052 

1053 """ 

1054 outputs = self._transfer_function.transfer_function.coords[ 

1055 "output" 

1056 ].data.tolist() 

1057 if self.ex in outputs or self.ey in outputs: 

1058 if np.all( 

1059 self._transfer_function.transfer_function.loc[ 

1060 dict( 

1061 input=self._ch_input_dict["impedance"], 

1062 output=self._ch_output_dict["impedance"], 

1063 ) 

1064 ].data 

1065 == 0 

1066 ): 

1067 return False 

1068 return True 

1069 return False 

1070 

1071 @property 

1072 def impedance(self) -> xr.DataArray | None: 

1073 """ 

1074 

1075 Returns 

1076 ------- 

1077 xr.DataArray | None 

1078 The impedance data array or None if not set. 

1079 """ 

1080 if self.has_impedance(): 

1081 z = self.dataset.transfer_function.loc[ 

1082 dict( 

1083 input=self._ch_input_dict["impedance"], 

1084 output=self._ch_output_dict["impedance"], 

1085 ) 

1086 ] 

1087 z.name = "impedance" 

1088 for key, mkey in self._dataset_attr_dict.items(): 

1089 obj, attr = mkey.split(".", 1) 

1090 value = getattr(self, obj).get_attr_from_name(attr) 

1091 

1092 z.attrs[key] = value 

1093 return z 

1094 

1095 @impedance.setter 

1096 def impedance(self, value: xr.DataArray | np.ndarray | list | tuple | None): 

1097 """ 

1098 Set the impedance from values 

1099 

1100 Parameters 

1101 ---------- 

1102 value : xr.DataArray | np.ndarray | list | tuple | None 

1103 The data array to set. 

1104 """ 

1105 self._set_data_array(value, "impedance") 

1106 

1107 @property 

1108 def impedance_error(self) -> xr.DataArray | None: 

1109 """ 

1110 

1111 Returns 

1112 ------- 

1113 xr.DataArray | None 

1114 The impedance error data array or None if not set. 

1115 

1116 """ 

1117 if self.has_impedance(): 

1118 z_err = self.dataset.transfer_function_error.loc[ 

1119 dict( 

1120 input=self._ch_input_dict["impedance"], 

1121 output=self._ch_output_dict["impedance"], 

1122 ) 

1123 ] 

1124 z_err.name = "impedance_error" 

1125 

1126 for key, mkey in self._dataset_attr_dict.items(): 

1127 obj, attr = mkey.split(".", 1) 

1128 value = getattr(self, obj).get_attr_from_name(attr) 

1129 

1130 z_err.attrs[key] = value 

1131 return z_err 

1132 

1133 @impedance_error.setter 

1134 def impedance_error(self, value: xr.DataArray | np.ndarray | list | tuple | None): 

1135 """ 

1136 Set the impedance from values 

1137 

1138 Parameters 

1139 ---------- 

1140 value : xr.DataArray | np.ndarray | list | tuple | None 

1141 The data array to set. 

1142 """ 

1143 self._set_data_array(value, "impedance_error") 

1144 

1145 @property 

1146 def impedance_model_error(self) -> xr.DataArray | None: 

1147 """ 

1148 

1149 Returns 

1150 ------- 

1151 xr.DataArray | None 

1152 The impedance model error data array or None if not set. 

1153 

1154 """ 

1155 if self.has_impedance(): 

1156 z_err = self.dataset.transfer_function_model_error.loc[ 

1157 dict( 

1158 input=self._ch_input_dict["impedance"], 

1159 output=self._ch_output_dict["impedance"], 

1160 ) 

1161 ] 

1162 z_err.name = "impedance_model_error" 

1163 

1164 for key, mkey in self._dataset_attr_dict.items(): 

1165 obj, attr = mkey.split(".", 1) 

1166 value = getattr(self, obj).get_attr_from_name(attr) 

1167 

1168 z_err.attrs[key] = value 

1169 return z_err 

1170 

1171 @impedance_model_error.setter 

1172 def impedance_model_error( 

1173 self, value: xr.DataArray | np.ndarray | list | tuple | None 

1174 ): 

1175 """ 

1176 Set the impedance model errors from values 

1177 

1178 Parameters 

1179 ---------- 

1180 value : xr.DataArray | np.ndarray | list | tuple | None 

1181 The data array to set. 

1182 """ 

1183 self._set_data_array(value, "impedance_model_error") 

1184 

1185 def has_tipper(self) -> bool: 

1186 """ 

1187 Check to see if the transfer function is not 0 and has 

1188 transfer function components 

1189 

1190 Returns 

1191 ------- 

1192 bool 

1193 True if the transfer function has tipper components, False otherwise. 

1194 """ 

1195 outputs = self._transfer_function.transfer_function.coords[ 

1196 "output" 

1197 ].data.tolist() 

1198 if self.hz in outputs: 

1199 if np.all( 

1200 np.nan_to_num( 

1201 self._transfer_function.transfer_function.loc[ 

1202 dict( 

1203 input=self._ch_input_dict["tipper"], 

1204 output=self._ch_output_dict["tipper"], 

1205 ) 

1206 ].data 

1207 ) 

1208 == 0 

1209 ): 

1210 return False 

1211 return True 

1212 return False 

1213 

1214 @property 

1215 def tipper(self) -> xr.DataArray | None: 

1216 """ 

1217 

1218 Returns 

1219 ------- 

1220 xr.DataArray | None 

1221 The tipper data array or None if not set. 

1222 

1223 """ 

1224 if self.has_tipper(): 

1225 t = self.dataset.transfer_function.loc[ 

1226 dict( 

1227 input=self._ch_input_dict["tipper"], 

1228 output=self._ch_output_dict["tipper"], 

1229 ) 

1230 ] 

1231 t.name = "tipper" 

1232 

1233 for key, mkey in self._dataset_attr_dict.items(): 

1234 obj, attr = mkey.split(".", 1) 

1235 value = getattr(self, obj).get_attr_from_name(attr) 

1236 

1237 t.attrs[key] = value 

1238 return t 

1239 

1240 @tipper.setter 

1241 def tipper(self, value: xr.DataArray | np.ndarray | list | tuple | None): 

1242 """ 

1243 

1244 Parameters 

1245 ---------- 

1246 value : xr.DataArray | np.ndarray | list | tuple | None 

1247 The data array to set. 

1248 """ 

1249 

1250 self._set_data_array(value, "tipper") 

1251 

1252 @property 

1253 def tipper_error(self) -> xr.DataArray | None: 

1254 """ 

1255 

1256 Returns 

1257 ------- 

1258 xr.DataArray | None 

1259 The tipper error data array or None if not set. 

1260 

1261 """ 

1262 

1263 if self.has_tipper(): 

1264 t = self.dataset.transfer_function_error.loc[ 

1265 dict( 

1266 input=self._ch_input_dict["tipper"], 

1267 output=self._ch_output_dict["tipper"], 

1268 ) 

1269 ] 

1270 t.name = "tipper_error" 

1271 for key, mkey in self._dataset_attr_dict.items(): 

1272 obj, attr = mkey.split(".", 1) 

1273 value = getattr(self, obj).get_attr_from_name(attr) 

1274 

1275 t.attrs[key] = value 

1276 return t 

1277 

1278 @tipper_error.setter 

1279 def tipper_error(self, value: xr.DataArray | np.ndarray | list | tuple | None): 

1280 """ 

1281 

1282 Parameters 

1283 ---------- 

1284 value : xr.DataArray | np.ndarray | list | tuple | None 

1285 The data array to set. 

1286 

1287 """ 

1288 self._set_data_array(value, "tipper_error") 

1289 

1290 @property 

1291 def tipper_model_error(self) -> xr.DataArray | None: 

1292 """ 

1293 

1294 Returns 

1295 ------- 

1296 xr.DataArray | None 

1297 The tipper model error data array or None if not set. 

1298 

1299 """ 

1300 if self.has_tipper(): 

1301 t = self.dataset.transfer_function_model_error.loc[ 

1302 dict( 

1303 input=self._ch_input_dict["tipper"], 

1304 output=self._ch_output_dict["tipper"], 

1305 ) 

1306 ] 

1307 t.name = "tipper_model_error" 

1308 for key, mkey in self._dataset_attr_dict.items(): 

1309 obj, attr = mkey.split(".", 1) 

1310 value = getattr(self, obj).get_attr_from_name(attr) 

1311 

1312 t.attrs[key] = value 

1313 return t 

1314 

1315 @tipper_model_error.setter 

1316 def tipper_model_error( 

1317 self, value: xr.DataArray | np.ndarray | list | tuple | None 

1318 ): 

1319 """ 

1320 

1321 Parameters 

1322 ---------- 

1323 value : xr.DataArray | np.ndarray | list | tuple | None 

1324 The data array to set. 

1325 

1326 """ 

1327 self._set_data_array(value, "tipper_model_error") 

1328 

1329 def has_inverse_signal_power(self) -> bool: 

1330 """ 

1331 Check to see if the transfer function is not 0 and has 

1332 transfer function components 

1333 

1334 Returns 

1335 ------- 

1336 bool 

1337 True if the inverse signal power is set and not zero, False otherwise. 

1338 

1339 """ 

1340 

1341 if np.all( 

1342 self._transfer_function.inverse_signal_power.loc[ 

1343 dict( 

1344 input=self._ch_input_dict["isp"], 

1345 output=self._ch_output_dict["isp"], 

1346 ) 

1347 ].data 

1348 == 0 

1349 ): 

1350 return False 

1351 return True 

1352 

1353 @property 

1354 def inverse_signal_power(self) -> xr.DataArray | None: 

1355 """ 

1356 Get the inverse signal power data array. 

1357 

1358 Returns 

1359 ------- 

1360 xr.DataArray | None 

1361 The inverse signal power data array or None if not set. 

1362 """ 

1363 if self.has_inverse_signal_power(): 

1364 ds = self.dataset.inverse_signal_power.loc[ 

1365 dict( 

1366 input=self._ch_input_dict["isp"], 

1367 output=self._ch_output_dict["isp"], 

1368 ) 

1369 ] 

1370 for key, mkey in self._dataset_attr_dict.items(): 

1371 obj, attr = mkey.split(".", 1) 

1372 value = getattr(self, obj).get_attr_from_name(attr) 

1373 

1374 ds.attrs[key] = value 

1375 return ds 

1376 return None 

1377 

1378 @inverse_signal_power.setter 

1379 def inverse_signal_power( 

1380 self, value: xr.DataArray | np.ndarray | list | tuple | None 

1381 ): 

1382 """ 

1383 Set the inverse signal power 

1384 

1385 Parameters 

1386 ---------- 

1387 value : xr.DataArray | np.ndarray | list | tuple | None 

1388 The data array to set. 

1389 

1390 """ 

1391 self._set_data_array(value, "isp") 

1392 if self.has_residual_covariance(): 

1393 self._compute_error_from_covariance() 

1394 

1395 def has_residual_covariance(self) -> bool: 

1396 """ 

1397 Check to see if the transfer function is not 0 and has 

1398 transfer function components 

1399 

1400 Returns 

1401 ------- 

1402 bool 

1403 True if the residual covariance is set and not zero, False otherwise. 

1404 

1405 """ 

1406 

1407 if np.all( 

1408 self._transfer_function.residual_covariance.loc[ 

1409 dict( 

1410 input=self._ch_input_dict["res"], 

1411 output=self._ch_output_dict["res"], 

1412 ) 

1413 ].data 

1414 == 0 

1415 ): 

1416 return False 

1417 return True 

1418 

1419 @property 

1420 def residual_covariance(self) -> xr.DataArray | None: 

1421 """ 

1422 Get the residual covariance data array. 

1423 

1424 Returns 

1425 ------- 

1426 xr.DataArray | None 

1427 The residual covariance data array or None if not set. 

1428 """ 

1429 if self.has_residual_covariance(): 

1430 ds = self.dataset.residual_covariance.loc[ 

1431 dict( 

1432 input=self._ch_input_dict["res"], 

1433 output=self._ch_output_dict["res"], 

1434 ) 

1435 ] 

1436 for key, mkey in self._dataset_attr_dict.items(): 

1437 obj, attr = mkey.split(".", 1) 

1438 value = getattr(self, obj).get_attr_from_name(attr) 

1439 

1440 ds.attrs[key] = value 

1441 return ds 

1442 return None 

1443 

1444 @residual_covariance.setter 

1445 def residual_covariance( 

1446 self, value: xr.DataArray | np.ndarray | list | tuple | None 

1447 ): 

1448 """ 

1449 Set the residual covariance 

1450 

1451 Parameters 

1452 ---------- 

1453 value : xr.DataArray | np.ndarray | list | tuple | None 

1454 The data array to set. 

1455 

1456 """ 

1457 self._set_data_array(value, "res") 

1458 if self.has_inverse_signal_power(): 

1459 self._compute_error_from_covariance() 

1460 

1461 def _compute_impedance_error_from_covariance(self) -> None: 

1462 """ 

1463 Compute transfer function errors from covariance matrices 

1464 

1465 This will become important when writing edi files. 

1466 

1467 Translated from code written by Ben Murphy. 

1468 

1469 """ 

1470 sigma_e = self.residual_covariance.loc[ 

1471 dict(input=self.ex_ey, output=self.ex_ey) 

1472 ] 

1473 sigma_s = self.inverse_signal_power.loc[ 

1474 dict(input=self.hx_hy, output=self.hx_hy) 

1475 ] 

1476 

1477 z_err = np.zeros((self.period.size, 2, 2), dtype=float) 

1478 z_err[:, 0, 0] = np.abs( 

1479 sigma_e.loc[dict(input=[self.ex], output=[self.ex])].data.flatten() 

1480 * sigma_s.loc[dict(input=[self.hx], output=[self.hx])].data.flatten() 

1481 ) 

1482 z_err[:, 0, 1] = np.abs( 

1483 sigma_e.loc[dict(input=[self.ex], output=[self.ex])].data.flatten() 

1484 * sigma_s.loc[dict(input=[self.hy], output=[self.hy])].data.flatten() 

1485 ) 

1486 z_err[:, 1, 0] = np.abs( 

1487 sigma_e.loc[dict(input=[self.ey], output=[self.ey])].data.flatten() 

1488 * sigma_s.loc[dict(input=[self.hx], output=[self.hx])].data.flatten() 

1489 ) 

1490 z_err[:, 1, 1] = np.abs( 

1491 sigma_e.loc[dict(input=[self.ey], output=[self.ey])].data.flatten() 

1492 * sigma_s.loc[dict(input=[self.hy], output=[self.hy])].data.flatten() 

1493 ) 

1494 

1495 z_err = np.sqrt(np.abs(z_err)) 

1496 

1497 self.dataset.transfer_function_error.loc[ 

1498 dict(input=self.hx_hy, output=self.ex_ey) 

1499 ] = z_err 

1500 

1501 def _compute_tipper_error_from_covariance(self) -> None: 

1502 """ 

1503 Compute transfer function errors from covariance matrices 

1504 

1505 This will become important when writing edi files. 

1506 

1507 Translated from code written by Ben Murphy. 

1508 

1509 """ 

1510 sigma_e = self.residual_covariance.loc[dict(input=[self.hz], output=[self.hz])] 

1511 sigma_s = self.inverse_signal_power.loc[ 

1512 dict(input=self.hx_hy, output=self.hx_hy) 

1513 ] 

1514 

1515 t_err = np.zeros((self.period.size, 1, 2), dtype=float) 

1516 t_err[:, 0, 0] = np.abs( 

1517 sigma_e.loc[dict(input=[self.hz], output=[self.hz])].data.flatten() 

1518 * sigma_s.loc[dict(input=[self.hx], output=[self.hx])].data.flatten() 

1519 ) 

1520 t_err[:, 0, 1] = np.abs( 

1521 sigma_e.loc[dict(input=[self.hz], output=[self.hz])].data.flatten() 

1522 * sigma_s.loc[dict(input=[self.hy], output=[self.hy])].data.flatten() 

1523 ) 

1524 

1525 t_err = np.sqrt(np.abs(t_err)) 

1526 

1527 self.dataset.transfer_function_error.loc[ 

1528 dict(input=self.hx_hy, output=[self.hz]) 

1529 ] = t_err 

1530 

1531 def _compute_error_from_covariance(self) -> None: 

1532 """ 

1533 convenience method to compute errors from covariance 

1534 

1535 """ 

1536 self._compute_impedance_error_from_covariance() 

1537 self._compute_tipper_error_from_covariance() 

1538 

1539 @property 

1540 def period(self) -> np.ndarray | None: 

1541 """Periods of the transfer function""" 

1542 return self.dataset.period.data 

1543 

1544 @period.setter 

1545 def period(self, value: np.ndarray | None): 

1546 """ 

1547 Set the periods of the transfer function. 

1548 

1549 Parameters 

1550 ---------- 

1551 value : np.ndarray | None 

1552 The new periods for the transfer function. 

1553 

1554 Raises 

1555 ------ 

1556 TFError 

1557 If the new periods are not compatible with the existing ones. 

1558 """ 

1559 if self.period is not None: 

1560 if len(self.period) == 1 and (self.period == np.array([1])).all(): 

1561 self._transfer_function = self._initialize_transfer_function( 

1562 periods=value 

1563 ) 

1564 elif len(value) != len(self.period): 

1565 msg = ( 

1566 f"New period size {value.size} is not the same size as " 

1567 f"old ones {self.period.size}, suggest creating a new " 

1568 "instance of TF." 

1569 ) 

1570 logger.error(msg) 

1571 raise TFError(msg) 

1572 elif not (self.period == value).all(): 

1573 self.dataset["period"] = value 

1574 else: 

1575 self._transfer_function = self._initialize_transfer_function(periods=value) 

1576 return 

1577 

1578 @property 

1579 def frequency(self) -> np.ndarray | None: 

1580 if self.period is not None: 

1581 return 1.0 / self.period 

1582 return None 

1583 

1584 @frequency.setter 

1585 def frequency(self, value: np.ndarray | None): 

1586 if value is not None: 

1587 self.period = 1.0 / value 

1588 

1589 @property 

1590 def station(self) -> str: 

1591 """station name""" 

1592 return self.station_metadata.id 

1593 

1594 @station.setter 

1595 def station(self, station_name: str): 

1596 """ 

1597 set station name 

1598 """ 

1599 self.station_metadata.id = validate_name(station_name) 

1600 if self.station_metadata.runs[0].id is None: 

1601 r = self.station_metadata.runs.pop(None) 

1602 r.id = f"{self.station_metadata.id}a" 

1603 self.station_metadata.runs.append(r) 

1604 

1605 @property 

1606 def survey(self) -> str: 

1607 """ 

1608 Survey ID 

1609 """ 

1610 return self.survey_metadata.id 

1611 

1612 @survey.setter 

1613 def survey(self, survey_id: str): 

1614 """ 

1615 set survey id 

1616 """ 

1617 if survey_id is None: 

1618 survey_id = "unkown_survey" 

1619 self.survey_metadata.id = validate_name(survey_id) 

1620 

1621 @property 

1622 def tf_id(self) -> str: 

1623 """transfer function id""" 

1624 return self.station_metadata.transfer_function.id 

1625 

1626 @tf_id.setter 

1627 def tf_id(self, value: str): 

1628 """set transfer function id""" 

1629 self.station_metadata.transfer_function.id = validate_name(value) 

1630 

1631 def to_ts_station_metadata(self) -> TSStation: 

1632 """ 

1633 need a convinience function to translate to ts station metadata 

1634 for MTH5 

1635 

1636 """ 

1637 

1638 ts_station_metadata = TSStation() # type: ignore 

1639 for key, value in self.station_metadata.to_dict(single=True).items(): 

1640 if "transfer_function" in key: 

1641 continue 

1642 try: 

1643 ts_station_metadata.update_attribute(key, value) 

1644 except AttributeError: 

1645 logger.debug(f"Attribute {key} could not be set.") 

1646 return ts_station_metadata 

1647 

1648 def from_ts_station_metadata(self, ts_station_metadata: TSStation): 

1649 """ 

1650 need a convinience function to translate to ts station metadata 

1651 for MTH5 

1652 

1653 """ 

1654 

1655 for key, value in ts_station_metadata.to_dict(single=True).items(): 

1656 try: 

1657 self.station_metadata.update_attribute(key, value) 

1658 except AttributeError: 

1659 continue 

1660 

1661 def merge( 

1662 self, 

1663 other: "TF", 

1664 period_min: float | None = None, 

1665 period_max: float | None = None, 

1666 inplace: bool = False, 

1667 ) -> "TF | None": 

1668 """ 

1669 metadata will be assumed to be from self. 

1670 

1671 Merge transfer functions together. `other` can be another `TF` object 

1672 or a tuple of `TF` objects 

1673 

1674 to set bounds should be of the format 

1675 

1676 [{"tf": tf_01, "period_min": .01, "period_max": 100}, 

1677 {"tf": tf_02, "period_min": 100, "period_max": 1000}] 

1678 

1679 or to just use whats in the transfer function 

1680 [tf_01, tf_02, ...] 

1681 

1682 The bounds are inclusive, so if you want to merge at say 1 s choose 

1683 the best one and set the other to a value lower or higher depending 

1684 on the periods for that transfer function, for example 

1685 

1686 [{"tf": tf_01, "period_min": .01, "period_max": 100}, 

1687 {"tf": tf_02, "period_min": 100.1, "period_max": 1000}] 

1688 

1689 Parameters 

1690 ---------- 

1691 other: TF, list of dicts, list of TF objects, dict 

1692 other transfer functions to merge with 

1693 period_min: float 

1694 minimum period for the original TF 

1695 period_max: float 

1696 maximum period for the original TF 

1697 inplace: bool 

1698 whether to modify the original TF or return a new one 

1699 

1700 Returns 

1701 ------- 

1702 TF | None 

1703 merged transfer function or None if inplace=True 

1704 

1705 

1706 """ 

1707 

1708 def get_slice_dict(period_min: float, period_max: float) -> dict[str, slice]: 

1709 """ 

1710 Get an the correct dictionary for slicing an xarray. 

1711 

1712 Parameters 

1713 ---------- 

1714 period_min: float 

1715 minimum period 

1716 period_max: float 

1717 maximum period 

1718 

1719 Returns 

1720 ------- 

1721 dict[str, slice] 

1722 variable to slice an xarray 

1723 

1724 """ 

1725 return {"period": slice(period_min, period_max)} 

1726 

1727 def sort_by_period(tf: xr.Dataset) -> xr.Dataset: 

1728 """ 

1729 period needs to be monotonically increasing for slice to work. 

1730 """ 

1731 return tf.sortby("period") 

1732 

1733 def is_tf(item: xr.Dataset) -> xr.Dataset: 

1734 """ 

1735 If the item is a transfer function return it sorted by period 

1736 

1737 Parameters 

1738 ---------- 

1739 item: transfer function 

1740 type item: :class:`mt_metadata.transfer_function.core.TF` 

1741 

1742 Returns 

1743 ------- 

1744 sorted by period transfer function 

1745 rtype: xarray.Dataset 

1746 

1747 """ 

1748 return sort_by_period(item._transfer_function) 

1749 

1750 def validate_dict(item: dict[str, Any]) -> dict[str, Any]: 

1751 """ 

1752 Make sure input dictionary has proper keys. 

1753 

1754 - **tf** :class:`mt_metadata.transfer_function.core.TF` 

1755 - **period_min** minumum period (s) 

1756 - **period_max** maximum period (s) 

1757 

1758 Parameters 

1759 ---------- 

1760 item: dict 

1761 dictionary to slice a transfer function 

1762 

1763 Returns 

1764 ------- 

1765 validated dictionary 

1766 rtype: dict 

1767 

1768 Raises 

1769 ------- 

1770 KeyError 

1771 If keys are not what they should be 

1772 

1773 """ 

1774 accepted_keys = sorted(["tf", "period_min", "period_max"]) 

1775 

1776 if accepted_keys != sorted(list(item.keys())): 

1777 msg = f"Input dictionary must have keys of {accepted_keys}" 

1778 logger.error(msg) 

1779 raise KeyError(msg) 

1780 return item 

1781 

1782 def is_dict(item: dict) -> xr.Dataset: 

1783 """ 

1784 If the item is a dictionary then be sure to sort the transfer 

1785 function and then apply the slice. 

1786 

1787 Parameters 

1788 ---------- 

1789 item: dict 

1790 dictionary with keys 'tf', 'period_min', 'period_max' 

1791 

1792 Returns 

1793 ------- 

1794 sliced transfer function 

1795 rtype: xarray.Dataset 

1796 

1797 Raises 

1798 ------ 

1799 KeyError 

1800 If keys are not what they should be 

1801 

1802 """ 

1803 item = validate_dict(item) 

1804 period_slice = get_slice_dict(item["period_min"], item["period_max"]) 

1805 item["tf"]._transfer_function = sort_by_period( 

1806 item["tf"]._transfer_function 

1807 ) 

1808 return get_slice(item["tf"], period_slice) 

1809 

1810 def get_slice(tf, period_slice: dict[str, slice]) -> xr.Dataset | None: 

1811 """ 

1812 Get slice of a transfer function most of the time we can use .loc 

1813 but sometimes a key error occurs if the period index is not 

1814 monotonic (which is should be now after using .sortby('period')), 

1815 but leaving in place just in case. If .loc does not work, then 

1816 we can use .where(conditions) to slice the transfer function. 

1817 

1818 Parameters 

1819 ---------- 

1820 tf: xarray.Dataset 

1821 The transfer function to slice. 

1822 period_slice: dict[str, slice] 

1823 The slice to apply to the period dimension. 

1824 

1825 Returns 

1826 ------- 

1827 xarray.Dataset 

1828 The sliced transfer function. 

1829 """ 

1830 try: 

1831 return tf._transfer_function.loc[period_slice] 

1832 

1833 except KeyError: 

1834 if ( 

1835 period_slice["period"].start is not None 

1836 and period_slice["period"].stop is not None 

1837 ): 

1838 return tf._transfer_function.where( 

1839 (tf._transfer_function.period >= period_slice["period"].start) 

1840 & (tf._transfer_function.period <= period_slice["period"].stop), 

1841 drop=True, 

1842 ) 

1843 elif ( 

1844 period_slice["period"].start is None 

1845 and period_slice["period"].stop is not None 

1846 ): 

1847 return tf._transfer_function.where( 

1848 (tf._transfer_function.period <= period_slice["period"].stop), 

1849 drop=True, 

1850 ) 

1851 elif ( 

1852 period_slice["period"].start is not None 

1853 and period_slice["period"].stop is None 

1854 ): 

1855 return tf._transfer_function.where( 

1856 (tf._transfer_function.period >= period_slice["period"].start), 

1857 drop=True, 

1858 ) 

1859 

1860 period_slice_self = get_slice_dict(period_min, period_max) 

1861 tf_list = [get_slice(self, period_slice_self)] 

1862 if not isinstance(other, list): 

1863 other = [other] 

1864 

1865 for item in other: 

1866 if isinstance(item, TF): 

1867 tf_list.append(is_tf(item)) 

1868 elif isinstance(item, dict): 

1869 tf_list.append(is_dict(item)) 

1870 else: 

1871 msg = f"Type {type(item)} not supported" 

1872 logger.error(msg) 

1873 raise TypeError(msg) 

1874 

1875 new_tf = xr.combine_by_coords(tf_list, combine_attrs="override") 

1876 

1877 if inplace: 

1878 self._transfer_function = new_tf 

1879 else: 

1880 return_tf = self.copy() 

1881 return_tf._transfer_function = new_tf 

1882 return return_tf 

1883 

1884 def write( 

1885 self, 

1886 fn: str | Path | None = None, 

1887 save_dir: str | Path | None = None, 

1888 fn_basename: str | None = None, 

1889 file_type: Literal["edi", "xml", "zmm", "avg", "j"] = "edi", 

1890 **kwargs, 

1891 ): 

1892 """ 

1893 Write an mt file, the supported file types are EDI and XML. 

1894 

1895 .. todo:: j-files 

1896 

1897 Parameters 

1898 ---------- 

1899 fn: str | Path | None 

1900 Full path to file to save to. 

1901 save_dir: str | Path | None 

1902 Full path save directory. 

1903 fn_basename: str | None 

1904 Name of file with or without extension. 

1905 file_type: Literal["edi", "xml", "zmm", "avg", "j"] 

1906 Type of file to write. 

1907 

1908 Optional Keyword Arguments 

1909 --------------------------- 

1910 longitude_format: str 

1911 whether to write longitude as longitude or LONG. 

1912 options are 'longitude' or 'LONG', default 'longitude' 

1913 

1914 longitude_format: string 

1915 latlon_format: format of latitude and longitude in output edi, 

1916 degrees minutes seconds ('dms') or decimal 

1917 degrees ('dd') 

1918 

1919 Returns 

1920 ------- 

1921 str 

1922 Full path to the written file. 

1923 

1924 :Example: :: 

1925 

1926 >>> tf_obj.write(file_type='xml') 

1927 

1928 """ 

1929 

1930 if fn is not None: 

1931 new_fn = Path(fn) 

1932 self.save_dir = new_fn.parent 

1933 fn_basename = new_fn.name 

1934 file_type = new_fn.suffix.lower()[1:] 

1935 if save_dir is not None: 

1936 self.save_dir = Path(save_dir) 

1937 if fn_basename is not None: 

1938 fn_basename = Path(fn_basename) 

1939 if fn_basename.suffix in ["", None]: 

1940 fn_basename = fn_basename.with_name(f"{fn_basename.name}.{file_type}") 

1941 if fn_basename is None: 

1942 fn_basename = Path(f"{self.station}.{file_type}") 

1943 if file_type is None: 

1944 file_type = fn_basename.suffix.lower()[1:] 

1945 if file_type not in self._read_write_dict.keys(): 

1946 msg = f"File type {file_type} not supported yet." 

1947 logger.error(msg) 

1948 raise TFError(msg) 

1949 fn = self.save_dir.joinpath(fn_basename) 

1950 

1951 obj = self._read_write_dict[file_type]["write"]() 

1952 obj._fn = fn 

1953 obj.write(fn, **kwargs) 

1954 

1955 return obj 

1956 

1957 def write_tf_file(self, **kwargs): 

1958 logger.error("'write_tf_file' has been deprecated use 'write()'") 

1959 

1960 def read_tf_file(self, **kwargs): 

1961 logger.error("'read_tf_file' has been deprecated use 'read()'") 

1962 

1963 def read( 

1964 self, 

1965 fn: str | Path | None = None, 

1966 file_type: str | None = None, 

1967 get_elevation: bool = False, 

1968 **kwargs, 

1969 ): 

1970 """ 

1971 

1972 Read an TF response file. 

1973 

1974 .. note:: Currently only .edi, .xml, .j, .zmm/rr/ss, .avg 

1975 files are supported 

1976 

1977 Parameters 

1978 ---------- 

1979 fn: str | Path | None 

1980 Full path to input file. 

1981 file_type: str | None 

1982 Type of file to read. If None, automatically detects file type by 

1983 the extension. Options are [edi | j | xml | avg | zmm | zrr | zss | ...] 

1984 get_elevation: bool 

1985 Whether to get elevation from US National Map DEM 

1986 

1987 :Example: :: 

1988 

1989 >>> import mt_metadata.transfer_functions import TF 

1990 >>> tf_obj = TF() 

1991 >>> tf_obj.read(fn=r"/home/mt/mt01.xml") 

1992 

1993 .. note:: If your internet is slow try setting 'get_elevation' = False, 

1994 It can get hooked in a slow loop and slow down reading. 

1995 

1996 """ 

1997 if fn is not None: 

1998 self.fn = fn 

1999 self.save_dir = self.fn.parent 

2000 if file_type is None: 

2001 file_type = self.fn.suffix.lower()[1:] 

2002 self._read_write_dict[file_type]["read"]( 

2003 self.fn, get_elevation=get_elevation, **kwargs 

2004 ) 

2005 

2006 self.station_metadata.update_time_period() 

2007 self.survey_metadata.update_bounding_box() 

2008 self.survey_metadata.update_time_period() 

2009 

2010 def to_edi(self) -> EDI: 

2011 """ 

2012 

2013 Convert the TF object to a 

2014 :class:`mt_metadata.transfer_functions.io.edi.EDI` object. From there 

2015 attributes of an EDI object can be manipulated previous to writing 

2016 to a file. 

2017 

2018 Returns 

2019 ------- 

2020 EDI object 

2021 

2022 >>> from mt_metadata.transfer_functions import TF 

2023 >>> from mt_metadata import TF_XML 

2024 >>> t = TF(TF_XML) 

2025 >>> t.read() 

2026 >>> edi_object = t.to_edi() 

2027 >>> edi_object.Header.acqby = "me" 

2028 >>> edi_object.write() 

2029 

2030 """ 

2031 

2032 edi_obj = EDI() 

2033 if self.has_impedance(): 

2034 edi_obj.z = self.impedance.data 

2035 edi_obj.z_err = self.impedance_error.data 

2036 if self.has_tipper(): 

2037 edi_obj.t = self.tipper.data 

2038 edi_obj.t_err = self.tipper_error.data 

2039 edi_obj.frequency = 1.0 / self.period 

2040 

2041 if isinstance(self._rotation_angle, (int, float)): 

2042 edi_obj.rotation_angle = np.repeat(self._rotation_angle, self.period.size) 

2043 else: 

2044 edi_obj.rotation_angle = self._rotation_angle 

2045 

2046 # fill from survey metadata 

2047 edi_obj.survey_metadata = self.survey_metadata 

2048 

2049 # fill from station metadata 

2050 edi_obj.station_metadata = self.station_metadata 

2051 

2052 # input data section 

2053 edi_obj.Data.data_type = self.station_metadata.data_type 

2054 edi_obj.Data.nfreq = self.period.size 

2055 edi_obj.Data.sectid = self.station 

2056 edi_obj.Data.nchan = len(edi_obj.Measurement.channel_ids.keys()) 

2057 edi_obj.Data.maxblks = 999 

2058 

2059 for comp in ["ex", "ey", "hx", "hy", "hz", "rrhx", "rrhy"]: 

2060 if hasattr(edi_obj.Measurement, f"meas_{comp}"): 

2061 setattr( 

2062 edi_obj.Data, 

2063 comp, 

2064 getattr(edi_obj.Measurement, f"meas_{comp}").id, 

2065 ) 

2066 edi_obj.Data.read_data(edi_obj.Data.write_data()) 

2067 

2068 edi_obj.Measurement.read_measurement(edi_obj.Measurement.write_measurement()) 

2069 

2070 return edi_obj 

2071 

2072 def from_edi( 

2073 self, edi_obj: str | Path | EDI, get_elevation: bool = False, **kwargs 

2074 ) -> None: 

2075 """ 

2076 Read in an EDI file or a 

2077 :class:`mt_metadata.transfer_functions.io.edi.EDI` object 

2078 

2079 Parameters 

2080 ---------- 

2081 

2082 edi_obj: str | Path | EDI 

2083 Path to EDI file or EDI object 

2084 If a path is provided, the file will be read from disk. 

2085 If an EDI object is provided, it will be used directly. 

2086 get_elevation: bool 

2087 Try to get elevation from US National Map, 

2088 defaults to False 

2089 

2090 Raises 

2091 ------ 

2092 TypeError 

2093 If input is incorrect 

2094 

2095 """ 

2096 

2097 if isinstance(edi_obj, (str, Path)): 

2098 self._fn = Path(edi_obj) 

2099 edi_obj = EDI(**kwargs) 

2100 edi_obj.read(self._fn, get_elevation=get_elevation) 

2101 if not isinstance(edi_obj, EDI): 

2102 raise TypeError(f"Input must be a EDI object not {type(edi_obj)}") 

2103 if edi_obj.tf is not None and edi_obj.tf.shape[1:] == (3, 2): 

2104 k_dict = OrderedDict( 

2105 { 

2106 "period": "period", 

2107 "transfer_function": "tf", 

2108 "inverse_signal_power": "signal_inverse_power", 

2109 "residual_covariance": "residual_covariance", 

2110 "transfer_function_error": "tf_err", 

2111 "survey_metadata": "survey_metadata", 

2112 # "station_metadata": "station_metadata", 

2113 "_rotation_angle": "rotation_angle", 

2114 } 

2115 ) 

2116 else: 

2117 k_dict = OrderedDict( 

2118 { 

2119 "period": "period", 

2120 "impedance": "z", 

2121 "impedance_error": "z_err", 

2122 "tipper": "t", 

2123 "tipper_error": "t_err", 

2124 "survey_metadata": "survey_metadata", 

2125 # "station_metadata": "station_metadata", 

2126 "_rotation_angle": "rotation_angle", 

2127 } 

2128 ) 

2129 for tf_key, edi_key in k_dict.items(): 

2130 setattr(self, tf_key, getattr(edi_obj, edi_key)) 

2131 

2132 def to_emtfxml(self) -> EMTFXML: 

2133 """ 

2134 Convert TF to a :class:`mt_metadata.transfer_function.io.emtfxml.EMTFXML` 

2135 object. 

2136 

2137 Returns 

2138 ------- 

2139 :return: EMTFXML object 

2140 :rtype: :class:`mt_metadata.transfer_function.io.emtfxml.EMTFXML` 

2141 

2142 >>> from mt_metadata.transfer_functions import TF 

2143 >>> from mt_metadata import TF_XML 

2144 >>> t = TF(TF_XML) 

2145 >>> t.read() 

2146 >>> xml_object = t.to_emtfxml() 

2147 >>> xml_object.site.country = "Here" 

2148 >>> xml_object.write() 

2149 

2150 """ 

2151 

2152 emtf = EMTFXML() 

2153 emtf.survey_metadata = self.survey_metadata 

2154 emtf.station_metadata = self.station_metadata 

2155 

2156 if emtf.description is None: 

2157 emtf.description = "Magnetotelluric Transfer Functions" 

2158 if emtf.product_id is None: 

2159 emtf.product_id = ( 

2160 f"{emtf.survey_metadata.project}." 

2161 f"{emtf.station_metadata.id}." 

2162 f"{emtf.station_metadata.time_period.start.year}" 

2163 ) 

2164 tags = [] 

2165 

2166 emtf.data.period = self.period 

2167 

2168 if self.has_impedance(): 

2169 tags += ["impedance"] 

2170 emtf.data.z = self.impedance.data 

2171 emtf.data.z_var = self.impedance_error.data**2 

2172 if self.has_residual_covariance() and self.has_inverse_signal_power(): 

2173 emtf.data.z_invsigcov = self.inverse_signal_power.loc[ 

2174 dict(input=self.hx_hy, output=self.hx_hy) 

2175 ].data 

2176 emtf.data.z_residcov = self.residual_covariance.loc[ 

2177 dict(input=self.ex_ey, output=self.ex_ey) 

2178 ].data 

2179 if self.has_tipper(): 

2180 tags += ["tipper"] 

2181 emtf.data.t = self.tipper.data 

2182 emtf.data.t_var = self.tipper_error.data**2 

2183 if self.has_residual_covariance() and self.has_inverse_signal_power(): 

2184 emtf.data.t_invsigcov = self.inverse_signal_power.loc[ 

2185 dict(input=self.hx_hy, output=self.hx_hy) 

2186 ].data 

2187 emtf.data.t_residcov = self.residual_covariance.loc[ 

2188 dict( 

2189 input=[self.channel_nomenclature["hz"]], 

2190 output=[self.channel_nomenclature["hz"]], 

2191 ) 

2192 ].data 

2193 emtf.tags = ", ".join(tags) 

2194 emtf.period_range.min = emtf.data.period.min() 

2195 emtf.period_range.max = emtf.data.period.max() 

2196 

2197 emtf._get_data_types() 

2198 emtf._get_statistical_estimates() 

2199 # Update site layout after data is set to populate channels correctly 

2200 emtf._update_site_layout() 

2201 

2202 return emtf 

2203 

2204 def from_emtfxml( 

2205 self, emtfxml_obj: str | Path | EMTFXML, get_elevation: bool = False, **kwargs 

2206 ) -> None: 

2207 """ 

2208 

2209 Parameters 

2210 ---------- 

2211 emtfxml_obj: str | Path | EMTFXML 

2212 The input object to convert from. 

2213 get_elevation: bool 

2214 Try to get elevation from US National Map, defaults to True. 

2215 

2216 Returns 

2217 ------- 

2218 None 

2219 

2220 """ 

2221 

2222 if isinstance(emtfxml_obj, (str, Path)): 

2223 self._fn = Path(emtfxml_obj) 

2224 emtfxml_obj = EMTFXML(**kwargs) 

2225 emtfxml_obj.read(self._fn, get_elevation=get_elevation) 

2226 if not isinstance(emtfxml_obj, EMTFXML): 

2227 raise TypeError(f"Input must be a EMTFXML object not {type(emtfxml_obj)}") 

2228 self.survey_metadata = emtfxml_obj.survey_metadata 

2229 self.station_metadata = self.survey_metadata.stations[0] 

2230 

2231 self.period = emtfxml_obj.data.period 

2232 self.impedance = emtfxml_obj.data.z 

2233 # Handle negative or invalid values in z_var before taking sqrt 

2234 z_var = emtfxml_obj.data.z_var 

2235 with np.errstate(invalid="ignore"): 

2236 self.impedance_error = np.sqrt(np.where(z_var >= 0, z_var, np.nan)) 

2237 self._transfer_function.inverse_signal_power.loc[ 

2238 dict(input=["hx", "hy"], output=["hx", "hy"]) 

2239 ] = emtfxml_obj.data.z_invsigcov 

2240 self._transfer_function.residual_covariance.loc[ 

2241 dict(input=["ex", "ey"], output=["ex", "ey"]) 

2242 ] = emtfxml_obj.data.z_residcov 

2243 

2244 self.tipper = emtfxml_obj.data.t 

2245 self.tipper_error = np.sqrt(emtfxml_obj.data.t_var) 

2246 self._transfer_function.inverse_signal_power.loc[ 

2247 dict(input=["hx", "hy"], output=["hx", "hy"]) 

2248 ] = emtfxml_obj.data.t_invsigcov 

2249 self._transfer_function.residual_covariance.loc[ 

2250 dict(input=["hz"], output=["hz"]) 

2251 ] = emtfxml_obj.data.t_residcov 

2252 

2253 def to_jfile(self) -> None: 

2254 """ 

2255 

2256 Translate TF object ot JFile object. 

2257 

2258 .. note:: Not Implemented yet 

2259 

2260 :return: JFile object 

2261 :rtype: :class:`mt_metadata.transfer_functions.io.jfile.JFile` 

2262 

2263 """ 

2264 

2265 raise NotImplementedError("to_jfile not implemented yet.") 

2266 

2267 def from_jfile( 

2268 self, j_obj: str | Path | JFile, get_elevation: bool = False, **kwargs 

2269 ) -> None: 

2270 """ 

2271 

2272 Parameters 

2273 ---------- 

2274 jfile_obj: str | Path | JFile 

2275 The input object to convert from. 

2276 get_elevation: bool 

2277 Try to get elevation from US National Map, defaults to True. 

2278 

2279 Returns 

2280 ------- 

2281 None 

2282 

2283 """ 

2284 if isinstance(j_obj, (str, Path)): 

2285 self._fn = Path(j_obj) 

2286 j_obj = JFile(**kwargs) 

2287 j_obj.read(self._fn, get_elevation=get_elevation) 

2288 if not isinstance(j_obj, JFile): 

2289 raise TypeError(f"Input must be a JFile object not {type(j_obj)}") 

2290 k_dict = OrderedDict( 

2291 { 

2292 "period": "periods", 

2293 "impedance": "z", 

2294 "impedance_error": "z_err", 

2295 "tipper": "t", 

2296 "tipper_error": "t_err", 

2297 "survey_metadata": "survey_metadata", 

2298 # "station_metadata": "station_metadata", 

2299 } 

2300 ) 

2301 

2302 for tf_key, j_key in k_dict.items(): 

2303 setattr(self, tf_key, getattr(j_obj, j_key)) 

2304 

2305 def make_zmm_run(self, zmm_obj: ZMM, number_dict: dict) -> Run: 

2306 """ 

2307 Helper function to provide a run for a zmm object to aid writing z-file 

2308 

2309 Parameters 

2310 ---------- 

2311 zmm_obj: ZMM 

2312 A ZMM that will be written to file, that needs a run associated. 

2313 

2314 number_dict: dict 

2315 Mapping between hexy keys and integers, needed for emtf z-files, 

2316 e.g. {"hx": 1, "hy": 2, "hz": 3, "ex": 4, "ey": 5} 

2317 :type number_dict: dictionary 

2318 

2319 :return: run 

2320 :rtype: :class:` mt_metadata.timeseries.run.Run` 

2321 """ 

2322 run = Run() 

2323 for ch, ch_num in number_dict.items(): 

2324 c = ZChannel() 

2325 c.channel = ch 

2326 c.number = ch_num 

2327 setattr(zmm_obj, c.channel, c) 

2328 if ch in ["ex", "ey"]: 

2329 rc = Electric(component=ch, channel_number=ch_num) 

2330 run.add_channel(rc) 

2331 elif ch in ["hx", "hy", "hz"]: 

2332 rc = Magnetic(component=ch, channel_number=ch_num) 

2333 run.add_channel(rc) 

2334 return run 

2335 

2336 def to_zmm(self) -> ZMM: 

2337 """ 

2338 

2339 Translate TF object to ZMM object. 

2340 

2341 :return: ZMM object 

2342 :rtype: :class:`mt_metadata.transfer_function.io.zfiles.ZMM` 

2343 

2344 >>> from mt_metadata.transfer_functions import TF 

2345 >>> from mt_metadata import TF_XML 

2346 >>> t = TF(TF_XML) 

2347 >>> t.read() 

2348 >>> zmm_object = t.to_zmm() 

2349 >>> zmm_object.processing_type = "new and fancy" 

2350 >>> zmm_object.write() 

2351 

2352 """ 

2353 zmm_kwargs = {} 

2354 zmm_kwargs["channel_nomenclature"] = self.channel_nomenclature 

2355 zmm_kwargs["inverse_channel_nomenclature"] = self.inverse_channel_nomenclature 

2356 if hasattr(self, "decimation_dict"): 

2357 zmm_kwargs["decimation_dict"] = self.decimation_dict 

2358 zmm_obj = ZMM(**zmm_kwargs) 

2359 

2360 zmm_obj.dataset = self.dataset 

2361 zmm_obj.station_metadata = self.station_metadata 

2362 

2363 # need to set the channel numbers according to the z-file format 

2364 # with input channels (h's) and output channels (hz, e's). 

2365 if self.has_tipper(): 

2366 if self.has_impedance(): 

2367 zmm_obj.num_channels = 5 

2368 number_dict = {"hx": 1, "hy": 2, "hz": 3, "ex": 4, "ey": 5} 

2369 else: 

2370 zmm_obj.num_channels = 3 

2371 number_dict = {"hx": 1, "hy": 2, "hz": 3} 

2372 else: 

2373 if self.has_impedance(): 

2374 zmm_obj.num_channels = 4 

2375 number_dict = {"hx": 1, "hy": 2, "ex": 3, "ey": 4} 

2376 if len(self.station_metadata.runs) == 0: 

2377 run = self.make_zmm_run(zmm_obj, number_dict) 

2378 self.station_metadata.add_run(run) 

2379 elif len(self.station_metadata.runs[0].channels_recorded_all) == 0: 

2380 # avoid the default metadata getting interpretted as a real metadata object 

2381 # Overwrite this "spoof" run with a run that has recorded channels 

2382 if len(self.station_metadata.runs[0].channels_recorded_all) == 0: 

2383 run = self.make_zmm_run(zmm_obj, number_dict) 

2384 self.station_metadata.runs[0] = run 

2385 else: 

2386 for comp in self.station_metadata.runs[0].channels_recorded_all: 

2387 if "rr" in comp: 

2388 continue 

2389 ch = self.station_metadata.runs[0].get_channel(comp) 

2390 ch.component = self.inverse_channel_nomenclature[comp] 

2391 c = ZChannel() 

2392 c.from_dict(ch.to_dict(single=True)) 

2393 ch.component = comp 

2394 try: 

2395 c.number = number_dict[c.channel] 

2396 setattr(zmm_obj, c.channel, c) 

2397 except KeyError: 

2398 logger.debug(f"Could not find channel {c.channel}") 

2399 zmm_obj.survey_metadata.update(self.survey_metadata) 

2400 zmm_obj.num_freq = self.period.size 

2401 

2402 return zmm_obj 

2403 

2404 def from_zmm( 

2405 self, zmm_obj: str | Path | ZMM, get_elevation: bool = False, **kwargs 

2406 ) -> None: 

2407 """ 

2408 

2409 Parameters 

2410 ---------- 

2411 zmm_obj: str | Path | ZMM 

2412 Path to .zmm file or ZMM object 

2413 get_elevation: bool 

2414 Try to get elevation from US National Map, defaults to True 

2415 kwargs: dict 

2416 Keyword arguments for ZMM object 

2417 

2418 """ 

2419 

2420 if isinstance(zmm_obj, (str, Path)): 

2421 self._fn = Path(zmm_obj) 

2422 zmm_obj = ZMM(**kwargs) 

2423 zmm_obj.read(self._fn, get_elevation=get_elevation) 

2424 if not isinstance(zmm_obj, ZMM): 

2425 raise TypeError(f"Input must be a ZMM object not {type(zmm_obj)}") 

2426 self.decimation_dict = zmm_obj.decimation_dict 

2427 k_dict = OrderedDict( 

2428 { 

2429 "survey_metadata": "survey_metadata", 

2430 "station_metadata": "station_metadata", 

2431 "period": "periods", 

2432 } 

2433 ) 

2434 

2435 for tf_key, j_key in k_dict.items(): 

2436 setattr(self, tf_key, getattr(zmm_obj, j_key)) 

2437 self._transfer_function["transfer_function"].loc[ 

2438 dict(input=zmm_obj.input_channels, output=zmm_obj.output_channels) 

2439 ] = zmm_obj.dataset.transfer_function.loc[ 

2440 dict(input=zmm_obj.input_channels, output=zmm_obj.output_channels) 

2441 ] 

2442 self._transfer_function["inverse_signal_power"].loc[ 

2443 dict(input=zmm_obj.input_channels, output=zmm_obj.input_channels) 

2444 ] = zmm_obj.dataset.inverse_signal_power.loc[ 

2445 dict(input=zmm_obj.input_channels, output=zmm_obj.input_channels) 

2446 ] 

2447 self._transfer_function["residual_covariance"].loc[ 

2448 dict(input=zmm_obj.output_channels, output=zmm_obj.output_channels) 

2449 ] = zmm_obj.dataset.residual_covariance.loc[ 

2450 dict(input=zmm_obj.output_channels, output=zmm_obj.output_channels) 

2451 ] 

2452 

2453 self._compute_error_from_covariance() 

2454 self._rotation_angle = -1 * zmm_obj.declination 

2455 

2456 def to_zrr(self) -> ZMM: 

2457 """ 

2458 

2459 Translate TF object to ZMM object. 

2460 

2461 :return: ZMM object 

2462 :rtype: :class:`mt_metadata.transfer_function.io.zfiles.ZMM` 

2463 

2464 >>> from mt_metadata.transfer_functions import TF 

2465 >>> from mt_metadata import TF_XML 

2466 >>> t = TF(TF_XML) 

2467 >>> t.read() 

2468 >>> zmm_object = t.to_zmm() 

2469 >>> zmm_object.processing_type = "new and fancy" 

2470 >>> zmm_object.write() 

2471 

2472 """ 

2473 return self.to_zmm() 

2474 

2475 def from_zrr( 

2476 self, zrr_obj: str | Path | ZMM, get_elevation: bool = False, **kwargs 

2477 ) -> None: 

2478 """ 

2479 Parameters 

2480 ---------- 

2481 zmm_obj: str | Path | ZMM 

2482 Path to .zmm file or ZMM object 

2483 get_elevation: bool 

2484 Try to get elevation from US National Map, defaults to True 

2485 kwargs: dict 

2486 Keyword arguments for ZMM object 

2487 

2488 """ 

2489 

2490 self.from_zmm(zrr_obj, get_elevation=get_elevation, **kwargs) 

2491 

2492 def to_zss(self) -> ZMM: 

2493 """ 

2494 

2495 Translate TF object to ZMM object. 

2496 

2497 :return: ZMM object 

2498 :rtype: :class:`mt_metadata.transfer_function.io.zfiles.ZMM` 

2499 

2500 >>> from mt_metadata.transfer_functions import TF 

2501 >>> from mt_metadata import TF_XML 

2502 >>> t = TF(TF_XML) 

2503 >>> t.read() 

2504 >>> zmm_object = t.to_zmm() 

2505 >>> zmm_object.processing_type = "new and fancy" 

2506 >>> zmm_object.write() 

2507 

2508 """ 

2509 return self.to_zmm() 

2510 

2511 def from_zss( 

2512 self, zss_obj: str | Path | ZMM, get_elevation: bool = False, **kwargs 

2513 ) -> None: 

2514 """ 

2515 Parameters 

2516 ---------- 

2517 zss_obj: str | Path | ZMM 

2518 Path to .zss file or ZMM object 

2519 get_elevation: bool 

2520 Try to get elevation from US National Map, defaults to True 

2521 

2522 """ 

2523 

2524 self.from_zmm(zss_obj, get_elevation=get_elevation, **kwargs) 

2525 

2526 def to_avg(self) -> ZongeMTAvg: 

2527 """ 

2528 

2529 Translate TF object to ZongeMTAvg object. 

2530 

2531 .. note:: Not Implemented yet 

2532 

2533 :return: ZongeMTAvg object 

2534 :rtype: :class:`mt_metadata.transfer_function.io.zonge.ZongeMTAvg` 

2535 

2536 

2537 """ 

2538 

2539 avg_obj = ZongeMTAvg() 

2540 avg_obj.frequency = self.frequency 

2541 avg_obj.z = self.impedance 

2542 avg_obj.z_err = self.impedance_error 

2543 avg_obj.t = self.tipper 

2544 avg_obj.t_err = self.tipper_error 

2545 

2546 logger.warning("Metadata is not properly set for a AVG file yet.") 

2547 return avg_obj 

2548 

2549 def from_avg( 

2550 self, avg_obj: str | Path | ZongeMTAvg, get_elevation: bool = False, **kwargs 

2551 ) -> None: 

2552 """ 

2553 

2554 Parameters 

2555 ---------- 

2556 avg_obj: str | Path | ZongeMTAvg 

2557 Path to .avg file or ZongeMTAvg object 

2558 get_elevation: bool 

2559 Try to get elevation from US National Map, defaults to True 

2560 

2561 """ 

2562 if isinstance(avg_obj, (str, Path)): 

2563 self._fn = Path(avg_obj) 

2564 avg_obj = ZongeMTAvg(**kwargs) 

2565 avg_obj.read(self._fn, get_elevation=get_elevation) 

2566 if not isinstance(avg_obj, ZongeMTAvg): 

2567 raise TypeError(f"Input must be a ZMM object not {type(avg_obj)}") 

2568 self.survey_metadata = avg_obj.survey_metadata 

2569 

2570 self.period = 1.0 / avg_obj.frequency 

2571 self.impedance = avg_obj.z 

2572 self.impedance_error = avg_obj.z_err 

2573 

2574 if avg_obj.t is not None: 

2575 self.tipper = avg_obj.t 

2576 self.tipper_error = avg_obj.t_err 

2577 

2578 

2579# ============================================================================== 

2580# Error 

2581# ============================================================================== 

2582 

2583 

2584class TFError(Exception): 

2585 pass