Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ timeseries \ experiment.py: 75%

305 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# -*- coding: utf-8 -*- 

2""" 

3Containers for the full metadata tree 

4 

5Experiment --> Survey --> Station --> Run --> Channel 

6 

7Each level has a list attribute 

8 

9Created on Mon Feb 8 21:25:40 2021 

10 

11:copyright: 

12 Jared Peacock (jpeacock@usgs.gov) 

13 

14:license: MIT 

15 

16""" 

17import json 

18 

19# ============================================================================= 

20# Imports 

21# ============================================================================= 

22from collections import OrderedDict 

23from pathlib import Path 

24from typing import Annotated 

25from xml.etree import cElementTree as et 

26 

27from loguru import logger 

28from pydantic import computed_field, Field, field_validator 

29 

30from mt_metadata.base import helpers, MetadataBase 

31from mt_metadata.common.list_dict import ListDict 

32 

33from . import Auxiliary, Electric, Magnetic, Run, Station, Survey 

34from .filters import ( 

35 CoefficientFilter, 

36 FIRFilter, 

37 FrequencyResponseTableFilter, 

38 PoleZeroFilter, 

39 TimeDelayFilter, 

40) 

41 

42 

43# ============================================================================= 

44 

45 

46class Experiment(MetadataBase): 

47 """ 

48 Top level of the metadata 

49 """ 

50 

51 surveys: Annotated[ 

52 ListDict | list | dict | OrderedDict, 

53 Field( 

54 default_factory=ListDict, 

55 description="List of surveys in the experiment", 

56 title="List of Surveys", 

57 json_schema_extra={ 

58 "required": False, 

59 "units": None, 

60 "examples": [{"id": "survey_1"}, {"id": "survey_2"}], 

61 }, 

62 ), 

63 ] 

64 

65 def __str__(self) -> str: 

66 lines = ["Experiment Contents", "-" * 20] 

67 if len(self.surveys) > 0: 

68 lines.append(f"Number of Surveys: {len(self.surveys)}") 

69 for survey in self.surveys: 

70 lines.append(f" Survey ID: {survey.id}") 

71 lines.append(f" Number of Stations: {survey.n_stations}") 

72 lines.append(f" Number of Filters: {len(survey.filters.keys())}") 

73 lines.append(f" {'-' * 20}") 

74 for f_key, f_object in survey.filters.items(): 

75 lines.append(f" Filter Name: {f_key}") 

76 lines.append(f" Filter Type: {f_object.type}") 

77 lines.append(f" {'-' * 20}") 

78 for station in survey.stations: 

79 lines.append(f" Station ID: {station.id}") 

80 lines.append(f" Number of Runs: {station.n_runs}") 

81 lines.append(f" {'-' * 20}") 

82 for run in station.runs: 

83 lines.append(f" Run ID: {run.id}") 

84 lines.append(f" Number of Channels: {run.n_channels}") 

85 lines.append( 

86 " Recorded Channels: " 

87 + ", ".join(run.channels_recorded_all) 

88 ) 

89 lines.append(f" Start: {run.time_period.start}") 

90 lines.append(f" End: {run.time_period.end}") 

91 

92 lines.append(f" {'-' * 20}") 

93 

94 return "\n".join(lines) 

95 

96 def __repr__(self) -> str: 

97 return self.__str__() 

98 

99 def __eq__(self, other) -> bool: 

100 return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ 

101 

102 def __ne__(self, other) -> bool: 

103 return not self.__eq__(other) 

104 

105 def merge(self, other: "Experiment") -> "Experiment": 

106 """ 

107 Merge two Experiment objects 

108 """ 

109 if isinstance(other, Experiment): 

110 self.surveys.extend(other.surveys) 

111 

112 return self 

113 else: 

114 msg = f"Can only merge Experiment objects, not {type(other)}" 

115 logger.error(msg) 

116 raise TypeError(msg) 

117 

118 @computed_field 

119 @property 

120 def n_surveys(self) -> int: 

121 return len(self.surveys) 

122 

123 @field_validator("surveys", mode="before") 

124 @classmethod 

125 def validate_surveys(cls, value) -> ListDict: 

126 """set the survey list""" 

127 

128 if not isinstance(value, (list, tuple, dict, ListDict, OrderedDict)): 

