Coverage for /Users/Newville/Codes/xraylarch/larch/larchlib.py: 65%

437 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2""" 

3Helper classes for larch interpreter 

4""" 

5from __future__ import division 

6import sys, os, time 

7from datetime import datetime 

8import ast 

9import numpy as np 

10import traceback 

11import toml 

12import inspect 

13from collections import namedtuple 

14import ctypes 

15import ctypes.util 

16 

17from .symboltable import Group, isgroup 

18from .site_config import user_larchdir 

19from .closure import Closure 

20from .utils import uname, bindir, get_cwd, read_textfile 

21 

22HAS_TERMCOLOR = False 

23try: 

24 from termcolor import colored 

25 if uname == 'win': 

26 # HACK (hopefully temporary): 

27 # disable color output for Windows command terminal 

28 # because it interferes with wx event loop. 

29 import CannotUseTermcolorOnWindowsWithWx 

30 # os.environ.pop('TERM') 

31 # import colorama 

32 # colorama.init() 

33 HAS_TERMCOLOR = True 

34except ImportError: 

35 HAS_TERMCOLOR = False 

36 

37 

38class Empty: 

39 def __nonzero__(self): return False 

40 

41# holder for 'returned None' from Larch procedure 

42ReturnedNone = Empty() 

43 

44def get_filetext(fname, lineno): 

45 """try to extract line from source text file""" 

46 out = '<could not find text>' 

47 try: 

48 ftmp = open(fname, 'r') 

49 lines = ftmp.readlines() 

50 ftmp.close() 

51 lineno = min(lineno, len(lines)) - 1 

52 out = lines[lineno][:-1] 

53 except: 

54 pass 

55 return out 

56 

57class LarchExceptionHolder: 

58 "basic exception handler" 

59 def __init__(self, node=None, msg='', fname='<stdin>', 

60 func=None, expr=None, exc=None, lineno=0): 

61 self.node = node 

62 self.fname = fname 

63 self.func = func 

64 self.expr = expr 

65 self.msg = msg 

66 self.exc = exc 

67 self.lineno = lineno 

68 self.exc_info = sys.exc_info() 

69 

70 if self.exc is None and self.exc_info[0] is not None: 

71 self.exc = self.exc_info[0] 

72 if self.msg in ('', None) and self.exc_info[1] is not None: 

73 self.msg = self.exc_info[1] 

74 

75 def get_error(self): 

76 "retrieve error data" 

77 col_offset = -1 

78 e_type, e_val, e_tb = self.exc_info 

79 if self.node is not None: 

80 try: 

81 col_offset = self.node.col_offset 

82 except AttributeError: 

83 pass 

84 try: 

85 exc_name = self.exc.__name__ 

86 except AttributeError: 

87 exc_name = str(self.exc) 

88 if exc_name in (None, 'None'): 

89 exc_name = 'UnknownError' 

90 

91 out = [] 

92 fname = self.fname 

93 

94 if isinstance(self.expr, ast.AST): 

95 self.expr = 'In compiled script' 

96 if self.expr is None: 

97 out.append('unknown error\n') 

98 elif '\n' in self.expr: 

99 out.append("\n%s" % self.expr) 

100 else: 

101 out.append(" %s" % self.expr) 

102 if col_offset > 0: 

103 out.append("%s^^^" % ((col_offset)*' ')) 

104 

105 fline = ' File %s, line %i' % (fname, self.lineno) 

106 if self.func is not None: 

107 func = self.func 

108 fname = self.fname 

109 if fname is None: 

110 if isinstance(func, Closure): 

111 func = func.func 

112 fname = inspect.getmodule(func).__file__ 

113 try: 

114 fname = inspect.getmodule(func).__file__ 

115 except AttributeError: 

116 fname = 'unknown' 

117 if fname.endswith('.pyc'): 

118 fname = fname[:-1] 

119 

120 if hasattr(self.func, 'name'): 

121 dec = '' 

122 if isinstance(self.func, Procedure): 

123 dec = 'procedure ' 

124 pname = self.func.name 

125 ftext = get_filetext(self.fname, self.lineno) 

126 fline = "%s, in %s%s\n%s" % (fline, dec, pname, ftext) 

