Coverage for /Users/Newville/Codes/xraylarch/larch/io/athena_project.py: 9%

627 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2""" 

3Code to read and write Athena Project files 

4 

5""" 

6 

7import os 

8import io 

9import sys 

10import time 

11import json 

12import platform 

13from fnmatch import fnmatch 

14from gzip import GzipFile 

15from copy import deepcopy 

16import numpy as np 

17from numpy.random import randint 

18 

19from larch import Group 

20from larch import __version__ as larch_version 

21from larch.utils.strutils import bytes2str, str2bytes, fix_varname, asfloat 

22 

23from xraydb import guess_edge 

24import asteval 

25 

26hexopen = '\\x{' 

27hexclose = '}' 

28 

29alist2json = str.maketrans("();'\n", "[] \" ") 

30 

31def plarray2json(text): 

32 return json.loads(text.split('=', 1)[1].strip().translate(alist2json)) 

33 

34def parse_arglist(text): 

35 txt = text.split('=', 1)[1].strip() 

36 if txt.endswith(';'): 

37 txt = txt[:-1] 

38 return json.loads(txt.translate(alist2json)) 

39 

40 

41 

42ERR_MSG = "Error reading Athena Project File" 

43 

44 

45def _read_raw_athena(filename): 

46 """try to read athena project file as plain text, 

47 to determine validity 

48 """ 

49 # try gzip 

50 text = None 

51 try: 

52 fh = GzipFile(filename) 

53 text = bytes2str(fh.read()) 

54 except Exception: 

55 errtype, errval, errtb = sys.exc_info() 

56 text = None 

57 

58 if text is None: 

59 # try plain text file 

60 try: 

61 fh = open(filename, 'r') 

62 text = bytes2str(fh.read()) 

63 except Exception: 

64 errtype, errval, errtb = sys.exc_info() 

65 text = None 

66 

67 return text 

68 

69 

70def _test_athena_text(text): 

71 return "Athena project file -- " in text[:500] 

72 

73 

74def is_athena_project(filename): 

75 """tests whether file is a valid Athena Project file""" 

76 text = _read_raw_athena(filename) 

77 if text is None: 

78 return False 

79 return _test_athena_text(text) 

80 

81 

82def make_hashkey(length=5): 

83 """generate an 'athena hash key': 5 random lower-case letters 

84 """ 

85 return ''.join([chr(randint(97, 122)) for i in range(length)]) 

86 

87def make_athena_args(group, hashkey=None, **kws): 

88 """make athena args line from a group""" 

89 # start with default args: 

90 from larch.xafs.xafsutils import etok 

91 

92 if hashkey is None: 

93 hashkey = make_hashkey() 

94 args = {} 

95 for k, v in (('annotation', ''), 

96 ('beamline', ''), 

97 ('beamline_identified', '0'), ('bft_dr', '0.0'), 

98 ('bft_rmax', '3'), ('bft_rmin', '1'), 

99 ('bft_rwindow', 'hanning'), ('bkg_algorithm', 'autobk'), 

100 ('bkg_cl', '0'), ('bkg_clamp1', '0'), ('bkg_clamp2', '24'), 

101 ('bkg_delta_eshift', '0'), ('bkg_dk', '1'), 

102 ('bkg_e0_fraction', '0.5'), ('bkg_eshift', '0'), 

103 ('bkg_fixstep', '0'), ('bkg_flatten', '1'), 

104 ('bkg_former_e0', '0'), ('bkg_funnorm', '0'), 

105 ('bkg_int', '7.'), ('bkg_kw', '1'), 

106 ('bkg_kwindow', 'hanning'), ('bkg_nclamp', '5'), 

107 ('bkg_rbkg', '1.0'), ('bkg_slope', '-0.0'), 

108 ('bkg_stan', 'None'), ('bkg_tie_e0', '0'), 

109 ('bkg_nc0', '0'), ('bkg_nc1', '0'), 

110 ('bkg_nc2', '0'), ('bkg_nc3', '0'), 

111 ('bkg_rbkg', '1.0'), ('bkg_slope', '0'), 

112 ('bkg_pre1', '-150'), ('bkg_pre2', '-30'), 

113 ('bkg_nor1', '150'), ('bkg_nor2', '800'), 

114 ('bkg_nnorm', '1'), 

115 ('prjrecord', 'athena.prj, 1'), ('chi_column', ''), 

116 ('chi_string', ''), ('collided', '0'), ('columns', ''), 

117 ('daq', ''), ('denominator', '1'), ('display', '0'), 

118 ('energy', ''), ('energy_string', ''), ('epsk', ''), 

119 ('epsr', ''), ('fft_dk', '4'), ('fft_edge', 'k'), 

120 ('fft_kmax', '15.'), ('fft_kmin', '2.00'), 

121 ('fft_kwindow', 'kaiser-bessel'), ('fft_pc', '0'), 

122 ('fft_pcpathgroup', ''), ('fft_pctype', 'central'), 

123 ('forcekey', '0'), ('from_athena', '1'), 

124 ('from_yaml', '0'), ('frozen', '0'), ('generated', '0'), 

125 ('i0_scale', '1'), ('i0_string', '1'), 

126 ('importance', '1'), ('inv', '0'), ('is_col', '1'), 

127 ('is_fit', '0'), ('is_kev', '0'), ('is_merge', ''), 

128 ('is_nor', '0'), ('is_pixel', '0'), ('is_special', '0'), 

129 ('is_xmu', '1'), ('ln', '0'), ('mark', '0'), 

130 ('marked', '0'), ('maxk', '15'), ('merge_weight', '1'), 

131 ('multiplier', '1'), ('nidp', '5'), ('nknots', '4'), 

132 ('numerator', ''), ('plot_scale', '1'), 

133 ('plot_yoffset', '0'), ('plotkey', ''), 

134 ('plotspaces', 'any'), ('provenance', ''), 

135 ('quenched', '0'), ('quickmerge', '0'), 

136 ('read_as_raw', '0'), ('rebinned', '0'), 

137 ('recommended_kmax', '1'), ('recordtype', 'mu(E)'), 

138 ('referencegroup', ''), ('rmax_out', '10'), 

139 ('signal_scale', '1'), ('signal_string', '-1'), 

140 ('trouble', ''), ('tying', '0'), 

141 ('unreadable', '0'), ('update_bft', '1'), 

142 ('update_bkg', '1'), ('update_columns', '0'), 

143 ('update_data', '0'), ('update_fft', '1'), 

144 ('update_norm', '1'), ('xdi_will_be_cloned', '0'), 

145 ('xdifile', ''), ('xmu_string', ''), 

146 ('valence', ''), ('lasso_yvalue', ''), 

147 ('atsym', ''), ('edge', '') ): 