129 msg = ( 

130 "input station_list must be an iterable, should be a list or dict " 

131 f"not {type(value)}" 

132 ) 

133 logger.error(msg) 

134 raise TypeError(msg) 

135 

136 fails = [] 

137 surveys = ListDict() 

138 if isinstance(value, (dict, ListDict, OrderedDict)): 

139 value_list = value.values() 

140 

141 elif isinstance(value, (list, tuple)): 

142 value_list = value 

143 

144 for ii, survey in enumerate(value_list): 

145 if isinstance(survey, (dict, OrderedDict)): 

146 s = Survey() 

147 s.from_dict(survey) 

148 surveys.append(s) 

149 elif not isinstance(survey, Survey): 

150 msg = f"Item {ii} is not type(Survey); type={type(survey)}" 

151 fails.append(msg) 

152 logger.error(msg) 

153 else: 

154 surveys.append(survey) 

155 if len(fails) > 0: 

156 raise TypeError("\n".join(fails)) 

157 return surveys 

158 

159 @property 

160 def survey_names(self) -> list[str]: 

161 """Return names of surveys in experiment""" 

162 return self.surveys.keys() 

163 

164 def has_survey(self, survey_id: str) -> bool: 

165 """ 

166 Has survey id 

167 

168 :param survey_id: DESCRIPTION 

169 :type survey_id: TYPE 

170 :return: DESCRIPTION 

171 :rtype: TYPE 

172 

173 """ 

174 if survey_id in self.survey_names: 

175 return True 

176 return False 

177 

178 def survey_index(self, survey_id: str) -> int | None: 

179 """ 

180 Get survey index 

181 

182 :param survey_id: DESCRIPTION 

183 :type survey_id: TYPE 

184 :return: DESCRIPTION 

185 :rtype: TYPE 

186 

187 """ 

188 

189 if self.has_survey(survey_id): 

190 return self.survey_names.index(survey_id) 

191 return None 

192 

193 def add_survey(self, survey_obj: "Survey") -> None: 

194 """ 

195 Add a survey, if has the same name update that object. 

196 

197 :param survey_obj: DESCRIPTION 

198 :type survey_obj: `:class:`mt_metadata.timeseries.Survey` 

199 :return: DESCRIPTION 

200 :rtype: TYPE 

201 

202 """ 

203 

204 if not isinstance(survey_obj, Survey): 

205 raise TypeError( 

206 f"Input must be a mt_metadata.timeseries.Survey object not {type(survey_obj)}" 

207 ) 

208 

209 if self.has_survey(survey_obj.id): 

210 self.surveys[survey_obj.id].update(survey_obj) 

211 logger.debug(f"survey {survey_obj.id} already exists, updating metadata") 

212 else: 

213 self.surveys.append(survey_obj) 

214 

215 def get_survey(self, survey_id: str) -> "Survey": 

216 """ 

217 Get a survey from the survey id 

218 

219 :param survey_id: DESCRIPTION 

220 :type survey_id: TYPE 

221 :return: DESCRIPTION 

222 :rtype: TYPE 

223 

224 """ 

225 

226 if self.has_survey(survey_id): 

227 return self.surveys[survey_id] 

228 else: 

229 logger.warning(f"Could not find survey {survey_id}") 

230 return None 

231 

232 def remove_survey(self, survey_id: str, update: bool = True) -> None: 

233 """ 

234 Remove a survey from the experiment 

235 

236 :param survey_id: DESCRIPTION 

237 :type survey_id: TYPE 

238 :return: DESCRIPTION 

239 :rtype: TYPE 

240 

241 """ 

242 

243 if self.has_survey(survey_id): 

244 self.surveys.remove(survey_id) 

245 logger.debug(f"Removed survey {survey_id} from experiment") 

246 

247 else: 

248 logger.warning(f"Could not find survey {survey_id} to remove") 

249 

250 def to_dict(self, nested: bool = False, required: bool = True) -> dict: 

251 """ 

252 create a dictionary for the experiment object. 

253 

254 :param nested: DESCRIPTION, defaults to False 

255 :type nested: TYPE, optional 

256 :param single: DESCRIPTION, defaults to False 

257 :type single: TYPE, optional 

258 :param required: DESCRIPTION, defaults to True 

259 :type required: TYPE, optional 

260 :return: DESCRIPTION 

261 :rtype: TYPE 

262 

263 """ 