127 

128 if fline is not None: 

129 out.append(fline) 

130 

131 tblist = [] 

132 for tb in traceback.extract_tb(self.exc_info[2]): 

133 if not (sys.prefix in tb[0] and 

134 ('ast.py' in tb[0] or 

135 os.path.join('larch', 'utils') in tb[0] or 

136 os.path.join('larch', 'interpreter') in tb[0] or 

137 os.path.join('larch', 'symboltable') in tb[0])): 

138 tblist.append(tb) 

139 if len(tblist) > 0: 

140 out.append(''.join(traceback.format_list(tblist))) 

141 

142 # try to get last error message, as from e_val.args 

143 ex_msg = getattr(e_val, 'args', None) 

144 try: 

145 ex_msg = ' '.join(ex_msg) 

146 except TypeError: 

147 pass 

148 

149 if ex_msg is None: 

150 ex_msg = getattr(e_val, 'message', None) 

151 if ex_msg is None: 

152 ex_msg = self.msg 

153 out.append("%s: %s" % (exc_name, ex_msg)) 

154 

155 out.append("") 

156 return (exc_name, '\n'.join(out)) 

157 

158 

159 

160class StdWriter(object): 

161 """Standard writer method for Larch, 

162 to be used in place of sys.stdout 

163 

164 supports methods: 

165 set_mode(mode) # one of 'text', 'text2', 'error', 'comment' 

166 write(text) 

167 flush() 

168 """ 

169 valid_termcolors = ('grey', 'red', 'green', 'yellow', 

170 'blue', 'magenta', 'cyan', 'white') 

171 

172 termcolor_attrs = ('bold', 'underline', 'blink', 'reverse') 

173 def __init__(self, stdout=None, has_color=True, _larch=None): 

174 if stdout is None: 

175 stdout = sys.stdout 

176 self.has_color = has_color and HAS_TERMCOLOR 

177 self.writer = stdout 

178 self._larch = _larch 

179 self.textstyle = None 

180 

181 def set_textstyle(self, mode='text'): 

182 """ set text style for output """ 

183 if not self.has_color: 

184 self.textstyle = None 

185 display_colors = self._larch.symtable._sys.display.colors 

186 self.textstyle = display_colors.get(mode, {}) 

187 

188 def write(self, text): 

189 """write text to writer 

190 write('hello') 

191 """ 

192 if self.textstyle is not None and HAS_TERMCOLOR: 

193 text = colored(text, **self.textstyle) 

194 self.writer.write(text) 

195 

196 def flush(self): 

197 self.writer.flush() 

198 

199 

200class Procedure(object): 

201 """larch procedure: function """ 

202 def __init__(self, name, _larch=None, doc=None, 

203 fname='<stdin>', lineno=0, 

204 body=None, args=None, kwargs=None, 

205 vararg=None, varkws=None): 

206 self.name = name 

207 self._larch = _larch 

208 self.modgroup = _larch.symtable._sys.moduleGroup 

209 self.body = body 

210 self.argnames = args 

211 self.kwargs = kwargs 

212 self.vararg = vararg 

213 self.varkws = varkws 

214 self.__doc__ = doc 

215 self.lineno = lineno 

216 self.__file__ = fname 

217 self.__name__ = name 

218 

219 def __repr__(self): 

220 return "<Procedure %s, file=%s>" % (self.name, self.__file__) 

221 

222 def _signature(self): 

223 sig = "" 

224 if len(self.argnames) > 0: 

225 sig = "%s%s" % (sig, ', '.join(self.argnames)) 

226 if self.vararg is not None: 

227 sig = "%s, *%s" % (sig, self.vararg) 

228 if len(self.kwargs) > 0: 

229 if len(sig) > 0: 

230 sig = "%s, " % sig 

231 _kw = ["%s=%s" % (k, repr(v)) for k, v in self.kwargs] 

232 sig = "%s%s" % (sig, ', '.join(_kw)) 

233 

234 if self.varkws is not None: 

235 sig = "%s, **%s" % (sig, self.varkws) 

236 return "%s(%s)" % (self.name, sig) 

237 

238 def raise_exc(self, **kws): 

