Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/asteval/asteval.py : 12%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2"""Safe(ish) evaluation of mathematical expression using Python's ast
3module.
5This module provides an Interpreter class that compiles a restricted set of
6Python expressions and statements to Python's AST representation, and then
7executes that representation using values held in a symbol table.
9The symbol table is a simple dictionary, giving a simple, flat namespace.
10This comes pre-loaded with many functions from Python's builtin and math
11module. If numpy is installed, many numpy functions are also included.
12Additional symbols can be added when an Interpreter is created, but the
13user of that interpreter will not be able to import additional modules.
15Expressions, including loops, conditionals, and function definitions can be
16compiled into ast node and then evaluated later, using the current values
17in the symbol table.
19The result is a restricted, simplified version of Python meant for
20numerical caclulations that is somewhat safer than 'eval' because many
21unsafe operations (such as 'import' and 'eval') are simply not allowed.
23Many parts of Python syntax are supported, including:
24 for loops, while loops, if-then-elif-else conditionals
25 try-except (including 'finally')
26 function definitions with def
27 advanced slicing: a[::-1], array[-3:, :, ::2]
28 if-expressions: out = one_thing if TEST else other
29 list comprehension out = [sqrt(i) for i in values]
31The following Python syntax elements are not supported:
32 Import, Exec, Lambda, Class, Global, Generators,
33 Yield, Decorators
35In addition, while many builtin functions are supported, several builtin
36functions that are considered unsafe are missing ('eval', 'exec', and
37'getattr' for example)
38"""
39import ast
40import time
41import inspect
42from sys import exc_info, stdout, stderr, version_info
44from .astutils import (UNSAFE_ATTRS, HAS_NUMPY, make_symbol_table, numpy,
45 op2func, ExceptionHolder, ReturnedNone,
46 valid_symbol_name)
48if version_info[0] < 3 or version_info[1] < 5:
49 raise SystemError("Python 3.5 or higher required")
51builtins = __builtins__
52if not isinstance(builtins, dict):
53 builtins = builtins.__dict__
55ALL_NODES = ['arg', 'assert', 'assign', 'attribute', 'augassign', 'binop',
56 'boolop', 'break', 'call', 'compare', 'continue', 'delete',
57 'dict', 'ellipsis', 'excepthandler', 'expr', 'extslice',
58 'for', 'functiondef', 'if', 'ifexp', 'index', 'interrupt',
59 'list', 'listcomp', 'module', 'name', 'nameconstant', 'num',
60 'pass', 'raise', 'repr', 'return', 'slice', 'str',
61 'subscript', 'try', 'tuple', 'unaryop', 'while', 'constant']
63class Interpreter(object):
64 """create an asteval Interpreter: a restricted, simplified interpreter
65 of mathematical expressions using Python syntax.
67 Parameters
68 ----------
69 symtable : dict or `None`
70 dictionary to use as symbol table (if `None`, one will be created).
71 usersyms : dict or `None`
72 dictionary of user-defined symbols to add to symbol table.
73 writer : file-like or `None`
74 callable file-like object where standard output will be sent.
75 err_writer : file-like or `None`
76 callable file-like object where standard error will be sent.
77 use_numpy : bool
78 whether to use functions from numpy.
79 minimal : bool
80 create a minimal interpreter: disable all options (see Note 1).
81 no_if : bool
82 whether to support `if` blocks
83 no_for : bool
84 whether to support `for` blocks.
85 no_while : bool
86 whether to support `while` blocks.
87 no_try : bool
88 whether to support `try` blocks.
89 no_functiondef : bool
90 whether to support user-defined functions.
91 no_ifexp : bool
92 whether to support if expressions.
93 no_listcomp : bool
94 whether to support list comprehension.
95 no_augassign : bool
96 whether to support augemented assigments (`a += 1`, etc).
97 no_assert : bool
98 whether to support `assert`.
99 no_delete : bool
100 whether to support `del`.
101 no_raise : bool
102 whether to support `raise`.
103 no_print : bool
104 whether to support `print`.
105 readonly_symbols : iterable or `None`
106 symbols that the user can not assign to
107 builtins_readonly : bool
108 whether to blacklist all symbols that are in the initial symtable
110 Notes
111 -----
112 1. setting `minimal=True` is equivalent to setting all
113 `no_***` options to `True`.
114 """
116 def __init__(self, symtable=None, usersyms=None, writer=None,
117 err_writer=None, use_numpy=True, minimal=False,
118 no_if=False, no_for=False, no_while=False, no_try=False,
119 no_functiondef=False, no_ifexp=False, no_listcomp=False,
120 no_augassign=False, no_assert=False, no_delete=False,
121 no_raise=False, no_print=False, max_time=None,
122 readonly_symbols=None, builtins_readonly=False):
124 self.writer = writer or stdout
125 self.err_writer = err_writer or stderr
127 if symtable is None:
128 if usersyms is None:
129 usersyms = {}
130 symtable = make_symbol_table(use_numpy=use_numpy, **usersyms)
132 self.symtable = symtable
133 self._interrupt = None
134 self.error = []
135 self.error_msg = None
136 self.expr = None
137 self.retval = None
138 self.lineno = 0
139 self.start_time = time.time()
140 self.use_numpy = HAS_NUMPY and use_numpy
142 symtable['print'] = self._printer
143 self.no_print = no_print or minimal
145 nodes = ALL_NODES[:]
147 if minimal or no_if:
148 nodes.remove('if')
149 if minimal or no_for:
150 nodes.remove('for')
151 if minimal or no_while:
152 nodes.remove('while')
153 if minimal or no_try:
154 nodes.remove('try')
155 if minimal or no_functiondef:
156 nodes.remove('functiondef')
157 if minimal or no_ifexp:
158 nodes.remove('ifexp')
159 if minimal or no_assert:
160 nodes.remove('assert')
161 if minimal or no_delete:
162 nodes.remove('delete')
163 if minimal or no_raise:
164 nodes.remove('raise')
165 if minimal or no_listcomp:
166 nodes.remove('listcomp')
167 if minimal or no_augassign:
168 nodes.remove('augassign')
170 self.node_handlers = {}
171 for node in nodes:
172 self.node_handlers[node] = getattr(self, "on_%s" % node)
174 # to rationalize try/except try/finally for Python2.6 through Python3.3
175 if 'try' in self.node_handlers:
176 self.node_handlers['tryexcept'] = self.node_handlers['try']
177 self.node_handlers['tryfinally'] = self.node_handlers['try']
179 if readonly_symbols is None:
180 self.readonly_symbols = set()
181 else:
182 self.readonly_symbols = set(readonly_symbols)
184 if builtins_readonly:
185 self.readonly_symbols |= set(self.symtable)
187 self.no_deepcopy = [key for key, val in symtable.items()
188 if (callable(val)
189 or 'numpy.lib.index_tricks' in repr(val)
190 or inspect.ismodule(val))]
192 def remove_nodehandler(self, node):
193 """remove support for a node
194 returns current node handler, so that it
195 might be re-added with add_nodehandler()
196 """
197 out = None
198 if node in self.node_handlers:
199 out = self.node_handlers.pop(node)
200 return out
202 def set_nodehandler(self, node, handler):
203 """set node handler"""
204 self.node_handlers[node] = handler
207 def user_defined_symbols(self):
208 """Return a set of symbols that have been added to symtable after
209 construction.
211 I.e., the symbols from self.symtable that are not in
212 self.no_deepcopy.
214 Returns
215 -------
216 unique_symbols : set
217 symbols in symtable that are not in self.no_deepcopy
219 """
220 sym_in_current = set(self.symtable.keys())
221 sym_from_construction = set(self.no_deepcopy)
222 unique_symbols = sym_in_current.difference(sym_from_construction)
223 return unique_symbols
225 def unimplemented(self, node):
226 """Unimplemented nodes."""
227 self.raise_exception(node, exc=NotImplementedError,
228 msg="'%s' not supported" %
229 (node.__class__.__name__))
231 def raise_exception(self, node, exc=None, msg='', expr=None,
232 lineno=None):
233 """Add an exception."""
234 if self.error is None:
235 self.error = []
236 if expr is None:
237 expr = self.expr
238 if len(self.error) > 0 and not isinstance(node, ast.Module):
239 msg = '%s' % msg
240 err = ExceptionHolder(node, exc=exc, msg=msg, expr=expr, lineno=lineno)
241 self._interrupt = ast.Break()
242 self.error.append(err)
243 if self.error_msg is None:
244 self.error_msg = "at expr='%s'" % (self.expr)
245 elif len(msg) > 0:
246 self.error_msg = msg
247 if exc is None:
248 try:
249 exc = self.error[0].exc
250 except:
251 exc = RuntimeError
252 raise exc(self.error_msg)
254 # main entry point for Ast node evaluation
255 # parse: text of statements -> ast
256 # run: ast -> result
257 # eval: string statement -> result = run(parse(statement))
258 def parse(self, text):
259 """Parse statement/expression to Ast representation."""
260 self.expr = text
261 try:
262 out = ast.parse(text)
263 except SyntaxError:
264 self.raise_exception(None, msg='Syntax Error', expr=text)
265 except:
266 self.raise_exception(None, msg='Runtime Error', expr=text)
268 return out
270 def run(self, node, expr=None, lineno=None, with_raise=True):
271 """Execute parsed Ast representation for an expression."""
272 # Note: keep the 'node is None' test: internal code here may run
273 # run(None) and expect a None in return.
274 out = None
275 if len(self.error) > 0:
276 return out
277 if node is None:
278 return out
279 if isinstance(node, str):
280 node = self.parse(node)
281 if lineno is not None:
282 self.lineno = lineno
283 if expr is not None:
284 self.expr = expr
286 # get handler for this node:
287 # on_xxx with handle nodes of type 'xxx', etc
288 try:
289 handler = self.node_handlers[node.__class__.__name__.lower()]
290 except KeyError:
291 return self.unimplemented(node)
293 # run the handler: this will likely generate
294 # recursive calls into this run method.
295 try:
296 ret = handler(node)
297 if isinstance(ret, enumerate):
298 ret = list(ret)
299 return ret
300 except:
301 if with_raise:
302 if len(self.error) == 0:
303 # Unhandled exception that didn't go through raise_exception
304 self.raise_exception(node, expr=expr)
305 raise
307 def __call__(self, expr, **kw):
308 """Call class instance as function."""
309 return self.eval(expr, **kw)
311 def eval(self, expr, lineno=0, show_errors=True):
312 """Evaluate a single statement."""
313 self.lineno = lineno
314 self.error = []
315 self.start_time = time.time()
316 try:
317 node = self.parse(expr)
318 except:
319 errmsg = exc_info()[1]
320 if len(self.error) > 0:
321 errmsg = "\n".join(self.error[0].get_error())
322 if not show_errors:
323 try:
324 exc = self.error[0].exc
325 except:
326 exc = RuntimeError
327 raise exc(errmsg)
328 print(errmsg, file=self.err_writer)
329 return
330 try:
331 return self.run(node, expr=expr, lineno=lineno)
332 except:
333 errmsg = exc_info()[1]
334 if len(self.error) > 0:
335 errmsg = "\n".join(self.error[0].get_error())
336 if not show_errors:
337 try:
338 exc = self.error[0].exc
339 except:
340 exc = RuntimeError
341 raise exc(errmsg)
342 print(errmsg, file=self.err_writer)
343 return
345 @staticmethod
346 def dump(node, **kw):
347 """Simple ast dumper."""
348 return ast.dump(node, **kw)
350 # handlers for ast components
351 def on_expr(self, node):
352 """Expression."""
353 return self.run(node.value) # ('value',)
355 def on_index(self, node):
356 """Index."""
357 return self.run(node.value) # ('value',)
359 def on_return(self, node): # ('value',)
360 """Return statement: look for None, return special sentinal."""
361 self.retval = self.run(node.value)
362 if self.retval is None:
363 self.retval = ReturnedNone
364 return
366 def on_repr(self, node):
367 """Repr."""
368 return repr(self.run(node.value)) # ('value',)
370 def on_module(self, node): # ():('body',)
371 """Module def."""
372 out = None
373 for tnode in node.body:
374 out = self.run(tnode)
375 return out
377 def on_expression(self, node):
378 "basic expression"
379 return self.on_module(node) # ():('body',)
381 def on_pass(self, node):
382 """Pass statement."""
383 return None # ()
385 def on_ellipsis(self, node):
386 """Ellipses."""
387 return Ellipsis
389 # for break and continue: set the instance variable _interrupt
390 def on_interrupt(self, node): # ()
391 """Interrupt handler."""
392 self._interrupt = node
393 return node
395 def on_break(self, node):
396 """Break."""
397 return self.on_interrupt(node)
399 def on_continue(self, node):
400 """Continue."""
401 return self.on_interrupt(node)
403 def on_assert(self, node): # ('test', 'msg')
404 """Assert statement."""
405 if not self.run(node.test):
406 self.raise_exception(node, exc=AssertionError, msg=node.msg)
407 return True
409 def on_list(self, node): # ('elt', 'ctx')
410 """List."""
411 return [self.run(e) for e in node.elts]
413 def on_tuple(self, node): # ('elts', 'ctx')
414 """Tuple."""
415 return tuple(self.on_list(node))
417 def on_dict(self, node): # ('keys', 'values')
418 """Dictionary."""
419 return dict([(self.run(k), self.run(v)) for k, v in
420 zip(node.keys, node.values)])
422 def on_constant(self, node): # ('value', 'kind')
423 """Return constant value."""
424 return node.value
426 def on_num(self, node): # ('n',)
427 """Return number."""
428 return node.n
430 def on_str(self, node): # ('s',)
431 """Return string."""
432 return node.s
434 def on_nameconstant(self, node): # ('value',)
435 """named constant"""
436 return node.value
438 def on_name(self, node): # ('id', 'ctx')
439 """Name node."""
440 ctx = node.ctx.__class__
441 if ctx in (ast.Param, ast.Del):
442 return str(node.id)
443 else:
444 if node.id in self.symtable:
445 return self.symtable[node.id]
446 else:
447 msg = "name '%s' is not defined" % node.id
448 self.raise_exception(node, exc=NameError, msg=msg)
450 def on_nameconstant(self, node):
451 """ True, False, None in python >= 3.4 """
452 return node.value
454 def node_assign(self, node, val):
455 """Assign a value (not the node.value object) to a node.
457 This is used by on_assign, but also by for, list comprehension,
458 etc.
460 """
461 if node.__class__ == ast.Name:
462 if not valid_symbol_name(node.id) or node.id in self.readonly_symbols:
463 errmsg = "invalid symbol name (reserved word?) %s" % node.id
464 self.raise_exception(node, exc=NameError, msg=errmsg)
465 self.symtable[node.id] = val
466 if node.id in self.no_deepcopy:
467 self.no_deepcopy.remove(node.id)
469 elif node.__class__ == ast.Attribute:
470 if node.ctx.__class__ == ast.Load:
471 msg = "cannot assign to attribute %s" % node.attr
472 self.raise_exception(node, exc=AttributeError, msg=msg)
474 setattr(self.run(node.value), node.attr, val)
476 elif node.__class__ == ast.Subscript:
477 sym = self.run(node.value)
478 xslice = self.run(node.slice)
479 if isinstance(node.slice, ast.Index):
480 sym[xslice] = val
481 elif isinstance(node.slice, ast.Slice):
482 sym[slice(xslice.start, xslice.stop)] = val
483 elif isinstance(node.slice, ast.ExtSlice):
484 sym[xslice] = val
485 elif node.__class__ in (ast.Tuple, ast.List):
486 if len(val) == len(node.elts):
487 for telem, tval in zip(node.elts, val):
488 self.node_assign(telem, tval)
489 else:
490 raise ValueError('too many values to unpack')
492 def on_attribute(self, node): # ('value', 'attr', 'ctx')
493 """Extract attribute."""
494 ctx = node.ctx.__class__
495 if ctx == ast.Store:
496 msg = "attribute for storage: shouldn't be here!"
497 self.raise_exception(node, exc=RuntimeError, msg=msg)
499 sym = self.run(node.value)
500 if ctx == ast.Del:
501 return delattr(sym, node.attr)
503 # ctx is ast.Load
504 fmt = "cannnot access attribute '%s' for %s"
505 if node.attr not in UNSAFE_ATTRS:
506 fmt = "no attribute '%s' for %s"
507 try:
508 return getattr(sym, node.attr)
509 except AttributeError:
510 pass
512 # AttributeError or accessed unsafe attribute
513 obj = self.run(node.value)
514 msg = fmt % (node.attr, obj)
515 self.raise_exception(node, exc=AttributeError, msg=msg)
517 def on_assign(self, node): # ('targets', 'value')
518 """Simple assignment."""
519 val = self.run(node.value)
520 for tnode in node.targets:
521 self.node_assign(tnode, val)
522 return
524 def on_augassign(self, node): # ('target', 'op', 'value')
525 """Augmented assign."""
526 return self.on_assign(ast.Assign(targets=[node.target],
527 value=ast.BinOp(left=node.target,
528 op=node.op,
529 right=node.value)))
531 def on_slice(self, node): # ():('lower', 'upper', 'step')
532 """Simple slice."""
533 return slice(self.run(node.lower),
534 self.run(node.upper),
535 self.run(node.step))
537 def on_extslice(self, node): # ():('dims',)
538 """Extended slice."""
539 return tuple([self.run(tnode) for tnode in node.dims])
541 def on_subscript(self, node): # ('value', 'slice', 'ctx')
542 """Subscript handling -- one of the tricky parts."""
543 val = self.run(node.value)
544 nslice = self.run(node.slice)
545 ctx = node.ctx.__class__
546 if ctx in (ast.Load, ast.Store):
547 if isinstance(node.slice, (ast.Index, ast.Slice, ast.Ellipsis)):
548 return val.__getitem__(nslice)
549 elif isinstance(node.slice, ast.ExtSlice):
550 return val[nslice]
551 else:
552 msg = "subscript with unknown context"
553 self.raise_exception(node, msg=msg)
555 def on_delete(self, node): # ('targets',)
556 """Delete statement."""
557 for tnode in node.targets:
558 if tnode.ctx.__class__ != ast.Del:
559 break
560 children = []
561 while tnode.__class__ == ast.Attribute:
562 children.append(tnode.attr)
563 tnode = tnode.value
565 if tnode.__class__ == ast.Name and tnode.id not in self.readonly_symbols:
566 children.append(tnode.id)
567 children.reverse()
568 self.symtable.pop('.'.join(children))
569 else:
570 msg = "could not delete symbol"
571 self.raise_exception(node, msg=msg)
573 def on_unaryop(self, node): # ('op', 'operand')
574 """Unary operator."""
575 return op2func(node.op)(self.run(node.operand))
577 def on_binop(self, node): # ('left', 'op', 'right')
578 """Binary operator."""
579 return op2func(node.op)(self.run(node.left),
580 self.run(node.right))
582 def on_boolop(self, node): # ('op', 'values')
583 """Boolean operator."""
584 val = self.run(node.values[0])
585 is_and = ast.And == node.op.__class__
586 if (is_and and val) or (not is_and and not val):
587 for n in node.values[1:]:
588 val = op2func(node.op)(val, self.run(n))
589 if (is_and and not val) or (not is_and and val):
590 break
591 return val
593 def on_compare(self, node): # ('left', 'ops', 'comparators')
594 """comparison operators, including chained comparisons (a<b<c)"""
595 lval = self.run(node.left)
596 results = []
597 for op, rnode in zip(node.ops, node.comparators):
598 rval = self.run(rnode)
599 ret = op2func(op)(lval, rval)
600 results.append(ret)
601 if (self.use_numpy and not isinstance(ret, numpy.ndarray)) and not ret:
602 break
603 lval = rval
604 if len(results) == 1:
605 return results[0]
606 else:
607 out = True
608 for r in results:
609 out = out and r
610 return out
612 def _printer(self, *out, **kws):
613 """Generic print function."""
614 if self.no_print:
615 return
616 flush = kws.pop('flush', True)
617 fileh = kws.pop('file', self.writer)
618 sep = kws.pop('sep', ' ')
619 end = kws.pop('sep', '\n')
621 print(*out, file=fileh, sep=sep, end=end)
622 if flush:
623 fileh.flush()
625 def on_if(self, node): # ('test', 'body', 'orelse')
626 """Regular if-then-else statement."""
627 block = node.body
628 if not self.run(node.test):
629 block = node.orelse
630 for tnode in block:
631 self.run(tnode)
633 def on_ifexp(self, node): # ('test', 'body', 'orelse')
634 """If expressions."""
635 expr = node.orelse
636 if self.run(node.test):
637 expr = node.body
638 return self.run(expr)
640 def on_while(self, node): # ('test', 'body', 'orelse')
641 """While blocks."""
642 while self.run(node.test):
643 self._interrupt = None
644 for tnode in node.body:
645 self.run(tnode)
646 if self._interrupt is not None:
647 break
648 if isinstance(self._interrupt, ast.Break):
649 break
650 else:
651 for tnode in node.orelse:
652 self.run(tnode)
653 self._interrupt = None
655 def on_for(self, node): # ('target', 'iter', 'body', 'orelse')
656 """For blocks."""
657 for val in self.run(node.iter):
658 self.node_assign(node.target, val)
659 self._interrupt = None
660 for tnode in node.body:
661 self.run(tnode)
662 if self._interrupt is not None:
663 break
664 if isinstance(self._interrupt, ast.Break):
665 break
666 else:
667 for tnode in node.orelse:
668 self.run(tnode)
669 self._interrupt = None
671 def on_listcomp(self, node): # ('elt', 'generators')
672 """List comprehension."""
673 out = []
674 for tnode in node.generators:
675 if tnode.__class__ == ast.comprehension:
676 for val in self.run(tnode.iter):
677 self.node_assign(tnode.target, val)
678 add = True
679 for cond in tnode.ifs:
680 add = add and self.run(cond)
681 if add:
682 out.append(self.run(node.elt))
683 return out
685 def on_excepthandler(self, node): # ('type', 'name', 'body')
686 """Exception handler..."""
687 return (self.run(node.type), node.name, node.body)
689 def on_try(self, node): # ('body', 'handlers', 'orelse', 'finalbody')
690 """Try/except/else/finally blocks."""
691 no_errors = True
692 for tnode in node.body:
693 self.run(tnode, with_raise=False)
694 no_errors = no_errors and len(self.error) == 0
695 if len(self.error) > 0:
696 e_type, e_value, e_tback = self.error[-1].exc_info
697 for hnd in node.handlers:
698 htype = None
699 if hnd.type is not None:
700 htype = builtins.get(hnd.type.id, None)
701 if htype is None or isinstance(e_type(), htype):
702 self.error = []
703 if hnd.name is not None:
704 self.node_assign(hnd.name, e_value)
705 for tline in hnd.body:
706 self.run(tline)
707 break
708 break
709 if no_errors and hasattr(node, 'orelse'):
710 for tnode in node.orelse:
711 self.run(tnode)
713 if hasattr(node, 'finalbody'):
714 for tnode in node.finalbody:
715 self.run(tnode)
717 def on_raise(self, node): # ('type', 'inst', 'tback')
718 """Raise statement: note difference for python 2 and 3."""
719 excnode = node.exc
720 msgnode = node.cause
721 out = self.run(excnode)
722 msg = ' '.join(out.args)
723 msg2 = self.run(msgnode)
724 if msg2 not in (None, 'None'):
725 msg = "%s: %s" % (msg, msg2)
726 self.raise_exception(None, exc=out.__class__, msg=msg, expr='')
728 def on_call(self, node):
729 """Function execution."""
730 # ('func', 'args', 'keywords'. Py<3.5 has 'starargs' and 'kwargs' too)
731 func = self.run(node.func)
732 if not hasattr(func, '__call__') and not isinstance(func, type):
733 msg = "'%s' is not callable!!" % (func)
734 self.raise_exception(node, exc=TypeError, msg=msg)
736 args = [self.run(targ) for targ in node.args]
737 starargs = getattr(node, 'starargs', None)
738 if starargs is not None:
739 args = args + self.run(starargs)
741 keywords = {}
742 if func == print:
743 keywords['file'] = self.writer
745 for key in node.keywords:
746 if not isinstance(key, ast.keyword):
747 msg = "keyword error in function call '%s'" % (func)
748 self.raise_exception(node, msg=msg)
749 keywords[key.arg] = self.run(key.value)
751 kwargs = getattr(node, 'kwargs', None)
752 if kwargs is not None:
753 keywords.update(self.run(kwargs))
755 try:
756 return func(*args, **keywords)
757 except Exception as ex:
758 func_name = getattr(func, '__name__', str(func))
759 self.raise_exception(
760 node, msg="Error running function call '%s' with args %s and "
761 "kwargs %s: %s" % (func_name, args, keywords, ex))
763 def on_arg(self, node): # ('test', 'msg')
764 """Arg for function definitions."""
765 return node.arg
767 def on_functiondef(self, node):
768 """Define procedures."""
769 # ('name', 'args', 'body', 'decorator_list')
770 if node.decorator_list:
771 raise Warning("decorated procedures not supported!")
772 kwargs = []
774 if not valid_symbol_name(node.name) or node.name in self.readonly_symbols:
775 errmsg = "invalid function name (reserved word?) %s" % node.name
776 self.raise_exception(node, exc=NameError, msg=errmsg)
778 offset = len(node.args.args) - len(node.args.defaults)
779 for idef, defnode in enumerate(node.args.defaults):
780 defval = self.run(defnode)
781 keyval = self.run(node.args.args[idef+offset])
782 kwargs.append((keyval, defval))
784 args = [tnode.arg for tnode in node.args.args[:offset]]
785 doc = None
786 nb0 = node.body[0]
787 if isinstance(nb0, ast.Expr) and isinstance(nb0.value, ast.Str):
788 doc = nb0.value.s
790 varkws = node.args.kwarg
791 vararg = node.args.vararg
792 if isinstance(vararg, ast.arg):
793 vararg = vararg.arg
794 if isinstance(varkws, ast.arg):
795 varkws = varkws.arg
797 self.symtable[node.name] = Procedure(node.name, self, doc=doc,
798 lineno=self.lineno,
799 body=node.body,
800 args=args, kwargs=kwargs,
801 vararg=vararg, varkws=varkws)
802 if node.name in self.no_deepcopy:
803 self.no_deepcopy.remove(node.name)
806class Procedure(object):
807 """Procedure: user-defined function for asteval.
809 This stores the parsed ast nodes as from the 'functiondef' ast node
810 for later evaluation.
812 """
814 def __init__(self, name, interp, doc=None, lineno=0,
815 body=None, args=None, kwargs=None,
816 vararg=None, varkws=None):
817 """TODO: docstring in public method."""
818 self.__ininit__ = True
819 self.name = name
820 self.__name__ = self.name
821 self.__asteval__ = interp
822 self.raise_exc = self.__asteval__.raise_exception
823 self.__doc__ = doc
824 self.body = body
825 self.argnames = args
826 self.kwargs = kwargs
827 self.vararg = vararg
828 self.varkws = varkws
829 self.lineno = lineno
830 self.__ininit__ = False
832 def __setattr__(self, attr, val):
833 if not getattr(self, '__ininit__', True):
834 self.raise_exc(None, exc=TypeError,
835 msg="procedure is read-only")
836 self.__dict__[attr] = val
838 def __dir__(self):
839 return ['name']
841 def __repr__(self):
842 """TODO: docstring in magic method."""
843 sig = ""
844 if len(self.argnames) > 0:
845 sig = "%s%s" % (sig, ', '.join(self.argnames))
846 if self.vararg is not None:
847 sig = "%s, *%s" % (sig, self.vararg)
848 if len(self.kwargs) > 0:
849 if len(sig) > 0:
850 sig = "%s, " % sig
851 _kw = ["%s=%s" % (k, v) for k, v in self.kwargs]
852 sig = "%s%s" % (sig, ', '.join(_kw))
854 if self.varkws is not None:
855 sig = "%s, **%s" % (sig, self.varkws)
856 sig = "<Procedure %s(%s)>" % (self.name, sig)
857 if self.__doc__ is not None:
858 sig = "%s\n %s" % (sig, self.__doc__)
859 return sig
861 def __call__(self, *args, **kwargs):
862 """TODO: docstring in public method."""
863 symlocals = {}
864 args = list(args)
865 nargs = len(args)
866 nkws = len(kwargs)
867 nargs_expected = len(self.argnames)
868 # check for too few arguments, but the correct keyword given
869 if (nargs < nargs_expected) and nkws > 0:
870 for name in self.argnames[nargs:]:
871 if name in kwargs:
872 args.append(kwargs.pop(name))
873 nargs = len(args)
874 nargs_expected = len(self.argnames)
875 nkws = len(kwargs)
876 if nargs < nargs_expected:
877 msg = "%s() takes at least %i arguments, got %i"
878 self.raise_exc(None, exc=TypeError,
879 msg=msg % (self.name, nargs_expected, nargs))
880 # check for multiple values for named argument
881 if len(self.argnames) > 0 and kwargs is not None:
882 msg = "multiple values for keyword argument '%s' in Procedure %s"
883 for targ in self.argnames:
884 if targ in kwargs:
885 self.raise_exc(None, exc=TypeError,
886 msg=msg % (targ, self.name),
887 lineno=self.lineno)
889 # check more args given than expected, varargs not given
890 if nargs != nargs_expected:
891 msg = None
892 if nargs < nargs_expected:
893 msg = 'not enough arguments for Procedure %s()' % self.name
894 msg = '%s (expected %i, got %i)' % (msg, nargs_expected, nargs)
895 self.raise_exc(None, exc=TypeError, msg=msg)
897 if nargs > nargs_expected and self.vararg is None:
898 if nargs - nargs_expected > len(self.kwargs):
899 msg = 'too many arguments for %s() expected at most %i, got %i'
900 msg = msg % (self.name, len(self.kwargs)+nargs_expected, nargs)
901 self.raise_exc(None, exc=TypeError, msg=msg)
903 for i, xarg in enumerate(args[nargs_expected:]):
904 kw_name = self.kwargs[i][0]
905 if kw_name not in kwargs:
906 kwargs[kw_name] = xarg
908 for argname in self.argnames:
909 symlocals[argname] = args.pop(0)
911 try:
912 if self.vararg is not None:
913 symlocals[self.vararg] = tuple(args)
915 for key, val in self.kwargs:
916 if key in kwargs:
917 val = kwargs.pop(key)
918 symlocals[key] = val
920 if self.varkws is not None:
921 symlocals[self.varkws] = kwargs
923 elif len(kwargs) > 0:
924 msg = 'extra keyword arguments for Procedure %s (%s)'
925 msg = msg % (self.name, ','.join(list(kwargs.keys())))
926 self.raise_exc(None, msg=msg, exc=TypeError,
927 lineno=self.lineno)
929 except (ValueError, LookupError, TypeError,
930 NameError, AttributeError):
931 msg = 'incorrect arguments for Procedure %s' % self.name
932 self.raise_exc(None, msg=msg, lineno=self.lineno)
934 save_symtable = self.__asteval__.symtable.copy()
935 self.__asteval__.symtable.update(symlocals)
936 self.__asteval__.retval = None
937 retval = None
939 # evaluate script of function
940 for node in self.body:
941 self.__asteval__.run(node, expr='<>', lineno=self.lineno)
942 if len(self.__asteval__.error) > 0:
943 break
944 if self.__asteval__.retval is not None:
945 retval = self.__asteval__.retval
946 self.__asteval__.retval = None
947 if retval is ReturnedNone:
948 retval = None
949 break
951 self.__asteval__.symtable = save_symtable
952 symlocals = None
953 return retval