264 

265 kwargs = {"nested": nested, "single": True, "required": required} 

266 

267 ex_dict = {"experiment": {"surveys": []}} 

268 for survey in self.surveys: 

269 survey_dict = survey.to_dict(**kwargs) 

270 survey_dict["stations"] = [] 

271 survey_dict["filters"] = [] 

272 for station in survey.stations: 

273 station_dict = station.to_dict(**kwargs) 

274 station_dict["runs"] = [] 

275 for run in station.runs: 

276 run_dict = run.to_dict(**kwargs) 

277 run_dict["channels"] = [] 

278 for channel in run.channels: 

279 run_dict["channels"].append(channel.to_dict(**kwargs)) 

280 station_dict["runs"].append(run_dict) 

281 survey_dict["stations"].append(station_dict) 

282 for f_key, f_object in survey.filters.items(): 

283 survey_dict["filters"].append(f_object.to_dict(**kwargs)) 

284 ex_dict["experiment"]["surveys"].append(survey_dict) 

285 

286 return ex_dict 

287 

288 def from_dict(self, ex_dict: dict | OrderedDict, skip_none: bool = True) -> None: 

289 """ 

290 fill from an input dictionary 

291 

292 :param ex_dict: DESCRIPTION 

293 :type ex_dict: TYPE 

294 :return: DESCRIPTION 

295 :rtype: TYPE 

296 

297 """ 

298 

299 if not isinstance(ex_dict, dict): 

300 msg = f"experiemnt input must be a dictionary not {type(ex_dict)}" 

301 logger.debug(msg) 

302 raise TypeError(msg) 

303 if "experiment" not in ex_dict.keys(): 

304 return 

305 

306 for survey_dict in ex_dict["experiment"]["surveys"]: 

307 survey_object = Survey() 

308 survey_object.from_dict(survey_dict, skip_none=skip_none) 

309 self.add_survey(survey_object) 

310 

311 def to_json( 

312 self, 

313 fn: str | Path = None, 

314 nested: bool = False, 

315 indent: str = " " * 4, 

316 required: bool = True, 

317 ) -> str | None: 

318 """ 

319 Write a json string from a given object, taking into account other 

320 class objects contained within the given object. 

321 

322 :param nested: make the returned json nested 

323 :type nested: [ True | False ] , default is False 

324 

325 """ 

326 

327 if fn is not None: 

328 with open(fn, "w") as fid: 

329 json.dump( 

330 self.to_dict(nested=nested, required=required), 

331 fid, 

332 cls=helpers.NumpyEncoder, 

333 indent=indent, 

334 ) 

335 

336 else: 

337 return json.dumps( 

338 self.to_dict(nested=nested, required=required), 

339 cls=helpers.NumpyEncoder, 

340 indent=indent, 

341 ) 

342 

343 def from_json(self, json_str: str, skip_none: bool = True) -> None: 

344 """ 

345 read in a json string and update attributes of an object 

346 

347 :param json_str: json string or file path 

348 :type json_str: string or :class:`pathlib.Path` 

349 

350 """ 

351 if isinstance(json_str, str): 

352 try: 

353 json_path = Path(json_str) 

354 if json_path.exists(): 

355 with open(json_path, "r") as fid: 

356 json_dict = json.load(fid) 

357 except OSError: 

358 pass 

359 json_dict = json.loads(json_str) 

360 elif isinstance(json_str, Path): 

361 if json_str.exists(): 

362 with open(json_str, "r") as fid: 

363 json_dict = json.load(fid) 

364 elif not isinstance(json_str, (str, Path)): 

365 msg = "Input must be valid JSON string not %" 

366 logger.error(msg, type(json_str)) 

367 raise TypeError(msg % type(json_str)) 

368 self.from_dict(json_dict, skip_none=skip_none) 

369 

370 def to_xml( 

371 self, fn: str | Path = None, required: bool = True, sort: bool = True 

372 ) -> et.Element: 

373 """ 

374 Write XML version of the experiment 

375 

376 :param fn: DESCRIPTION 

377 :type fn: TYPE 

378 :return: DESCRIPTION 

379 :rtype: TYPE 

380 

381 """ 

382 

383 experiment_element = et.Element(self.__class__.__name__) 

384 if sort: 

385 self.surveys.sort() 

