Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/asteval/astutils.py : 35%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2utility functions for asteval
4 Matthew Newville <newville@cars.uchicago.edu>,
5 The University of Chicago
6"""
7import io
8import re
9import ast
10import math
11import numbers
12from sys import exc_info
13from tokenize import (tokenize as generate_tokens,
14 ENCODING as tk_ENCODING,
15 NAME as tk_NAME)
17HAS_NUMPY = False
18numpy = None
19try:
20 import numpy
21 ndarr = numpy.ndarray
22 HAS_NUMPY = True
23except ImportError:
24 pass
26MAX_EXPONENT = 10000
27MAX_STR_LEN = 2 << 17 # 256KiB
28MAX_SHIFT = 1000
29MAX_OPEN_BUFFER = 2 << 17
31RESERVED_WORDS = ('and', 'as', 'assert', 'break', 'class', 'continue',
32 'def', 'del', 'elif', 'else', 'except', 'exec',
33 'finally', 'for', 'from', 'global', 'if', 'import',
34 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print',
35 'raise', 'return', 'try', 'while', 'with', 'True',
36 'False', 'None', 'eval', 'execfile', '__import__',
37 '__package__')
39NAME_MATCH = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$").match
41UNSAFE_ATTRS = ('__subclasses__', '__bases__', '__globals__', '__code__',
42 '__closure__', '__func__', '__self__', '__module__',
43 '__dict__', '__class__', '__call__', '__get__',
44 '__getattribute__', '__subclasshook__', '__new__',
45 '__init__', 'func_globals', 'func_code', 'func_closure',
46 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame',
47 '__asteval__', 'f_locals', '__mro__')
49# inherit these from python's __builtins__
50FROM_PY = ('ArithmeticError', 'AssertionError', 'AttributeError',
51 'BaseException', 'BufferError', 'BytesWarning',
52 'DeprecationWarning', 'EOFError', 'EnvironmentError',
53 'Exception', 'False', 'FloatingPointError', 'GeneratorExit',
54 'IOError', 'ImportError', 'ImportWarning', 'IndentationError',
55 'IndexError', 'KeyError', 'KeyboardInterrupt', 'LookupError',
56 'MemoryError', 'NameError', 'None',
57 'NotImplementedError', 'OSError', 'OverflowError',
58 'ReferenceError', 'RuntimeError', 'RuntimeWarning',
59 'StopIteration', 'SyntaxError', 'SyntaxWarning', 'SystemError',
60 'SystemExit', 'True', 'TypeError', 'UnboundLocalError',
61 'UnicodeDecodeError', 'UnicodeEncodeError', 'UnicodeError',
62 'UnicodeTranslateError', 'UnicodeWarning', 'ValueError',
63 'Warning', 'ZeroDivisionError', 'abs', 'all', 'any', 'bin',
64 'bool', 'bytearray', 'bytes', 'chr', 'complex', 'dict', 'dir',
65 'divmod', 'enumerate', 'filter', 'float', 'format', 'frozenset',
66 'hash', 'hex', 'id', 'int', 'isinstance', 'len', 'list', 'map',
67 'max', 'min', 'oct', 'ord', 'pow', 'range', 'repr',
68 'reversed', 'round', 'set', 'slice', 'sorted', 'str', 'sum',
69 'tuple', 'zip')
71# inherit these from python's math
72FROM_MATH = ('acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh',
73 'ceil', 'copysign', 'cos', 'cosh', 'degrees', 'e', 'exp',
74 'fabs', 'factorial', 'floor', 'fmod', 'frexp', 'fsum',
75 'hypot', 'isinf', 'isnan', 'ldexp', 'log', 'log10', 'log1p',
76 'modf', 'pi', 'pow', 'radians', 'sin', 'sinh', 'sqrt', 'tan',
77 'tanh', 'trunc')
79FROM_NUMPY = ('Inf', 'NAN', 'abs', 'add', 'alen', 'all', 'amax', 'amin',
80 'angle', 'any', 'append', 'arange', 'arccos', 'arccosh',
81 'arcsin', 'arcsinh', 'arctan', 'arctan2', 'arctanh',
82 'argmax', 'argmin', 'argsort', 'argwhere', 'around', 'array',
83 'array2string', 'asanyarray', 'asarray', 'asarray_chkfinite',
84 'ascontiguousarray', 'asfarray', 'asfortranarray',
85 'asmatrix', 'asscalar', 'atleast_1d', 'atleast_2d',
86 'atleast_3d', 'average', 'bartlett', 'base_repr',
87 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor',
88 'blackman', 'bool', 'broadcast', 'broadcast_arrays', 'byte',
89 'c_', 'cdouble', 'ceil', 'cfloat', 'chararray', 'choose',
90 'clip', 'clongdouble', 'clongfloat', 'column_stack',
91 'common_type', 'complex', 'complex128', 'complex64',
92 'complex_', 'complexfloating', 'compress', 'concatenate',
93 'conjugate', 'convolve', 'copy', 'copysign', 'corrcoef',
94 'correlate', 'cos', 'cosh', 'cov', 'cross', 'csingle',
95 'cumprod', 'cumsum', 'datetime_data', 'deg2rad', 'degrees',
96 'delete', 'diag', 'diag_indices', 'diag_indices_from',
97 'diagflat', 'diagonal', 'diff', 'digitize', 'divide', 'dot',
98 'double', 'dsplit', 'dstack', 'dtype', 'e', 'ediff1d',
99 'empty', 'empty_like', 'equal', 'exp', 'exp2', 'expand_dims',
100 'expm1', 'extract', 'eye', 'fabs', 'fill_diagonal', 'finfo',
101 'fix', 'flatiter', 'flatnonzero', 'fliplr', 'flipud',
102 'float', 'float32', 'float64', 'float_', 'floating', 'floor',
103 'floor_divide', 'fmax', 'fmin', 'fmod', 'format_parser',
104 'frexp', 'frombuffer', 'fromfile', 'fromfunction',
105 'fromiter', 'frompyfunc', 'fromregex', 'fromstring', 'fv',
106 'genfromtxt', 'getbufsize', 'geterr', 'gradient', 'greater',
107 'greater_equal', 'hamming', 'hanning', 'histogram',
108 'histogram2d', 'histogramdd', 'hsplit', 'hstack', 'hypot',
109 'i0', 'identity', 'iinfo', 'imag', 'in1d', 'index_exp',
110 'indices', 'inexact', 'inf', 'info', 'infty', 'inner',
111 'insert', 'int', 'int0', 'int16', 'int32', 'int64', 'int8',
112 'int_', 'int_asbuffer', 'intc', 'integer', 'interp',
113 'intersect1d', 'intp', 'invert', 'ipmt', 'irr', 'iscomplex',
114 'iscomplexobj', 'isfinite', 'isfortran', 'isinf', 'isnan',
115 'isneginf', 'isposinf', 'isreal', 'isrealobj', 'isscalar',
116 'issctype', 'iterable', 'ix_', 'kaiser', 'kron', 'ldexp',
117 'left_shift', 'less', 'less_equal', 'linspace',
118 'little_endian', 'load', 'loads', 'loadtxt', 'log', 'log10',
119 'log1p', 'log2', 'logaddexp', 'logaddexp2', 'logical_and',
120 'logical_not', 'logical_or', 'logical_xor', 'logspace',
121 'long', 'longcomplex', 'longdouble', 'longfloat', 'longlong',
122 'mafromtxt', 'mask_indices', 'mat', 'matrix',
123 'maximum', 'maximum_sctype', 'may_share_memory', 'mean',
124 'median', 'memmap', 'meshgrid', 'mgrid', 'minimum',
125 'mintypecode', 'mirr', 'mod', 'modf', 'msort', 'multiply',
126 'nan', 'nan_to_num', 'nanargmax', 'nanargmin', 'nanmax',
127 'nanmin', 'nansum', 'ndarray', 'ndenumerate', 'ndfromtxt',
128 'ndim', 'ndindex', 'negative', 'newaxis', 'nextafter',
129 'nonzero', 'not_equal', 'nper', 'npv', 'number',
130 'obj2sctype', 'ogrid', 'ones', 'ones_like', 'outer',
131 'packbits', 'percentile', 'pi', 'piecewise', 'place', 'pmt',
132 'poly', 'poly1d', 'polyadd', 'polyder', 'polydiv', 'polyfit',
133 'polyint', 'polymul', 'polysub', 'polyval', 'power', 'ppmt',
134 'prod', 'product', 'ptp', 'put', 'putmask', 'pv', 'r_',
135 'rad2deg', 'radians', 'rank', 'rate', 'ravel', 'real',
136 'real_if_close', 'reciprocal', 'record', 'remainder',
137 'repeat', 'reshape', 'resize', 'restoredot', 'right_shift',
138 'rint', 'roll', 'rollaxis', 'roots', 'rot90', 'round',
139 'round_', 'row_stack', 's_', 'sctype2char', 'searchsorted',
140 'select', 'setbufsize', 'setdiff1d', 'seterr', 'setxor1d',
141 'shape', 'short', 'sign', 'signbit', 'signedinteger', 'sin',
142 'sinc', 'single', 'singlecomplex', 'sinh', 'size',
143 'sometrue', 'sort', 'sort_complex', 'spacing', 'split',
144 'sqrt', 'square', 'squeeze', 'std', 'str', 'str_',
145 'subtract', 'sum', 'swapaxes', 'take', 'tan', 'tanh',
146 'tensordot', 'tile', 'trace', 'transpose', 'trapz', 'tri',
147 'tril', 'tril_indices', 'tril_indices_from', 'trim_zeros',
148 'triu', 'triu_indices', 'triu_indices_from', 'true_divide',
149 'trunc', 'ubyte', 'uint', 'uint0', 'uint16', 'uint32',
150 'uint64', 'uint8', 'uintc', 'uintp', 'ulonglong', 'union1d',
151 'unique', 'unravel_index', 'unsignedinteger', 'unwrap',
152 'ushort', 'vander', 'var', 'vdot', 'vectorize', 'vsplit',
153 'vstack', 'where', 'who', 'zeros', 'zeros_like',
154 'fft', 'linalg', 'polynomial', 'random')
157NUMPY_RENAMES = {'ln': 'log', 'asin': 'arcsin', 'acos': 'arccos',
158 'atan': 'arctan', 'atan2': 'arctan2', 'atanh':
159 'arctanh', 'acosh': 'arccosh', 'asinh': 'arcsinh'}
162def _open(filename, mode='r', buffering=0):
163 """read only version of open()"""
164 if mode not in ('r', 'rb', 'rU'):
165 raise RuntimeError("Invalid open file mode, must be 'r', 'rb', or 'rU'")
166 if buffering > MAX_OPEN_BUFFER:
167 raise RuntimeError("Invalid buffering value, max buffer size is {}".format(MAX_OPEN_BUFFER))
168 return open(filename, mode, buffering)
170def _type(obj, *varargs, **varkws):
171 """type that prevents varargs and varkws"""
172 return type(obj).__name__
175LOCALFUNCS = {'open': _open, 'type': _type}
178# Safe versions of functions to prevent denial of service issues
180def safe_pow(base, exp):
181 """safe version of pow"""
182 if isinstance(exp, numbers.Number):
183 if exp > MAX_EXPONENT:
184 raise RuntimeError("Invalid exponent, max exponent is {}".format(MAX_EXPONENT))
185 elif HAS_NUMPY:
186 if isinstance(exp, numpy.ndarray):
187 if numpy.nanmax(exp) > MAX_EXPONENT:
188 raise RuntimeError("Invalid exponent, max exponent is {}".format(MAX_EXPONENT))
189 return base ** exp
192def safe_mult(a, b):
193 """safe version of multiply"""
194 if isinstance(a, str) and isinstance(b, int) and len(a) * b > MAX_STR_LEN:
195 raise RuntimeError("String length exceeded, max string length is {}".format(MAX_STR_LEN))
196 return a * b
199def safe_add(a, b):
200 """safe version of add"""
201 if isinstance(a, str) and isinstance(b, str) and len(a) + len(b) > MAX_STR_LEN:
202 raise RuntimeError("String length exceeded, max string length is {}".format(MAX_STR_LEN))
203 return a + b
206def safe_lshift(a, b):
207 """safe version of lshift"""
208 if isinstance(b, numbers.Number):
209 if b > MAX_SHIFT:
210 raise RuntimeError("Invalid left shift, max left shift is {}".format(MAX_SHIFT))
211 elif HAS_NUMPY:
212 if isinstance(b, numpy.ndarray):
213 if numpy.nanmax(b) > MAX_SHIFT:
214 raise RuntimeError("Invalid left shift, max left shift is {}".format(MAX_SHIFT))
215 return a << b
218OPERATORS = {ast.Is: lambda a, b: a is b,
219 ast.IsNot: lambda a, b: a is not b,
220 ast.In: lambda a, b: a in b,
221 ast.NotIn: lambda a, b: a not in b,
222 ast.Add: safe_add,
223 ast.BitAnd: lambda a, b: a & b,
224 ast.BitOr: lambda a, b: a | b,
225 ast.BitXor: lambda a, b: a ^ b,
226 ast.Div: lambda a, b: a / b,
227 ast.FloorDiv: lambda a, b: a // b,
228 ast.LShift: safe_lshift,
229 ast.RShift: lambda a, b: a >> b,
230 ast.Mult: safe_mult,
231 ast.Pow: safe_pow,
232 ast.Sub: lambda a, b: a - b,
233 ast.Mod: lambda a, b: a % b,
234 ast.And: lambda a, b: a and b,
235 ast.Or: lambda a, b: a or b,
236 ast.Eq: lambda a, b: a == b,
237 ast.Gt: lambda a, b: a > b,
238 ast.GtE: lambda a, b: a >= b,
239 ast.Lt: lambda a, b: a < b,
240 ast.LtE: lambda a, b: a <= b,
241 ast.NotEq: lambda a, b: a != b,
242 ast.Invert: lambda a: ~a,
243 ast.Not: lambda a: not a,
244 ast.UAdd: lambda a: +a,
245 ast.USub: lambda a: -a}
248def valid_symbol_name(name):
249 """Determine whether the input symbol name is a valid name.
251 Arguments
252 ---------
253 name : str
254 name to check for validity.
256 Returns
257 --------
258 valid : bool
259 whether name is a a valid symbol name
261 This checks for Python reserved words and that the name matches
262 the regular expression ``[a-zA-Z_][a-zA-Z0-9_]``
263 """
264 if name in RESERVED_WORDS:
265 return False
267 gen = generate_tokens(io.BytesIO(name.encode('utf-8')).readline)
268 typ, _, start, end, _ = next(gen)
269 if typ == tk_ENCODING:
270 typ, _, start, end, _ = next(gen)
271 return typ == tk_NAME and start == (1, 0) and end == (1, len(name))
274def op2func(op):
275 """Return function for operator nodes."""
276 return OPERATORS[op.__class__]
279class Empty:
280 """Empty class."""
282 def __init__(self):
283 """TODO: docstring in public method."""
284 pass
286 def __nonzero__(self):
287 """TODO: docstring in magic method."""
288 return False
291ReturnedNone = Empty()
294class ExceptionHolder(object):
295 """Basic exception handler."""
297 def __init__(self, node, exc=None, msg='', expr=None, lineno=None):
298 """TODO: docstring in public method."""
299 self.node = node
300 self.expr = expr
301 self.msg = msg
302 self.exc = exc
303 self.lineno = lineno
304 self.exc_info = exc_info()
305 if self.exc is None and self.exc_info[0] is not None:
306 self.exc = self.exc_info[0]
307 if self.msg == '' and self.exc_info[1] is not None:
308 self.msg = self.exc_info[1]
310 def get_error(self):
311 """Retrieve error data."""
312 col_offset = -1
313 if self.node is not None:
314 try:
315 col_offset = self.node.col_offset
316 except AttributeError:
317 pass
318 try:
319 exc_name = self.exc.__name__
320 except AttributeError:
321 exc_name = str(self.exc)
322 if exc_name in (None, 'None'):
323 exc_name = 'UnknownError'
325 out = [" %s" % self.expr]
326 if col_offset > 0:
327 out.append(" %s^^^" % ((col_offset)*' '))
328 out.append(str(self.msg))
329 return (exc_name, '\n'.join(out))
332class NameFinder(ast.NodeVisitor):
333 """Find all symbol names used by a parsed node."""
335 def __init__(self):
336 """TODO: docstring in public method."""
337 self.names = []
338 ast.NodeVisitor.__init__(self)
340 def generic_visit(self, node):
341 """TODO: docstring in public method."""
342 if node.__class__.__name__ == 'Name':
343 if node.ctx.__class__ == ast.Load and node.id not in self.names:
344 self.names.append(node.id)
345 ast.NodeVisitor.generic_visit(self, node)
347builtins = __builtins__
348if not isinstance(builtins, dict):
349 builtins = builtins.__dict__
351def get_ast_names(astnode):
352 """Return symbol Names from an AST node."""
353 finder = NameFinder()
354 finder.generic_visit(astnode)
355 return finder.names
358def make_symbol_table(use_numpy=True, **kws):
359 """Create a default symboltable, taking dict of user-defined symbols.
361 Arguments
362 ---------
363 numpy : bool, optional
364 whether to include symbols from numpy
365 kws : optional
366 additional symbol name, value pairs to include in symbol table
368 Returns
369 --------
370 symbol_table : dict
371 a symbol table that can be used in `asteval.Interpereter`
373 """
374 symtable = {}
376 for sym in FROM_PY:
377 if sym in builtins:
378 symtable[sym] = builtins[sym]
380 for sym in FROM_MATH:
381 if hasattr(math, sym):
382 symtable[sym] = getattr(math, sym)
384 if HAS_NUMPY and use_numpy:
385 for sym in FROM_NUMPY:
386 if hasattr(numpy, sym):
387 symtable[sym] = getattr(numpy, sym)
388 for name, sym in NUMPY_RENAMES.items():
389 if hasattr(numpy, sym):
390 symtable[name] = getattr(numpy, sym)
393 symtable.update(LOCALFUNCS)
394 symtable.update(kws)
396 return symtable