148 args[k] = v 

149 

150 args['datagroup'] = args['tag'] = args['label'] = hashkey 

151 en = getattr(group, 'energy', []) 

152 args['npts'] = len(en) 

153 if len(en) > 0: 

154 args['xmin'] = '%.1f' % min(en) 

155 args['xmax'] = '%.1f' % max(en) 

156 

157 main_map = dict(source='filename', file='filename', label='filename', 

158 bkg_e0='e0', bkg_step='edge_step', 

159 bkg_fitted_step='edge_step', valence='valence', 

160 lasso_yvalue='lasso_yvalue', atsym='atsym', 

161 edge='edge') 

162 

163 for aname, lname in main_map.items(): 

164 val = getattr(group, lname, None) 

165 if val is not None: 

166 args[aname] = val 

167 

168 bkg_map = dict(nnorm='nnorm', nor1='norm1', nor2='norm2', pre1='pre1', 

169 pre2='pre2') 

170 

171 if hasattr(group, 'pre_edge_details'): 

172 for aname, lname in bkg_map.items(): 

173 val = getattr(group.pre_edge_details, lname, None) 

174 if val is not None: 

175 args['bkg_%s' % aname] = val 

176 

177 emax = max(group.energy) - group.e0 

178 args['bkg_spl1e'] = '0' 

179 args['bkg_spl2e'] = '%.5f' % emax 

180 args['bkg_spl1'] = '0' 

181 args['bkg_spl2'] = '%.5f' % etok(emax) 

182 

183 autobk_details = getattr(group, 'autobk_details', None) 

184 autobk_args = getattr(autobk_details, 'call_args', None) 

185 if autobk_args is not None: 

186 args['bkg_rbkg'] = autobk_args['rbkg'] 

187 args['bkg_spl1'] = autobk_args['kmin'] 

188 args['bkg_spl2'] = autobk_args['kmax'] 

189 args['bkg_kw'] = autobk_args['kweight'] 

190 args['bkg_dk'] = autobk_args['dk'] 

191 args['bkg_kwindow'] = autobk_args['win'] 

192 args['bkg_nclamp'] = autobk_args['nclamp'] 

193 args['bkg_clamp1'] = autobk_args['clamp_lo'] 

194 args['bkg_clamp2'] = autobk_args['clamp_hi'] 

195 

196 xftf_details = getattr(group, 'xftf_details', None) 

197 xftf_args = getattr(xftf_details, 'call_args', None) 

198 if xftf_args is not None: 

199 args['fft_kmin'] = xftf_args['kmin'] 

200 args['fft_kmax'] = xftf_args['kmax'] 

201 args['fft_kw'] = xftf_args['kweight'] 

202 args['fft_dk'] = xftf_args['dk'] 

203 args['fft_kwindow'] = xftf_args['window'] 

204 args.update(kws) 

205 return args 

206 

207 

208def athena_array(group, arrname): 

209 """convert ndarray to athena representation""" 

210 arr = getattr(group, arrname, None) 

211 if arr is None: 

212 return None 

213 return arr # json.dumps([repr(i) for i in arr]) 

214 # return "(%s)" % ','.join(["'%s'" % i for i in arr]) 

215 

216 

217def format_dict(d): 