386 for survey in self.surveys: 

387 survey.update_bounding_box() 

388 survey.update_time_period() 

389 survey_element = survey.to_xml(required=required) 

390 filter_element = et.SubElement(survey_element, "filters") 

391 for key, value in survey.filters.items(): 

392 filter_element.append(value.to_xml(required=required)) 

393 if sort: 

394 survey.stations.sort() 

395 for station in survey.stations: 

396 station.update_time_period() 

397 station_element = station.to_xml(required=required) 

398 if sort: 

399 station.runs.sort() 

400 for run in station.runs: 

401 run.update_time_period() 

402 run_element = run.to_xml(required=required) 

403 if sort: 

404 run.channels.sort() 

405 for channel in run.channels: 

406 if channel.type in ["electric"]: 

407 if ( 

408 channel.positive.latitude == 0 

409 and channel.positive.longitude == 0 

410 and channel.positive.elevation == 0 

411 ): 

412 channel.positive.latitude = station.location.latitude 

413 channel.positive.longitude = station.location.longitude 

414 channel.positive.elevation = station.location.elevation 

415 else: 

416 if ( 

417 channel.location.latitude == 0 

418 and channel.location.longitude == 0 

419 and channel.location.elevation == 0 

420 ): 

421 channel.location.latitude = station.location.latitude 

422 channel.location.longitude = station.location.longitude 

423 channel.location.elevation = station.location.elevation 

424 

425 run_element.append(channel.to_xml(required=required)) 

426 station_element.append(run_element) 

427 survey_element.append(station_element) 

428 experiment_element.append(survey_element) 

429 

430 if fn: 

431 with open(fn, "w") as fid: 

432 fid.write(helpers.element_to_string(experiment_element)) 

433 return experiment_element 

434 

435 def from_xml( 

436 self, 

437 fn: str | Path = None, 

438 element: et.Element | None = None, 

439 sort: bool = True, 

440 skip_none: bool = True, 

441 ) -> None: 

442 """ 

443 

444 :param fn: DESCRIPTION, defaults to None 

445 :type fn: TYPE, optional 

446 :param element: DESCRIPTION, defaults to None 

447 :type element: TYPE, optional 

448 :return: DESCRIPTION 

449 :rtype: TYPE 

450 

451 

452 

453 """ 

454 if fn: 

455 experiment_element = et.parse(fn).getroot() 

456 if element is not None: 

457 experiment_element = element 

458 

459 # need to set the lists for each layer, otherwise you get duplicates. 

460 for survey_element in list(experiment_element): 

461 survey_dict = helpers.element_to_dict(survey_element) 

462 stations = self._pop_dictionary(survey_dict["survey"], "station") 

463 survey_obj = Survey() 

464 survey_obj.from_dict(survey_dict, skip_none=skip_none) 

465 fd = survey_dict["survey"].pop("filters") 

466 filter_dict = self._read_filter_dict(fd) 

467 survey_obj.filters.update(filter_dict) 

468 

469 for station_dict in stations: 

470 station_obj = Station() 

471 runs = self._pop_dictionary(station_dict, "run") 

472 station_obj.from_dict(station_dict, skip_none=skip_none) 

473 for run_dict in runs: 

474 run_obj = Run() 

475 

476 for ch in ["electric", "magnetic", "auxiliary"]: 

477 try: 

478 for ch_dict in self._pop_dictionary(run_dict, ch): 

479 if ch == "electric": 

480 channel = Electric() 

481 elif ch == "magnetic": 

482 channel = Magnetic() 

483 elif ch == "auxiliary": 

484 channel = Auxiliary() 

485 channel.from_dict(ch_dict, skip_none=skip_none) 

486 run_obj.add_channel(channel) 

487 except KeyError: 

488 logger.debug(f"Could not find channel {ch}") 

489 run_obj.from_dict(run_dict, skip_none=skip_none) 

490 station_obj.add_run(run_obj) 

491 survey_obj.add_station(station_obj) 

492 self.add_survey(survey_obj) 

493 

494 if sort: 

495 self.sort() 

496 

497 def _pop_dictionary(self, in_dict: dict, element: str) -> list: 

498 """ 

499 Pop off a key from an input dictionary, make sure output is a list 

500 

501 :param in_dict: DESCRIPTION 

502 :type in_dict: TYPE 

503 :param element: DESCRIPTION 

504 :type element: TYPE 

505 :return: DESCRIPTION 

506 :rtype: TYPE 

507 

508 """ 

