Coverage for /Users/Newville/Codes/xraylarch/larch/symboltable.py: 74%

311 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2''' 

3SymbolTable for Larch interpreter 

4''' 

5import copy 

6 

7import numpy 

8 

9from . import site_config 

10from .utils import fixName, isValidName 

11 

12class Group(): 

13 """ 

14 Generic Group: a container for variables, modules, and subgroups. 

15 """ 

16 __private = ('_main', '_larch', '_parents', '__name__', '__doc__', 

17 '__private', '_subgroups', '_members', '_repr_html_') 

18 

19 __generic_functions = ('keys', 'values', 'items') 

20 

21 def __init__(self, name=None, **kws): 

22 if name is None: 

23 name = hex(id(self)) 

24 self.__name__ = name 

25 for key, val in kws.items(): 

26 setattr(self, key, val) 

27 

28 def __len__(self): 

29 return len(dir(self)) 

30 

31 def __repr__(self): 

32 if self.__name__ is not None: 

33 return f'<Group {self.__name__}>' 

34 return '<Group>' 

35 

36 def __copy__(self): 

37 out = Group() 

38 for key, val in self.__dict__.items(): 

39 if key != '__name__': 

40 setattr(out, key, copy.copy(val)) 

41 return out 

42 

43 def __deepcopy__(self, memo): 

44 out = Group() 

45 for key, val in self.__dict__.items(): 

46 if key != '__name__': 

47 setattr(out, key, copy.deepcopy(val, memo)) 

48 return out 

49 

50 def __id__(self): 

51 return id(self) 

52 

53 def __dir__(self): 

54 "return list of member names" 

55 cls_members = [] 

56 cname = self.__class__.__name__ 

57 if cname != 'SymbolTable' and hasattr(self, '__class__'): 

58 cls_members = dir(self.__class__) 

59 

60 dict_keys = [key for key in self.__dict__ if key not in cls_members] 

61 

62 return [key for key in cls_members + dict_keys 

63 if (not key.startswith('_SymbolTable_') and 

64 not key.startswith('_Group_') and 

65 not key.startswith(f'_{cname}_') and 

66 not (key.startswith('__') and key.endswith('__')) and 

67 key not in self.__generic_functions and 

68 key not in self.__private)] 

69 

70 def __getitem__(self, key): 

71 

72 if isinstance(key, int): 

73 raise IndexError("Group does not support Integer indexing") 

74 

75 return getattr(self, key) 

76 

77 def __setitem__(self, key, value): 

78 

79 if isinstance(key, int): 

80 raise IndexError("Group does not support Integer indexing") 

81 

82 return setattr(self, key, value) 

83 

84 def __iter__(self): 

85 return iter(self.keys()) 

86 

87 def keys(self): 

88 return self.__dir__() 

89 

90 def values(self): 

91 return [getattr(self, key) for key in self.__dir__()] 

92 

93 def items(self): 

94 return [(key, getattr(self, key)) for key in self.__dir__()] 

95 

96 def _subgroups(self): 

97 "return list of names of members that are sub groups" 

98 return [k for k in self._members() if isgroup(self.__dict__[k])] 

99 

100 def _members(self): 

101 "return members" 

102 out = {} 

103 for key in self.__dir__(): 

104 if key in self.__dict__: 

105 out[key] = self.__dict__[key] 

106 return out 

107 

108 def _repr_html_(self): 

109 """HTML representation for Jupyter notebook""" 

110 

111 html = [f"Group {self.__name__}"] 

112 html.append("<table>") 

113 html.append("<tr><td><b>Attribute</b></td><td><b>Type</b></td></tr>") 

114 attrs = self.__dir__() 

115 atypes = [type(getattr(self, attr)).__name__ for attr in attrs] 

116 hwords = [f"<tr><td>{attr}</td><td><i>{atp}</i></td></tr>" \ 

117 for attr, atp in zip(attrs, atypes)] 

118 html.append(''.join(hwords)) 

119 html.append("</table>") 

120 return ''.join(html) 

121 

122 

123def isgroup(grp, *args): 

124 """tests if input is a Group 

125 

126 With additional arguments (all must be strings), it also tests 

127 that the group has an an attribute named for each argument. This 

128 can be used to test not only if a object is a Group, but whether 

129 it a group with expected arguments. 

130 """ 