218 """ format dictionary for Athena Project file""" 

219 o = [] 

220 for key in sorted(d.keys()): 

221 o.append("'%s'" % key) 

222 val = d[key] 

223 if val is None: val = '' 

224 o.append("'%s'" % val) 

225 return ','.join(o) 

226 

227def format_array(arr): 

228 """ format dictionary for Athena Project file""" 

229 o = ["'%s'" % v for v in arr] 

230 return ','.join(o) 

231 

232def clean_bkg_params(grp): 

233 grp.nnorm = getattr(grp, 'nnorm', 2) 

234 grp.e0 = getattr(grp, 'e0', -1) 

235 grp.rbkg = getattr(grp, 'rbkg', 1) 

236 grp.pre1 = getattr(grp, 'pre1', -150) 

237 grp.pre2 = getattr(grp, 'pre2', -25) 

238 grp.nor1 = getattr(grp, 'nor1', 100) 

239 grp.nor2 = getattr(grp, 'nor2', 1200) 

240 grp.spl1 = getattr(grp, 'spl1', 0) 

241 grp.spl2 = getattr(grp, 'spl2', 30) 

242 grp.kw = getattr(grp, 'kw', 1) 

243 grp.dk = getattr(grp, 'dk', 3) 

244 grp.flatten = getattr(grp, 'flatten', 0) 

245 if getattr(grp, 'kwindow', None) is None: 

246 grp.kwindow = getattr(grp, 'win', 'hanning') 

247 

248 try: 

249 grp.clamp1 = float(grp.clamp1) 

250 except Exception: 

251 grp.clamp1 = 1 

252 try: 

253 grp.clamp2 = float(grp.clamp2) 

254 except Exception: 

255 grp.clamp2 = 1 

256 

257 return grp 

258 

259 

260def clean_fft_params(grp): 

261 grp.kmin = getattr(grp, 'kmin', 0) 

262 grp.kmax = getattr(grp, 'kmax', 25) 

263 grp.kweight = getattr(grp, 'kweight', 2) 

264 grp.dk = getattr(grp, 'dk', 3) 

265 grp.kwindow = getattr(grp, 'kwindow', 'hanning') 

266 return grp 

267 

268 

269def text2list(text): 

270 key, txt = [a.strip() for a in text.split('=', 1)] 

271 if txt.endswith('\n'): 

272 txt = txt[:-1] 

273 if txt.endswith(';'): 

274 txt = txt[:-1] 

275 txt = txt.replace('=>', ':').replace('\n', ' ').replace('\r', ' ').replace('\t', ' ') 

276 # re-cast unicode stored by perl (\x{e34} -> 0xe4) 

277 if hexopen in txt: 

278 w = [] 

279 k = 0 

280 for i in range(len(txt)-3): 

281 if txt[i:i+3] == hexopen: 

282 j = txt[i:i+8].find(hexclose) 

283 if j > 0: 

284 w.extend((txt[k:i], chr(int('0x' + txt[i+3:i+j], base=16)))) 

285 k = i+j+1 

286 w.append(txt[k:]) 

287 txt = ''.join(w) 

288 return txt 

289 

290 

291def parse_perlathena(text, filename): 

292 """ 

293 parse old athena file format to Group of Groups 

294 """ 

295 aout = io.StringIO() 

296 aeval = asteval.Interpreter(minimal=True, writer=aout, err_writer=aout, 

297 max_statement_length=12543000) 

298 

299 lines = text.split('\n') 

300 athenagroups = [] 

301 raw = {'name':''} 

302 vline = lines.pop(0) 

303 if "Athena project file -- " not in vline: 

304 raise ValueError("%s '%s': invalid Athena File" % (ERR_MSG, filename)) 

305 major, minor, fix = '0', '0', '0' 

306 if 'Demeter' in vline: 

307 try: 

308 vs = vline.split("Athena project file -- Demeter version")[1] 

309 major, minor, fix = vs.split('.') 

310 except: 

311 raise ValueError("%s '%s': cannot read version" % (ERR_MSG, filename)) 

312 else: 

313 try: 

314 vs = vline.split("Athena project file -- Athena version")[1] 

315 major, minor, fix = vs.split('.') 

316 except: 

317 raise ValueError("%s '%s': cannot read version" % (ERR_MSG, filename)) 

318 

319 header = [vline] 

320 journal = [''] 

321 is_header = True 

322 ix = 0 

323 for t in lines: 

324 ix += 1 

325 if t.startswith('#') or len(t) < 2 or 'undef' in t: 

326 if is_header: 

327 header.append(t) 

328 continue 

329 is_header = False 

330 key = t.split()[0].strip() 

331 key = key.replace('$', '').replace('@', '').replace('%', '').strip() 

332 if key == 'old_group': 

333 raw['name'] = aeval(text2list(t)) 

334 elif key == '[record]': 

335 athenagroups.append(raw) 

336 raw = {'name':''} 

337 elif key == 'journal': 

338 try: 

339 journal = aeval(text2list(t)) 