239 ekws = dict(lineno=self.lineno, func=self, fname=self.__file__) 

240 ekws.update(kws) 

241 self._larch.raise_exception(None, **ekws) 

242 

243 def __call__(self, *args, **kwargs): 

244 # msg = 'Cannot run Procedure %s' % self.name 

245 lgroup = Group() 

246 lgroup.__name__ = hex(id(lgroup)) 

247 args = list(args) 

248 nargs = len(args) 

249 nkws = len(kwargs) 

250 nargs_expected = len(self.argnames) 

251 

252 

253 # case 1: too few arguments, but the correct keyword given 

254 if (nargs < nargs_expected) and nkws > 0: 

255 for name in self.argnames[nargs:]: 

256 if name in kwargs: 

257 args.append(kwargs.pop(name)) 

258 nargs = len(args) 

259 nargs_expected = len(self.argnames) 

260 nkws = len(kwargs) 

261 

262 # case 2: multiple values for named argument 

263 if len(self.argnames) > 0 and kwargs is not None: 

264 msg = "%s() got multiple values for keyword argument '%s'" 

265 for targ in self.argnames: 

266 if targ in kwargs: 

267 self.raise_exc(exc=TypeError, 

268 msg=msg % (self.name, targ)) 

269 return 

270 

271 # case 3: too few args given 

272 if nargs < nargs_expected: 

273 mod = 'at least' 

274 if len(self.kwargs) == 0: 

275 mod = 'exactly' 

276 msg = '%s() expected %s %i arguments (got %i)' 

277 self.raise_exc(exc=TypeError, 

278 msg=msg%(self.name, mod, nargs_expected, nargs)) 

279 return 

280 

281 # case 4: more args given than expected, varargs not given 

282 if nargs > nargs_expected and self.vararg is None: 

283 if nargs - nargs_expected > len(self.kwargs): 

284 msg = 'too many arguments for %s() expected at most %i, got %i' 

285 msg = msg % (self.name, len(self.kwargs)+nargs_expected, nargs) 

286 self.raise_exc(exc=TypeError, msg=msg) 

287 return 

288 for i, xarg in enumerate(args[nargs_expected:]): 

289 kw_name = self.kwargs[i][0] 

290 if kw_name not in kwargs: 

291 kwargs[kw_name] = xarg 

292 

293 for argname in self.argnames: 

294 if len(args) > 0: 

295 setattr(lgroup, argname, args.pop(0)) 

296 try: 

297 if self.vararg is not None: 

298 setattr(lgroup, self.vararg, tuple(args)) 

299 

300 for key, val in self.kwargs: 

301 if key in kwargs: 

302 val = kwargs.pop(key) 

303 setattr(lgroup, key, val) 

304 

305 if self.varkws is not None: 

306 setattr(lgroup, self.varkws, kwargs) 

307 elif len(kwargs) > 0: 

308 msg = 'extra keyword arguments for procedure %s (%s)' 

309 msg = msg % (self.name, ','.join(list(kwargs.keys()))) 

310 self.raise_exc(exc=TypeError, msg=msg) 

311 return 

312 

313 except (ValueError, LookupError, TypeError, 

314 NameError, AttributeError): 

315 msg = 'incorrect arguments for procedure %s' % self.name 

316 self.raise_exc(msg=msg) 

317 return 

318 

319 stable = self._larch.symtable 

320 stable.save_frame() 

321 stable.set_frame((lgroup, self.modgroup)) 

322 retval = None 

323 self._larch.retval = None 

324 self._larch._calldepth += 1 

325 self._larch.debug = True 

326 for node in self.body: 

327 self._larch.run(node, fname=self.__file__, func=self, 

328 lineno=node.lineno+self.lineno-1, with_raise=False) 

329 if len(self._larch.error) > 0: 

330 break 

331 if self._larch.retval is not None: 

332 retval = self._larch.retval 

333 if retval is ReturnedNone: retval = None 

334 break 

335 stable.restore_frame() 

336 self._larch._calldepth -= 1 

337 self._larch.debug = False 

338 self._larch.retval = None 

339 del lgroup 

340 return retval 

341 

342 

343def add2path(envvar='PATH', dirname='.'): 