509 

510 elements = in_dict.pop(element) 

511 if not isinstance(elements, list): 

512 elements = [elements] 

513 

514 return elements 

515 

516 def to_pickle(self, fn: str | Path = None) -> None: 

517 """ 

518 Write a pickle version of the experiment 

519 

520 :param fn: DESCRIPTION 

521 :type fn: TYPE 

522 :return: DESCRIPTION 

523 :rtype: TYPE 

524 

525 """ 

526 

527 def from_pickle(self, fn: str | Path = None) -> None: 

528 """ 

529 Read pickle version of experiment 

530 

531 :param fn: DESCRIPTION 

532 :type fn: TYPE 

533 :return: DESCRIPTION 

534 :rtype: TYPE 

535 

536 """ 

537 

538 # def validate_experiment(self): 

539 # """ 

540 # Validate experiment is legal 

541 

542 # :return: DESCRIPTION 

543 # :rtype: TYPE 

544 

545 # """ 

546 # pass 

547 

548 def _read_filter_dict(self, filters_dict: dict | None) -> ListDict: 

549 """ 

550 Read in filter element an put it in the correct object 

551 

552 :param filter_element: DESCRIPTION 

553 :type filter_element: TYPE 

554 :return: DESCRIPTION 

555 :rtype: TYPE 

556 

557 """ 

558 return_dict = ListDict() 

559 if filters_dict is None: 

560 return return_dict 

561 

562 for key, value in filters_dict.items(): 

563 if key in ["pole_zero_filter"]: 

564 if isinstance(value, list): 

565 for v in value: 

566 mt_filter = PoleZeroFilter(**v) 

567 return_dict[mt_filter.name.lower()] = mt_filter 

568 else: 

569 mt_filter = PoleZeroFilter(value) 

570 return_dict[mt_filter.name.lower()] = mt_filter 

571 

572 elif key in ["coefficient_filter"]: 

573 if isinstance(value, list): 

574 for v in value: 

575 mt_filter = CoefficientFilter(**v) 

576 return_dict[mt_filter.name.lower()] = mt_filter 

577 else: 

578 mt_filter = CoefficientFilter(value) 

579 return_dict[mt_filter.name.lower()] = mt_filter 

580 

581 elif key in ["time_delay_filter"]: 

582 if isinstance(value, list): 

583 for v in value: 

584 mt_filter = TimeDelayFilter(**v) 

585 return_dict[mt_filter.name.lower()] = mt_filter 

586 else: 

587 mt_filter = TimeDelayFilter(value) 

588 return_dict[mt_filter.name.lower()] = mt_filter 

589 

590 elif key in ["frequency_response_table_filter"]: 

591 if isinstance(value, list): 

592 for v in value: 

593 mt_filter = FrequencyResponseTableFilter(**v) 

594 return_dict[mt_filter.name.lower()] = mt_filter 

595 else: 

596 mt_filter = FrequencyResponseTableFilter(value) 

597 return_dict[mt_filter.name.lower()] = mt_filter 

598 

599 elif key in ["fir_filter"]: 

600 if isinstance(value, list): 

601 for v in value: 

602 mt_filter = FIRFilter(**v) 

603 return_dict[mt_filter.name.lower()] = mt_filter 

604 else: 

605 mt_filter = FIRFilter(value) 

606 return_dict[mt_filter.name.lower()] = mt_filter 

607 

608 return return_dict 

609 

610 def sort(self, inplace: bool = True) -> "Experiment": 

611 """ 

612 sort surveys, stations, runs, channels alphabetically/numerically 

613 

614 :param inplace: DESCRIPTION, defaults to True 

615 :type inplace: TYPE, optional 

616 :return: DESCRIPTION 

617 :rtype: TYPE 

618 

619 """ 

620 

621 if inplace: 

622 self.surveys.sort() 

623 for survey in self.surveys: 

624 survey.stations.sort() 

625 for station in survey.stations: 

626 station.runs.sort() 

627 for run in station.runs: 

628 run.channels.sort() 

629 

630 else: 

631 ex = Experiment() 

632 ex.from_dict(self.to_dict()) 

633 ex.sort() 

634 return ex