340 except ValueError: 

341 pass 

342 if len(aeval.error) > 0: 

343 print(f" warning: may not read journal from '{filename:s}' completely") 

344 journal = text2list(t) 

345 

346 elif key == 'args': 

347 raw['args'] = aeval(text2list(t)) 

348 elif key == 'xdi': 

349 raw['xdi'] = t 

350 elif key in ('x', 'y', 'i0', 'signal', 'stddev'): 

351 raw[key] = np.array([float(x) for x in aeval(text2list(t))]) 

352 elif key in ('1;', 'indicator', 'lcf_data', 'plot_features'): 

353 pass 

354 else: 

355 print(" do not know what to do with key '%s' at '%s'" % (key, raw['name'])) 

356 

357 out = Group() 

358 out.__doc__ = """XAFS Data from Athena Project File %s""" % (filename) 

359 out.journal = '\n'.join(journal) 

360 out.group_names = [] 

361 out.header = '\n'.join(header) 

362 for dat in athenagroups: 

363 label = dat.get('name', 'unknown') 

364 this = Group(energy=dat['x'], mu=dat['y'], 

365 athena_params=Group(id=label, bkg=Group(), fft=Group())) 

366 

367 if 'i0' in dat: 

368 this.i0 = dat['i0'] 

369 if 'signal' in dat: 

370 this.signal = dat['signal'] 

371 if 'stddev' in dat: 

372 this.stddev = dat['stddev'] 

373 if 'args' in dat: 