131 ret = isinstance(grp, Group) 

132 if ret and len(args) > 0: 

133 try: 

134 ret = all([hasattr(grp, a) for a in args]) 

135 except TypeError: 

136 return False 

137 return ret 

138 

139 

140class InvalidName: 

141 """ used to create a value that will NEVER be a useful symbol. 

142 symboltable._lookup() uses this to check for invalid names""" 

143 

144 

145GroupDocs = {} 

146GroupDocs['_sys'] = """ 

147Larch system-wide status variables, including 

148configuration variables and lists of Groups used 

149for finding variables. 

150""" 

151 

152GroupDocs['_builtin'] = """ 

153core built-in functions, most taken from Python 

154""" 

155 

156GroupDocs['_math'] = """ 

157Mathematical functions, including a host of functtion from numpy and scipy 

158""" 

159 

160 

161class SymbolTable(Group): 

162 """Main Symbol Table for Larch. 

163 """ 

164 top_group = '_main' 

165 core_groups = ('_sys', '_builtin', '_math') 

166 __invalid_name = InvalidName() 

167 _private = ('save_frame', 'restore_frame', 'set_frame', 

168 'has_symbol', 'has_group', 'get_group', 

169 'create_group', 'new_group', 'isgroup', 

170 'get_symbol', 'set_symbol', 'del_symbol', 

171 'get_parent', '_path', '__parents') 

172 

173 def __init__(self, larch=None): 

174 Group.__init__(self, name=self.top_group) 

175 self._larch = larch 

176 self._sys = None 

177 setattr(self, self.top_group, self) 

178 

179 for gname in self.core_groups: 

180 thisgroup = Group(name=gname) 

181 if gname in GroupDocs: 

182 thisgroup.__doc__ = GroupDocs[gname] 

183 setattr(self, gname, thisgroup) 

184 

185 self._sys.frames = [] 

186 self._sys.searchGroups = [self.top_group] 

187 self._sys.path = ['.'] 

188 self._sys.localGroup = self 

189 self._sys.valid_commands = [] 

190 self._sys.moduleGroup = self 

191 self._sys.__cache__ = [None]*4 

192 self._sys.saverestore_groups = [] 

193 for grp in self.core_groups: 

194 self._sys.searchGroups.append(grp) 

195 self._sys.core_groups = tuple(self._sys.searchGroups[:]) 

196 

197 # self.__callbacks = {} 

198 

199 self._sys.modules = {'_main':self} 

200 for gname in self.core_groups: 

201 self._sys.modules[gname] = getattr(self, gname) 

202 self._fix_searchGroups() 

203 

204 self._sys.config = Group(home_dir = site_config.home_dir, 

205 history_file= site_config.history_file, 

206 init_files = site_config.init_files, 

207 user_larchdir= site_config.user_larchdir, 

208 larch_version= site_config.larch_version, 

209 release_version = site_config.larch_release_version) 

210 

211 def save_frame(self): 

212 " save current local/module group" 

213 self._sys.frames.append((self._sys.localGroup, self._sys.moduleGroup)) 

214 

215 def restore_frame(self): 

216 "restore last saved local/module group" 

217 try: 

218 lgrp, mgrp = self._sys.frames.pop() 

219 self._sys.localGroup = lgrp 

220 self._sys.moduleGroup = mgrp 

221 self._fix_searchGroups() 

222 except: 

223 pass 

224 

225 def set_frame(self, groups): 

226 "set current execution frame (localGroup, moduleGroup)" 

227 self._sys.localGroup, self._sys.moduleGroup = groups 

228 self._fix_searchGroups() 

229 

230 

231 def _fix_searchGroups(self, force=False): 

232 """resolve list of groups to search for symbol names: 

233 

234 The variable self._sys.searchGroups holds the list of group 

235 names for searching for symbol names. A user can set this 

236 dynamically. The names need to be absolute (that is, relative to 

237 _main, and can omit the _main prefix). 

238 

239 This calclutes and returns self._sys.searchGroupObjects, 

240 which is the list of actual group objects (not names) resolved from 

241 the list of names in _sys.searchGroups) 

242 

243 _sys.localGroup,_sys.moduleGroup come first in the search list, 

244 followed by any search path associated with that module (from 

245 imports for that module) 

246 """ 