344 """add specified dir to begninng of PATH and 

345 DYLD_LIBRARY_PATH, LD_LIBRARY_PATH environmental variables, 

346 returns previous definition of PATH, for restoration""" 

347 sep = ':' 

348 if uname == 'win': 

349 sep = ';' 

350 oldpath = os.environ.get(envvar, '') 

351 if oldpath == '': 

352 os.environ[envvar] = dirname 

353 else: 

354 paths = oldpath.split(sep) 

355 paths.insert(0, os.path.abspath(dirname)) 

356 os.environ[envvar] = sep.join(paths) 

357 return oldpath 

358 

359 

360def isNamedClass(obj, cls): 

361 """this is essentially a replacement for 

362 isinstance(obj, cls) 

363 that looks if an objects class name matches that of a class 

364 obj.__class__.__name__ == cls.__name__ 

365 """ 

366 return obj.__class__.__name__ == cls.__name__ 

367 

368def get_dll(libname): 

369 """find and load a shared library""" 

370 _dylib_formats = {'win': '%s.dll', 'linux': 'lib%s.so', 

371 'darwin': 'lib%s.dylib'} 

372 

373 loaddll = ctypes.cdll.LoadLibrary 

374 if uname == 'win': 

375 loaddll = ctypes.windll.LoadLibrary 

376 

377 # normally, we expect the dll to be here in the larch dlls tree 

378 # if we find it there, use that one 

379 fname = _dylib_formats[uname] % libname 

380 dllpath = os.path.join(bindir, fname) 

381 if os.path.exists(dllpath): 

382 return loaddll(dllpath) 

383 

384 # if not found in the larch dlls tree, try your best! 

385 dllpath = ctypes.util.find_library(libname) 

386 if dllpath is not None and os.path.exists(dllpath): 

387 return loaddll(dllpath) 

388 return None 

389 

390 

391def read_workdir(conffile): 

392 """read working dir from a config file in the users larch dir 

393 compare save_workdir(conffile) which will save this value 

394 

395 can be used to ensure that application startup starts in 

396 last working directory 

397 """ 

398 

399 try: 

400 w_file = os.path.join(user_larchdir, conffile) 

401 if os.path.exists(w_file): 

402 line = open(w_file, 'r').readlines() 

403 workdir = line[0][:-1] 

404 os.chdir(workdir) 

405 except: 

406 pass 

407 

408def save_workdir(conffile): 

409 """write working dir to a config file in the users larch dir 

410 compare read_workdir(conffile) which will read this value 

411 

412 can be used to ensure that application startup starts in 

413 last working directory 

414 """ 

415 

416 try: 

417 w_file = os.path.join(user_larchdir, conffile) 

418 fh = open(w_file, 'w', encoding=sys.getdefaultencoding()) 

419 fh.write("%s\n" % get_cwd()) 

420 fh.close() 

421 except: 

422 pass 

423 

424 

425def read_config(conffile): 

426 """read toml config file from users larch dir 

427 compare save_config(conffile) which will save such a config 

428 

429 returns dictionary / configuration 

430 """ 

431 cfile = os.path.join(user_larchdir, conffile) 

432 out = None 

433 if os.path.exists(cfile): 

434 data = read_textfile(cfile) 

435 try: 

436 out = toml.loads(data) 

437 except: 

438 pass 

439 return out 

440 

441def save_config(conffile, config): 

442 """write yaml config file in the users larch dir 

443 compare read_confif(conffile) which will read this value 

444 

445 """ 

446 cfile = os.path.join(user_larchdir, conffile) 

447 dat = toml.dumps(config).encode('utf-8') 

448 with open(cfile, 'wb') as fh: 

449 fh.write(dat) 

450 #except: 

451 # print(f"Could not save configuration file '{conffile:s}'") 

452 

453def parse_group_args(arg0, members=None, group=None, defaults=None, 

454 fcn_name=None, check_outputs=True): 