374 for i in range(len(dat['args'])//2): 

375 key = dat['args'][2*i] 

376 val = dat['args'][2*i+1] 

377 if key.startswith('bkg_'): 

378 setattr(this.athena_params.bkg, key[4:], asfloat(val)) 

379 elif key.startswith('fft_'): 

380 setattr(this.athena_params.fft, key[4:], asfloat(val)) 

381 elif key == 'label': 

382 label = this.label = val 

383 elif key in ('valence', 'lasso_yvalue', 'epsk', 'epsr'): 

384 setattr(this, key, asfloat(val)) 

385 elif key in ('atsym', 'edge'): 

386 setattr(this, key, val) 

387 else: 

388 setattr(this.athena_params, key, asfloat(val)) 

389 this.__doc__ = """Athena Group Name %s (key='%s')""" % (label, dat['name']) 

390 if label.startswith(' '): 

391 label = 'd_' + label.strip() 

392 name = fix_varname(label) 

393 if name.startswith('_'): 

394 name = 'd' + name 

395 setattr(out, name, this) 

396 out.group_names.append(name) 

397 return out 

398 

399 

400def parse_perlathena_old(text, filename): 

401 """ 

402 parse old athena file format to Group of Groups 

403 """ 

404 lines = text.split('\n') 

405 athenagroups = [] 

406 raw = {'name':''} 

407 vline = lines.pop(0) 

408 if "Athena project file -- " not in vline: 

409 raise ValueError("%s '%s': invalid Athena File" % (ERR_MSG, filename)) 

410 major, minor, fix = '0', '0', '0' 

411 if 'Demeter' in vline: 

412 try: 

413 vs = vline.split("Athena project file -- Demeter version")[1] 

414 major, minor, fix = vs.split('.') 

415 except: 

416 raise ValueError("%s '%s': cannot read version" % (ERR_MSG, filename)) 

417 else: 

418 try: 

419 vs = vline.split("Athena project file -- Athena version")[1] 

420 major, minor, fix = vs.split('.') 

421 except: 

422 raise ValueError("%s '%s': cannot read version" % (ERR_MSG, filename)) 

423 

424 header = [vline] 

425 journal = [''] 

426 is_header = True 

427 for t in lines: 

428 if t.startswith('#') or len(t) < 2 or 'undef' in t: 

429 if is_header: 

430 header.append(t) 

431 continue 

432 is_header = False 

433 key = t.split()[0].strip() 

434 key = key.replace('$', '').replace('@', '').replace('%', '').strip() 

435 if key == 'old_group': 

436 raw['name'] = plarray2json(t) 

437 elif key == '[record]': 

438 athenagroups.append(raw) 

439 raw = {'name':''} 

440 elif key == 'journal': 

441 journal = parse_arglist(t) 

442 elif key == 'args': 

443 raw['args'] = parse_arglist(t) 

444 elif key == 'xdi': 

445 raw['xdi'] = t 

446 elif key in ('x', 'y', 'i0', 'signal', 'stddev'): 

447 raw[key] = np.array([float(x) for x in plarray2json(t)]) 

448 elif key in ('1;', 'indicator', 'lcf_data', 'plot_features'): 

449 pass 

450 else: 

451 print(" do not know what to do with key '%s' at '%s'" % (key, raw['name'])) 

452 

453 out = Group() 

454 out.__doc__ = """XAFS Data from Athena Project File %s""" % (filename) 

455 out.journal = '\n'.join(journal) 

456 out.group_names = [] 

457 out.header = '\n'.join(header) 

458 for dat in athenagroups: 

459 label = dat.get('name', 'unknown') 

460 this = Group(energy=dat['x'], mu=dat['y'], 

461 athena_params=Group(id=label, bkg=Group(), fft=Group())) 

462 

463 if 'i0' in dat: 

464 this.i0 = dat['i0'] 

465 if 'signal' in dat: 

466 this.signal = dat['signal'] 

467 if 'stddev' in dat: 

468 this.stddev = dat['stddev'] 

469 if 'args' in dat: 

470 for i in range(len(dat['args'])//2): 

471 key = dat['args'][2*i] 

472 val = dat['args'][2*i+1] 

473 if key.startswith('bkg_'): 

474 setattr(this.athena_params.bkg, key[4:], asfloat(val)) 

475 elif key.startswith('fft_'): 

476 setattr(this.athena_params.fft, key[4:], asfloat(val)) 

477 elif key == 'label': 

478 label = this.label = val 

479 elif key in ('valence', 'lasso_yvalue', 'epsk', 'epsr'): 

480 setattr(this, key, asfloat(val)) 

481 elif key in ('atsym', 'edge'): 

482 setattr(this, key, val) 

483 else: 

484 setattr(this.athena_params, key, asfloat(val)) 

485 this.__doc__ = """Athena Group Name %s (key='%s')""" % (label, dat['name']) 

486 name = fix_varname(label) 

487 if name.startswith('_'): 

488 name = 'd' + name 

489 setattr(out, name, this) 

490 out.group_names.append(name) 

491 

492 return out 

493 

494 

495def parse_jsonathena(text, filename): 

496 """parse a JSON-style athena file""" 

497 jsdict = json.loads(text) 

498 

499 out = Group() 

500 out.__doc__ = """XAFS Data from Athena Project File %s""" % (filename) 

501 

502 header = [] 

503 athena_names = [] 

504 for key, val in jsdict.items(): 

505 if key.startswith('_____head'): 

506 header.append(val) 

507 elif key.startswith('_____journ'): 

508 journal = val 

509 elif key.startswith('_____order'): 

510 athena_names = val 

511 

512 out.journal = journal 

513 out.header = '\n'.join(header) 

514 out.group_names = [] 

515 for name in athena_names: 

516 label = name 

517 dat = jsdict[name] 

518 x = np.array(dat['x'], dtype='float64') 

519 y = np.array(dat['y'], dtype='float64') 

520 this = Group(energy=x, mu=y, 

521 athena_params=Group(id=name, bkg=Group(), fft=Group())) 

522 

523 if 'i0' in dat: 

524 this.i0 = np.array(dat['i0'], dtype='float64') 

525 if 'signal' in dat: 

526 this.signal = np.array(dat['signal'], dtype='float64') 

527 if 'stddev' in dat: 

528 this.stddev = np.array(dat['stddev'], dtype='float64') 

529 if 'args' in dat: 

530 for key, val in dat['args'].items(): 

531 if key.startswith('bkg_'): 

532 setattr(this.athena_params.bkg, key[4:], asfloat(val)) 

533 elif key.startswith('fft_'): 

534 setattr(this.athena_params.fft, key[4:], asfloat(val)) 

535 elif key == 'label': 

536 label = this.label = val 

537 elif key in ('valence', 'lasso_yvalue', 'epsk', 'epsr'): 

538 setattr(this, key, asfloat(val)) 

539 elif key in ('atsym', 'edge'): 

540 setattr(this, key, val) 

541 else: 

542 setattr(this.athena_params, key, asfloat(val)) 

543 this.__doc__ = """Athena Group Name %s (key='%s')""" % (label, name) 

544 name = fix_varname(label) 

545 if name.startswith('_'): 

546 name = 'd' + name 

547 setattr(out, name, this) 

548 out.group_names.append(name) 

549 return out 

550 

551 

552class AthenaGroup(Group): 

553 """A special Group for handling datasets loaded from Athena project files""" 

554 

555 def __init__(self, show_sel=False): 

556 """Constructor 

557 

558 Parameters 

559 ---------- 

560 

561 show_sel : boolean, False 

562 if True, it shows the selection flag in HTML representation 

563 """ 

564 super().__init__() 

565 self.show_sel = show_sel 

566 

567 def _repr_html_(self): 

568 """HTML representation for Jupyter notebook""" 

569 

570 _has_sel = any([hasattr(g, 'sel') for g in self.groups.values()]) 

571 html = ["<table>"] 

572 html.append("<tr>") 

573 html.append("<td><b>Group</b></td>") 

574 if self.show_sel and _has_sel: 

575 html.append("<td><b>Sel</b></td>") 

576 html.append("</tr>") 

577 for name, grp in self.groups.items(): 

578 try: 

579 if grp.sel == 1: 

580 sel = "\u2714" 

581 else: 

582 sel = "" 

583 except AttributeError: 

584 sel = "" 

585 html.append("<tr>") 

586 html.append(f"<td>{name}</td>") 

587 if self.show_sel and _has_sel: 

588 html.append(f"<td>{sel}</td>") 

589 html.append("</tr>") 

590 html.append("</table>") 

591 return ''.join(html) 

592 

593 @property 

594 def groups(self): 

595 return self._athena_groups 

596 

597 @groups.setter 

598 def groups(self, groups): 

599 self._athena_groups = groups 

600 

601 def __getitem__(self, key): 

602 

603 if isinstance(key, int): 

604 raise IndexError("AthenaGroup does not support integer indexing") 

605 

606 return getattr(self, key) 

607 

608 def __setitem__(self, key, value): 

609 

610 if isinstance(key, int): 

611 raise IndexError("AthenaGroup does not support integer indexing") 

612 

613 return setattr(self, key, value) 

614 

615 def keys(self): 

616 return list(self.groups.keys()) 

617 

618 def values(self): 

619 return list(self.groups.values()) 

620 

621 def items(self): 

622 return list(self.groups.items()) 

623 

624class AthenaProject(object): 

625 """read and write Athena Project files, mapping to Larch group 

626 containing sub-groups for each spectra / record 

627 

628 note that two generations of Project files are supported for reading: 

629 

630 1. Perl save file (custom format?) 

631 2. JSON format 

632 

633 In addition, project files may be Gzipped or not. 

634 

635 By default, files are saved in Gzipped JSON format 

636 """ 

637 

638 def __init__(self, filename=None): 

639 self.groups = {} 

640 self.header = None 

641 self.journal = None 

642 self.filename = filename 

643 if filename is not None: 

644 if os.path.exists(filename) and is_athena_project(filename): 

645 self.read(filename) 

646 

647 def add_group(self, group, signal=None): 

648 """add Larch group (presumably XAFS data) to Athena project""" 

649 from larch.xafs import pre_edge 

650 

651 x = athena_array(group, 'energy') 

652 yname = None 

653 for _name in ('mu', 'mutrans', 'mufluor'): 

654 if hasattr(group, _name): 

655 yname = _name 

656 break 

657 if x is None or yname is None: 

658 raise ValueError("can only add XAFS data to Athena project") 

659 

660 y = athena_array(group, yname) 

661 i0 = athena_array(group, 'i0') 

662 if signal is not None: 

663 signal = athena_array(group, signal) 

664 elif yname in ('mu', 'mutrans'): 

665 sname = None 

666 for _name in ('i1', 'itrans'): 

667 if hasattr(group, _name): 

668 sname = _name 

669 break 

670 if sname is not None: 

671 signal = athena_array(group, sname) 

672 

673 apars = getattr(group, 'athena_params', None) 

674 hashkey = getattr(group, 'id', None) 

675 if hashkey is None or hashkey in self.groups: 

676 hashkey = make_hashkey() 

677 while hashkey in self.groups: 

678 hashkey = make_hashkey() 

679 

680 # fill in data from pre-edge subtraction 

681 if not (hasattr(group, 'e0') and hasattr(group, 'edge_step')): 

682 pre_edge(group) 

683 group.args = make_athena_args(group, hashkey) 

684 

685 # fix parameters that are incompatible with athena 

686 group.args['bkg_nnorm'] = max(0, min(3, int(group.args['bkg_nnorm']))) 

687 

688 _elem, _edge = guess_edge(group.e0) 

689 group.args['bkg_z'] = _elem 

690 group.x = x 

691 group.y = y 

692 group.i0 = i0 

693 group.signal = signal 

694 

695 # add a selection flag 

696 group.sel = 1 

697 

698 self.groups[hashkey] = group 

699 

700 def save(self, filename=None, use_gzip=True): 

701 if filename is not None: 

702 self.filename = filename 

703 iso_now = time.strftime('%Y-%m-%dT%H:%M:%S') 

704 pyosversion = "Python %s on %s" % (platform.python_version(), 

705 platform.platform()) 

706 

707 buff = ["# Athena project file -- Demeter version 0.9.24", 

708 "# This file created at %s" % iso_now, 

709 "# Using Larch version %s, %s" % (larch_version, pyosversion)] 

710 

711 for key, dat in self.groups.items(): 

712 if not hasattr(dat, 'args'): 

713 continue 

714 buff.append("") 

715 groupname = getattr(dat, 'groupname', key) 

716 

717 buff.append("$old_group = '%s';" % groupname) 

718 buff.append("@args = (%s);" % format_dict(dat.args)) 

719 buff.append("@x = (%s);" % format_array(dat.x)) 

720 buff.append("@y = (%s);" % format_array(dat.y)) 

721 if getattr(dat, 'i0', None) is not None: 

722 buff.append("@i0 = (%s);" % format_array(dat.i0)) 

723 if getattr(dat, 'signal', None) is not None: 

724 buff.append("@signal = (%s);" % format_array(dat.signal)) 

725 if getattr(dat, 'stddev', None) is not None: 

726 buff.append("@stddev = (%s);" % format_array(dat.stddev)) 

727 buff.append("[record] # ") 

728 

729 buff.extend(["", "@journal = {};", "", "1;", "", "", 

730 "# Local Variables:", "# truncate-lines: t", 

731 "# End:", ""]) 

732 fopen =open 

733 if use_gzip: 

734 fopen = GzipFile 

735 fh = fopen(self.filename, 'w') 

736 fh.write(str2bytes("\n".join([bytes2str(t) for t in buff]))) 

737 fh.close() 

738 

739 def read(self, filename=None, match=None, do_preedge=True, do_bkg=False, 

740 do_fft=False, use_hashkey=False): 

741 """ 

742 read Athena project to group of groups, one for each Athena dataset 

743 in the project file. This supports both gzipped and unzipped files 

744 and old-style perl-like project files and new-style JSON project files 

745 

746 Arguments: 

747 filename (string): name of Athena Project file 

748 match (string): pattern to use to limit imported groups (see Note 1) 

749 do_preedge (bool): whether to do pre-edge subtraction [True] 

750 do_bkg (bool): whether to do XAFS background subtraction [False] 

751 do_fft (bool): whether to do XAFS Fast Fourier transform [False] 

752 use_hashkey (bool): whether to use Athena's hash key as the 

753 group name instead of the Athena label [False] 

754 Returns: 

755 None, fills in attributes `header`, `journal`, `filename`, `groups` 

756 

757 Notes: 

758 1. To limit the imported groups, use the pattern in `match`, 

759 using '*' to match 'all', '?' to match any single character, 

760 or [sequence] to match any of a sequence of letters. The match 

761 will always be insensitive to case. 

762 3. do_preedge, do_bkg, and do_fft will attempt to reproduce the 

763 pre-edge, background subtraction, and FFT from Athena by using 

764 the parameters saved in the project file. 

765 2. use_hashkey=True will name groups from the internal 5 character 

766 string used by Athena, instead of the group label. 

767 

768 Example: 

769 1. read in all groups from a project file: 

770 cr_data = read_athena('My Cr Project.prj') 

771 

772 2. read in only the "merged" data from a Project, do BKG and FFT: 

773 zn_data = read_athena('Zn on Stuff.prj', match='*merge*', do_bkg=True, do_fft=True) 

774 """ 

775 if filename is not None: 

776 self.filename = filename 

777 if not os.path.exists(self.filename): 

778 raise IOError("%s '%s': cannot find file" % (ERR_MSG, self.filename)) 

779 

780 from larch.xafs import pre_edge, autobk, xftf 

781 

782 if not os.path.exists(filename): 

783 raise IOError("file '%s' not found" % filename) 

784 

785 text = _read_raw_athena(filename) 

786 # failed to read: 

787 if text is None: 

788 raise OSError("failed to read '%s'" % filename) 

789 if not _test_athena_text(text): 

790 raise ValueError("%s '%s': invalid Athena File" % (ERR_MSG, filename)) 

791 

792 # decode JSON or Perl format 

793 data = None 

794 if '____header' in text[:500]: 

795 try: 

796 data = parse_jsonathena(text, self.filename) 

797 except Exception: 

798 pass 

799 

800 if data is None: 

801 data = parse_perlathena(text, self.filename) 

802 

803 if data is None: 

804 raise ValueError("cannot read file '%s' as Athena Project File" % (self.filename)) 

805 

806 self.header = data.header 

807 self.journal = data.journal 

808 self.group_names = data.group_names 

809 

810 for gname in data.group_names: 

811 oname = gname 

812 if match is not None: 

813 if not fnmatch(gname.lower(), match): 

814 continue 

815 this = getattr(data, gname) 

816 

817 this.athena_id = this.athena_params.id 

818 if use_hashkey: 

819 oname = this.athena_params.id 

820 is_xmu = bool(int(getattr(this.athena_params, 'is_xmu', 1.0))) 

821 is_chi = bool(int(getattr(this.athena_params, 'is_chi', 0.0))) 

822 is_xmu = is_xmu and not is_chi 

823 for aname in ('is_xmudat', 'is_bkg', 'is_diff', 

824 'is_proj', 'is_pixel', 'is_rsp'): 

825 val = bool(int(getattr(this.athena_params, aname, 0.0))) 

826 is_xmu = is_xmu and not val 

827 

828 if is_xmu and (do_preedge or do_bkg): 

829 pars = clean_bkg_params(this.athena_params.bkg) 

830 eshift = getattr(this.athena_params.bkg, 'eshift', None) 

831 if eshift is not None: 

832 this.energy = this.energy + eshift 

833 pre_edge(this, e0=float(pars.e0), 

834 pre1=float(pars.pre1), pre2=float(pars.pre2), 

835 norm1=float(pars.nor1), norm2=float(pars.nor2), 

836 nnorm=float(pars.nnorm), 

837 make_flat=bool(pars.flatten)) 

838 if do_bkg and hasattr(pars, 'rbkg'): 

839 autobk(this, e0=float(pars.e0), rbkg=float(pars.rbkg), 

840 kmin=float(pars.spl1), kmax=float(pars.spl2), 

841 kweight=float(pars.kw), dk=float(pars.dk), 

842 clamp_lo=float(pars.clamp1), 

843 clamp_hi=float(pars.clamp2)) 

844 if do_fft: 

845 pars = clean_fft_params(this.athena_params.fft) 

846 kweight=2 

847 if hasattr(pars, 'kw'): 

848 kweight = float(pars.kw) 

849 xftf(this, kmin=float(pars.kmin), 

850 kmax=float(pars.kmax), kweight=kweight, 

851 window=pars.kwindow, dk=float(pars.dk)) 

852 if is_chi: 

853 this.k = this.energy*1.0 

854 this.chi = this.mu*1.0 

855 del this.energy 

856 del this.mu 

857 

858 # add a selection flag 

859 this.sel = 1 

860 

861 self.groups[oname] = this 

862 

863 def as_group(self): 

864 """convert AthenaProject to Larch group""" 

865 out = AthenaGroup() 

866 out.__doc__ = """XAFS Data from Athena Project File %s""" % (self.filename) 

867 out._athena_journal = self.journal 

868 out._athena_header = self.header 

869 out._athena_groups = self.groups 

870 

871 for name, group in self.groups.items(): 

872 setattr(out, name, group) 

873 return out 

874 

875 def as_dict(self): 

876 """convert AthenaProject to a nested dictionary""" 

877 out = dict() 

878 out["_doc"] = """XAFS Data from Athena Project File %s""" % (self.filename) 

879 out["_journal"] = self.journal # str 

880 out["_header"] = self.header # str 

881 out["groups"] = dict() 

882 

883 for name, group in self.groups.items(): 

884 gdict = group.__dict__ 

885 _ = gdict.pop("__name__") 

886 par_key = "_params" 

887 gout = deepcopy(gdict) 

888 gout[par_key] = dict() 

889 for subname, subgroup in gdict.items(): 

890 if isinstance(subgroup, Group): 

891 subdict = gout.pop(subname).__dict__ 

892 _ = subdict.pop("__name__") 

893 par_name = subname.split(par_key)[0] # group all paramters in common dictionary 

894 gout[par_key][par_name] = subdict 

895 out["groups"][name] = gout 

896 

897 return out 

898 

899 

900def read_athena(filename, match=None, do_preedge=True, do_bkg=False, 

901 do_fft=False, use_hashkey=False): 

902 """read athena project file 

903 returns a Group of Groups, one for each Athena Group in the project file 

904 

905 Arguments: 

906 filename (string): name of Athena Project file 

907 match (string): pattern to use to limit imported groups (see Note 1) 

908 do_preedge (bool): whether to do pre-edge subtraction [True] 

909 do_bkg (bool): whether to do XAFS background subtraction [False] 

910 do_fft (bool): whether to do XAFS Fast Fourier transform [False] 

911 use_hashkey (bool): whether to use Athena's hash key as the 

912 group name instead of the Athena label [False] 

913 

914 Returns: 

915 group of groups each named according the label used by Athena. 

916 

917 Notes: 

918 1. To limit the imported groups, use the pattern in `match`, 

919 using '*' to match 'all', '?' to match any single character, 

920 or [sequence] to match any of a sequence of letters. The match 

921 will always be insensitive to case. 

922 2. do_preedge, do_bkg, and do_fft will attempt to reproduce the 

923 pre-edge, background subtraction, and FFT from Athena by using 

924 the parameters saved in the project file. 

925 3. use_hashkey=True will name groups from the internal 5 character 

926 string used by Athena, instead of the group label. 

927 

928 Example: 

929 1. read in all groups from a project file: 

930 cr_data = read_athena('My Cr Project.prj') 

931 

932 2. read in only the "merged" data from a Project, and do BKG and FFT: 

933 zn_data = read_athena('Zn on Stuff.prj', match='*merge*', do_bkg=True, do_fft=True) 

934 

935 """ 

936 if not os.path.exists(filename): 

937 raise IOError("%s '%s': cannot find file" % (ERR_MSG, filename)) 

938 

939 aprj = AthenaProject() 

940 aprj.read(filename, match=match, do_preedge=do_preedge, do_bkg=do_bkg, 

941 do_fft=do_fft, use_hashkey=use_hashkey) 

942 return aprj.as_group() 

943 

944 

945def create_athena(filename=None): 

946 """create athena project file""" 

947 return AthenaProject(filename=filename) 

948 

949 

950def extract_athenagroup(dgroup): 

951 '''extract xas group from athena group''' 

952 g = dgroup 

953 g.datatype = 'xas' 

954 g.filename = getattr(g, 'label', 'unknown') 

955 g.xdat = 1.0*g.energy 

956 g.ydat = 1.0*g.mu 

957 g.yerr = 1.0 

958 g.plot_xlabel = 'energy' 

959 g.plot_ylabel = 'mu' 

960 return g 

961#enddef