247 ## 

248 # check (and cache) whether searchGroups needs to be changed. 

249 sys = self._sys 

250 cache = sys.__cache__ 

251 if len(cache) < 4: 

252 cache = [None]*4 

253 if (sys.localGroup == cache[0] and 

254 sys.moduleGroup == cache[1] and 

255 sys.searchGroups == cache[2] and 

256 cache[3] is not None and not force): 

257 return cache[3] 

258 

259 if sys.moduleGroup is None: 

260 sys.moduleGroup = self.top_group 

261 if sys.localGroup is None: 

262 sys.localGroup = sys.moduleGroup 

263 

264 cache[0] = sys.localGroup 

265 cache[1] = sys.moduleGroup 

266 snames = [] 

267 sgroups = [] 

268 for grp in (sys.localGroup, sys.moduleGroup): 

269 if grp is not None and grp not in sgroups: 

270 sgroups.append(grp) 

271 snames.append(grp.__name__) 

272 

273 sysmods = list(self._sys.modules.values()) 

274 searchGroups = sys.searchGroups[:] 

275 searchGroups.extend(self._sys.core_groups) 

276 for name in searchGroups: 

277 grp = None 

278 if name in self._sys.modules: 

279 grp = self._sys.modules[name] 

280 elif hasattr(self, name): 

281 gtest = getattr(self, name) 

282 if isinstance(gtest, Group): 

283 grp = gtest 

284 elif '.' in name: 

285 parent, child= name.split('.') 

286 for sgrp in sysmods: 

287 if (parent == sgrp.__name__ and 

288 hasattr(sgrp, child)): 

289 grp = getattr(sgrp, child) 

290 break 

291 else: 

292 for sgrp in sysmods: 

293 if hasattr(sgrp, name): 

294 grp = getattr(sgrp, name) 

295 break 

296 if grp is not None and grp not in sgroups: 

297 sgroups.append(grp) 

298 snames.append(name) 

299 

300 self._sys.searchGroups = cache[2] = snames[:] 

301 sys.searchGroupObjects = cache[3] = sgroups[:] 

302 return sys.searchGroupObjects 

303 

304 def get_parentpath(self, sym): 

305 """ get parent path for a symbol""" 

306 obj = self._lookup(sym) 

307 if obj is None: 

308 return 

309 out = [] 

310 for s in reversed(self.__parents): 

311 if s.__name__ != '_main' or '_main' not in out: 

312 out.append(s.__name__) 

313 out.reverse() 

314 return '.'.join(out) 

315 

316 def _lookup(self, name=None, create=False): 

317 """looks up symbol in search path 

318 returns symbol given symbol name, 

319 creating symbol if needed (and create=True)""" 

320 debug = False # not ('force'in name) 

321 if debug: 

322 print( '====\nLOOKUP ', name) 

323 searchGroups = self._fix_searchGroups() 

324 self.__parents = [] 

325 if self not in searchGroups: 

326 searchGroups.append(self) 

327 

328 def public_attr(grp, name): 

329 return (hasattr(grp, name) and 

330 not (grp is self and name in self._private)) 

331 

332 parts = name.split('.') 

333 if len(parts) == 1: 

334 for grp in searchGroups: 

335 if public_attr(grp, name): 

336 self.__parents.append(grp) 

337 return getattr(grp, name) 

338 

339 # more complex case: not immediately found in Local or Module Group 

340 parts.reverse() 

341 top = parts.pop() 

342 out = self.__invalid_name 

343 if top == self.top_group: 

344 out = self 

345 else: 

346 for grp in searchGroups: 

347 if public_attr(grp, top): 

348 self.__parents.append(grp) 

349 out = getattr(grp, top) 

350 if out is self.__invalid_name: 

351 raise NameError(f"'{name}' is not defined") 

352 

353 if len(parts) == 0: 

354 return out 

355 

356 while parts: 

357 prt = parts.pop() 

358 if hasattr(out, prt): 

359 out = getattr(out, prt) 

360 elif create: 

361 val = None 