455 """parse arguments for functions supporting First Argument Group convention 

456 

457 That is, if the first argument is a Larch Group and contains members 

458 named in 'members', this will return data extracted from that group. 

459 

460 Arguments 

461 ---------- 

462 arg0: first argument for function call. 

463 members: list/tuple of names of required members (in order) 

464 defaults: tuple of default values for remaining required 

465 arguments past the first (in order) 

466 group: group sent to parent function, used for outputs 

467 fcn_name: name of parent function, used for error messages 

468 check_output: True/False (default True) setting whether a Warning should 

469 be raised in any of the outputs (except for the final group) 

470 are None. This effectively checks that all expected inputs 

471 have been specified 

472 Returns 

473 ------- 

474 tuple of output values in the order listed by members, followed by the 

475 output group (which could be None). 

476 

477 Notes 

478 ----- 

479 This implements the First Argument Group convention, used for many Larch functions. 

480 As an example, the function _xafs.find_e0 is defined like this: 

481 find_e0(energy, mu=None, group=None, ...) 

482 

483 and uses this function as 

484 energy, mu, group = parse_group_arg(energy, members=('energy', 'mu'), 

485 defaults=(mu,), group=group, 

486 fcn_name='find_e0', check_output=True) 

487 

488 This allows the caller to use 

489 find_e0(grp) 

490 as a shorthand for 

491 find_e0(grp.energy, grp.mu, group=grp) 

492 

493 as long as the Group grp has member 'energy', and 'mu'. 

494 

495 With 'check_output=True', the value for 'mu' is not actually allowed to be None. 

496 

497 The defaults tuple should be passed so that correct values are assigned 

498 if the caller actually specifies arrays as for the full call signature. 

499 """ 

500 if members is None: 

501 members = [] 

502 if isgroup(arg0, *members): 

503 if group is None: 

504 group = arg0 

505 out = [getattr(arg0, attr) for attr in members] 

506 else: 

507 out = [arg0] + list(defaults) 

508 

509 # test that all outputs are non-None 

510 if check_outputs: 

511 _errmsg = """%s: needs First Argument Group or valid arguments for 

512 %s""" 

513 if fcn_name is None: 

514 fcn_name ='unknown function' 

515 for i, nam in enumerate(members): 

516 if out[i] is None: 

517 raise Warning(_errmsg % (fcn_name, ', '.join(members))) 

518 

519 out.append(group) 

520 return out 

521 

522def Make_CallArgs(skipped_args): 

523 """ 

524 decorator to create a 'call_args' dictionary 

525 containing function arguments 

526 If a Group is included in the call arguments, 

527 these call_args will be added to the group's journal 

528 """ 

529 def wrap(fcn): 

530 def wrapper(*args, **kwargs): 

531 result = fcn(*args, **kwargs) 

532 argspec = inspect.getfullargspec(fcn) 

533 

534 offset = len(argspec.args) - len(argspec.defaults) 

535 call_args = {} 

536 

537 for k in argspec.args[:offset]: 

538 call_args[k] = None 

539 for k, v in zip(argspec.args[offset:], argspec.defaults): 

540 call_args[k] = v 

541 

542 for iarg, arg in enumerate(args): 

543 call_args[argspec.args[iarg]] = arg 

544 

545 call_args.update(kwargs) 

546 

547 skipped = skipped_args[:] 

548 at0 = skipped[0] 

549 at1 = skipped[1] 

550 a, b, groupx = parse_group_args(call_args[at0], 

551 members=(at0, at1), 

552 defaults=(call_args[at1],), 

553 group=call_args['group'], 

554 fcn_name=fcn.__name__) 

555 

556 for k in skipped + ['group', '_larch']: 

557 if k in call_args: 

558 call_args.pop(k) 

559 

560 if groupx is not None: 

561 fname = fcn.__name__ 

562 if not hasattr(groupx, 'journal'): groupx.journal = Journal() 

563 if not hasattr(groupx, 'callargs'): groupx.callargs = Group() 

564 setattr(groupx.callargs, fname, call_args) 

565 groupx.journal.add(f'{fname}_callargs', call_args) 

566 

567 return result 

568 wrapper.__doc__ = fcn.__doc__ 

569 wrapper.__name__ = fcn.__name__ 

570 wrapper._larchfunc_ = fcn 

571 wrapper.__filename__ = fcn.__code__.co_filename 

572 wrapper.__dict__.update(fcn.__dict__) 

573 return wrapper 

574 return wrap 

575 

576 