362 if len(parts) > 0: 

363 val = Group(name=prt) 

364 setattr(out, prt, val) 

365 out = getattr(out, prt) 

366 else: 

367 raise LookupError( 

368 f"cannot locate member '{prt}' of '{out}'") 

369 return out 

370 

371 def has_symbol(self, symname): 

372 try: 

373 _ = self.get_symbol(symname) 

374 return True 

375 except (LookupError, NameError, ValueError): 

376 return False 

377 

378 def has_group(self, gname): 

379 try: 

380 _ = self.get_group(gname) 

381 return True 

382 except (NameError, LookupError): 

383 return False 

384 

385 def isgroup(self, sym): 

386 "test if symbol is a group" 

387 return isgroup(sym) 

388 

389 def get_group(self, gname): 

390 "find group by name" 

391 sym = self._lookup(gname, create=False) 

392 if isgroup(sym): 

393 return sym 

394 raise LookupError(f"symbol '{gname}' found, but not a group") 

395 

396 def create_group(self, **kw): 

397 "create a new Group, not placed anywhere in symbol table" 

398 return Group(**kw) 

399 

400 def new_group(self, name, **kws): 

401 name = fixName(name) 

402 grp = Group(__name__ = name, **kws) 

403 self.set_symbol(name, value=grp) 

404 return grp 

405 

406 def get_symbol(self, sym, create=False): 

407 "lookup and return a symbol by name" 

408 return self._lookup(sym, create=create) 

409 

410 def set_symbol(self, name, value=None, group=None): 

411 "set a symbol in the table" 

412 grp = self._sys.localGroup 

413 if group is not None: 

414 grp = self.get_group(group) 

415 names = [] 

416 

417 for n in name.split('.'): 

418 if not isValidName(n): 

419 raise SyntaxError(f"invalid symbol name '{n}'") 

420 names.append(n) 

421 

422 child = names.pop() 

423 for nam in names: 

424 if hasattr(grp, nam): 

425 grp = getattr(grp, nam) 

426 if not isgroup(grp): 

427 raise ValueError( 

428 f"cannot create subgroup of non-group '{grp}'") 

429 else: 

430 setattr(grp, nam, Group()) 

431 

432 setattr(grp, child, value) 

433 return value 

434 

435 def del_symbol(self, name): 

436 "delete a symbol" 

437 sym = self._lookup(name, create=False) 

438 parent, child = self.get_parent(name) 

439 delattr(parent, child) 

440 

441 def clear_callbacks(self, name, index=None): 

442 """clear 1 or all callbacks for a symbol 

443 """ 

444 pass 

445 

446 def add_callback(self, name, func, args=None, kws=None): 

447 """disabled: 

448 set a callback to be called when set_symbol() is called 

449 for a named variable 

450 """ 

451 print("adding callback on symbol disabled") 

452 

453 

454 def get_parent(self, name): 

455 """return parent group, child name for an absolute symbol name 

456 (as from _lookup) that is, a pair suitable for hasattr, 

457 getattr, or delattr 

458 """ 

459 tnam = name.split('.') 

460 if len(tnam) < 1 or name == self.top_group: 

461 return (self, None) 

462 child = tnam.pop() 

463 sym = self 

464 if len(tnam) > 0: 

465 sym = self._lookup('.'.join(tnam)) 

466 return sym, child 

467 

468 def show_group(self, groupname): 

469 """display group members --- simple version for tests""" 

470 out = [] 

471 try: 

472 group = self.get_group(groupname) 

473 except (NameError, LookupError): 

474 return 'Group %s not found' % groupname 

475 

476 members = dir(group) 

477 out = ['f== {group.__name__}: {len(members)} symbols =='] 

478 for item in members: 

479 obj = getattr(group, item) 

480 dval = None 

481 if isinstance(obj, numpy.ndarray): 

482 if len(obj) > 10 or len(obj.shape)>1: 

483 dval = "array<shape=%s, type=%s>" % (repr(obj.shape), 

484 repr(obj.dtype)) 

485 if dval is None: 

486 dval = repr(obj) 

487 out.append(f' {item}: {dval}') 

488 out.append('\n') 

489 self._larch.writer.write('\n'.join(out))