577def ensuremod(_larch, modname=None): 

578 "ensure that a group exists" 

579 if _larch is not None: 

580 symtable = _larch.symtable 

581 if modname is not None and not symtable.has_group(modname): 

582 symtable.newgroup(modname) 

583 return symtable 

584 

585Entry = namedtuple('Entry', ('key', 'value', 'datetime')) 

586 

587def _get_dtime(dtime=None): 

588 """get datetime from input 

589 dtime can be: 

590 datetime : used as is 

591 str : assumed to be isoformat 

592 float : assumed to unix timestamp 

593 None : means now 

594 """ 

595 if isinstance(dtime, datetime): 

596 return dtime 

597 if isinstance(dtime, (int, float)): 

598 return datetime.fromtimestamp(dtime) 

599 elif isinstance(dtime, str): 

600 return datetime.fromisoformat(dtime) 

601 return datetime.now() 

602 

603class Journal: 

604 """list of journal entries""" 

605 def __init__(self, *args, **kws): 

606 self.data = [] 

607 for arg in args: 

608 if isinstance(arg, Journal): 

609 for entry in arg.data: 

610 self.add(entry.key, entry.value, dtime=entry.datetime) 

611 elif isinstance(arg, (list, tuple)): 

612 for entry in arg: 

613 self.add(entry[0], entry[1], dtime=entry[2]) 

614 

615 for k, v in kws.items(): 

616 self.add(k, v) 

617 

618 def tolist(self): 

619 return [(x.key, x.value, x.datetime.isoformat()) for x in self.data] 

620 

621 def __repr__(self): 

622 return repr(self.tolist()) 

623 

624 def __iter__(self): 

625 return iter(self.data) 

626 

627 

628 def add(self, key, value, dtime=None): 

629 """add journal entry: 

630 key, value pair with optional datetime 

631 """ 

632 self.data.append(Entry(key, value, _get_dtime(dtime))) 

633 

634 def add_ifnew(self, key, value, dtime=None): 

635 """add journal entry unless it already matches latest 

636 value (and dtime if supplied) 

637 """ 

638 needs_add = True 

639 latest = self.get(key, latest=True) 

640 if latest is not None: 

641 needs_add = (latest.value != value) 

642 if not needs_add and dtime is not None: 

643 dtime = _get_dtime(dtime) 

644 needs_add = needs_add or (latest.dtime != dtime) 

645 

646 if needs_add: 

647 self.add(key, value, dtime=dtime) 

648 

649 def get(self, key, latest=True): 

650 """get journal entries by key 

651 

652 Arguments 

653 ---------- 

654 latest [bool] whether to return latest matching entry only [True] 

655 

656 Notes: 

657 ------- 

658 if latest is True, one value will be returned, 

659 otherwise a list of entries (possibly length 1) will be returned. 

660 

661 """ 

662 matches = [x for x in self.data if x.key==key] 

663 if latest: 

664 tlatest = 0 

665 latest = None 

666 for m in matches: 

667 if m.datetime.timestamp() > tlatest: 

668 latest = m 

669 return latest 

670 return matches 

671 

672 def keys(self): 

673 return [x.key for x in self.data] 

674 

675 def values(self): 

676 return [x.values for x in self.data] 

677 

678 def items(self): 

679 return [(x.key, x.value) for x in self.data] 

680 

681 def get_latest(self, key): 

682 return self.get(key, latest=True) 

683 

684 def get_matches(self, key): 

685 return self.get(key, latest=False) 

686 

687 def sorted(self, sortby='time'): 

688 "return all entries, sorted by time or alphabetically by key" 

689 if 'time' in sortby.lower(): 

690 return sorted(self.data, key=lambda x: x.datetime.timestamp()) 

691 else: 

692 return sorted(self.data, key=lambda x: x.key) 

693 

694 def __getstate__(self): 

695 "get state for pickle / json encoding" 

696 return [(x.key, x.value, x.datetime.isoformat()) for x in self.data] 

697 

698 def __setstate__(self, state): 

699 "set state from pickle / json encoding" 

700 self.data = [] 

701 for key, value, dt in state: 

702 self.data.append(Entry(key, value, datetime.fromisoformat(dt)))