Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/pandas/core/computation/pytables.py : 25%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1""" manage PyTables query interface via Expressions """
3import ast
4from functools import partial
5from typing import Any, Dict, Optional, Tuple
7import numpy as np
9from pandas._libs.tslibs import Timedelta, Timestamp
10from pandas.compat.chainmap import DeepChainMap
12from pandas.core.dtypes.common import is_list_like
14import pandas as pd
15import pandas.core.common as com
16from pandas.core.computation import expr, ops, scope as _scope
17from pandas.core.computation.common import _ensure_decoded
18from pandas.core.computation.expr import BaseExprVisitor
19from pandas.core.computation.ops import UndefinedVariableError, is_term
21from pandas.io.formats.printing import pprint_thing, pprint_thing_encoded
24class PyTablesScope(_scope.Scope):
25 __slots__ = ("queryables",)
27 queryables: Dict[str, Any]
29 def __init__(
30 self,
31 level: int,
32 global_dict=None,
33 local_dict=None,
34 queryables: Optional[Dict[str, Any]] = None,
35 ):
36 super().__init__(level + 1, global_dict=global_dict, local_dict=local_dict)
37 self.queryables = queryables or dict()
40class Term(ops.Term):
41 env: PyTablesScope
43 def __new__(cls, name, env, side=None, encoding=None):
44 klass = Constant if not isinstance(name, str) else cls
45 return object.__new__(klass)
47 def __init__(self, name, env: PyTablesScope, side=None, encoding=None):
48 super().__init__(name, env, side=side, encoding=encoding)
50 def _resolve_name(self):
51 # must be a queryables
52 if self.side == "left":
53 # Note: The behavior of __new__ ensures that self.name is a str here
54 if self.name not in self.env.queryables:
55 raise NameError(f"name {repr(self.name)} is not defined")
56 return self.name
58 # resolve the rhs (and allow it to be None)
59 try:
60 return self.env.resolve(self.name, is_local=False)
61 except UndefinedVariableError:
62 return self.name
64 # read-only property overwriting read/write property
65 @property # type: ignore
66 def value(self):
67 return self._value
70class Constant(Term):
71 def __init__(self, value, env: PyTablesScope, side=None, encoding=None):
72 assert isinstance(env, PyTablesScope), type(env)
73 super().__init__(value, env, side=side, encoding=encoding)
75 def _resolve_name(self):
76 return self._name
79class BinOp(ops.BinOp):
81 _max_selectors = 31
83 op: str
84 queryables: Dict[str, Any]
86 def __init__(self, op: str, lhs, rhs, queryables: Dict[str, Any], encoding):
87 super().__init__(op, lhs, rhs)
88 self.queryables = queryables
89 self.encoding = encoding
90 self.condition = None
92 def _disallow_scalar_only_bool_ops(self):
93 pass
95 def prune(self, klass):
96 def pr(left, right):
97 """ create and return a new specialized BinOp from myself """
99 if left is None:
100 return right
101 elif right is None:
102 return left
104 k = klass
105 if isinstance(left, ConditionBinOp):
106 if isinstance(right, ConditionBinOp):
107 k = JointConditionBinOp
108 elif isinstance(left, k):
109 return left
110 elif isinstance(right, k):
111 return right
113 elif isinstance(left, FilterBinOp):
114 if isinstance(right, FilterBinOp):
115 k = JointFilterBinOp
116 elif isinstance(left, k):
117 return left
118 elif isinstance(right, k):
119 return right
121 return k(
122 self.op, left, right, queryables=self.queryables, encoding=self.encoding
123 ).evaluate()
125 left, right = self.lhs, self.rhs
127 if is_term(left) and is_term(right):
128 res = pr(left.value, right.value)
129 elif not is_term(left) and is_term(right):
130 res = pr(left.prune(klass), right.value)
131 elif is_term(left) and not is_term(right):
132 res = pr(left.value, right.prune(klass))
133 elif not (is_term(left) or is_term(right)):
134 res = pr(left.prune(klass), right.prune(klass))
136 return res
138 def conform(self, rhs):
139 """ inplace conform rhs """
140 if not is_list_like(rhs):
141 rhs = [rhs]
142 if isinstance(rhs, np.ndarray):
143 rhs = rhs.ravel()
144 return rhs
146 @property
147 def is_valid(self) -> bool:
148 """ return True if this is a valid field """
149 return self.lhs in self.queryables
151 @property
152 def is_in_table(self) -> bool:
153 """ return True if this is a valid column name for generation (e.g. an
154 actual column in the table) """
155 return self.queryables.get(self.lhs) is not None
157 @property
158 def kind(self):
159 """ the kind of my field """
160 return getattr(self.queryables.get(self.lhs), "kind", None)
162 @property
163 def meta(self):
164 """ the meta of my field """
165 return getattr(self.queryables.get(self.lhs), "meta", None)
167 @property
168 def metadata(self):
169 """ the metadata of my field """
170 return getattr(self.queryables.get(self.lhs), "metadata", None)
172 def generate(self, v) -> str:
173 """ create and return the op string for this TermValue """
174 val = v.tostring(self.encoding)
175 return f"({self.lhs} {self.op} {val})"
177 def convert_value(self, v) -> "TermValue":
178 """ convert the expression that is in the term to something that is
179 accepted by pytables """
181 def stringify(value):
182 if self.encoding is not None:
183 encoder = partial(pprint_thing_encoded, encoding=self.encoding)
184 else:
185 encoder = pprint_thing
186 return encoder(value)
188 kind = _ensure_decoded(self.kind)
189 meta = _ensure_decoded(self.meta)
190 if kind == "datetime64" or kind == "datetime":
191 if isinstance(v, (int, float)):
192 v = stringify(v)
193 v = _ensure_decoded(v)
194 v = Timestamp(v)
195 if v.tz is not None:
196 v = v.tz_convert("UTC")
197 return TermValue(v, v.value, kind)
198 elif kind == "timedelta64" or kind == "timedelta":
199 v = Timedelta(v, unit="s").value
200 return TermValue(int(v), v, kind)
201 elif meta == "category":
202 metadata = com.values_from_object(self.metadata)
203 result = metadata.searchsorted(v, side="left")
205 # result returns 0 if v is first element or if v is not in metadata
206 # check that metadata contains v
207 if not result and v not in metadata:
208 result = -1
209 return TermValue(result, result, "integer")
210 elif kind == "integer":
211 v = int(float(v))
212 return TermValue(v, v, kind)
213 elif kind == "float":
214 v = float(v)
215 return TermValue(v, v, kind)
216 elif kind == "bool":
217 if isinstance(v, str):
218 v = not v.strip().lower() in [
219 "false",
220 "f",
221 "no",
222 "n",
223 "none",
224 "0",
225 "[]",
226 "{}",
227 "",
228 ]
229 else:
230 v = bool(v)
231 return TermValue(v, v, kind)
232 elif isinstance(v, str):
233 # string quoting
234 return TermValue(v, stringify(v), "string")
235 else:
236 raise TypeError(f"Cannot compare {v} of type {type(v)} to {kind} column")
238 def convert_values(self):
239 pass
242class FilterBinOp(BinOp):
243 filter: Optional[Tuple[Any, Any, pd.Index]] = None
245 def __repr__(self) -> str:
246 if self.filter is None:
247 return "Filter: Not Initialized"
248 return pprint_thing(f"[Filter : [{self.filter[0]}] -> [{self.filter[1]}]")
250 def invert(self):
251 """ invert the filter """
252 if self.filter is not None:
253 f = list(self.filter)
254 f[1] = self.generate_filter_op(invert=True)
255 self.filter = tuple(f)
256 return self
258 def format(self):
259 """ return the actual filter format """
260 return [self.filter]
262 def evaluate(self):
264 if not self.is_valid:
265 raise ValueError(f"query term is not valid [{self}]")
267 rhs = self.conform(self.rhs)
268 values = list(rhs)
270 if self.is_in_table:
272 # if too many values to create the expression, use a filter instead
273 if self.op in ["==", "!="] and len(values) > self._max_selectors:
275 filter_op = self.generate_filter_op()
276 self.filter = (self.lhs, filter_op, pd.Index(values))
278 return self
279 return None
281 # equality conditions
282 if self.op in ["==", "!="]:
284 filter_op = self.generate_filter_op()
285 self.filter = (self.lhs, filter_op, pd.Index(values))
287 else:
288 raise TypeError(
289 f"passing a filterable condition to a non-table indexer [{self}]"
290 )
292 return self
294 def generate_filter_op(self, invert: bool = False):
295 if (self.op == "!=" and not invert) or (self.op == "==" and invert):
296 return lambda axis, vals: ~axis.isin(vals)
297 else:
298 return lambda axis, vals: axis.isin(vals)
301class JointFilterBinOp(FilterBinOp):
302 def format(self):
303 raise NotImplementedError("unable to collapse Joint Filters")
305 def evaluate(self):
306 return self
309class ConditionBinOp(BinOp):
310 def __repr__(self) -> str:
311 return pprint_thing(f"[Condition : [{self.condition}]]")
313 def invert(self):
314 """ invert the condition """
315 # if self.condition is not None:
316 # self.condition = "~(%s)" % self.condition
317 # return self
318 raise NotImplementedError(
319 "cannot use an invert condition when passing to numexpr"
320 )
322 def format(self):
323 """ return the actual ne format """
324 return self.condition
326 def evaluate(self):
328 if not self.is_valid:
329 raise ValueError(f"query term is not valid [{self}]")
331 # convert values if we are in the table
332 if not self.is_in_table:
333 return None
335 rhs = self.conform(self.rhs)
336 values = [self.convert_value(v) for v in rhs]
338 # equality conditions
339 if self.op in ["==", "!="]:
341 # too many values to create the expression?
342 if len(values) <= self._max_selectors:
343 vs = [self.generate(v) for v in values]
344 self.condition = f"({' | '.join(vs)})"
346 # use a filter after reading
347 else:
348 return None
349 else:
350 self.condition = self.generate(values[0])
352 return self
355class JointConditionBinOp(ConditionBinOp):
356 def evaluate(self):
357 self.condition = f"({self.lhs.condition} {self.op} {self.rhs.condition})"
358 return self
361class UnaryOp(ops.UnaryOp):
362 def prune(self, klass):
364 if self.op != "~":
365 raise NotImplementedError("UnaryOp only support invert type ops")
367 operand = self.operand
368 operand = operand.prune(klass)
370 if operand is not None:
371 if issubclass(klass, ConditionBinOp):
372 if operand.condition is not None:
373 return operand.invert()
374 elif issubclass(klass, FilterBinOp):
375 if operand.filter is not None:
376 return operand.invert()
378 return None
381class PyTablesExprVisitor(BaseExprVisitor):
382 const_type = Constant
383 term_type = Term
385 def __init__(self, env, engine, parser, **kwargs):
386 super().__init__(env, engine, parser)
387 for bin_op in self.binary_ops:
388 bin_node = self.binary_op_nodes_map[bin_op]
389 setattr(
390 self,
391 f"visit_{bin_node}",
392 lambda node, bin_op=bin_op: partial(BinOp, bin_op, **kwargs),
393 )
395 def visit_UnaryOp(self, node, **kwargs):
396 if isinstance(node.op, (ast.Not, ast.Invert)):
397 return UnaryOp("~", self.visit(node.operand))
398 elif isinstance(node.op, ast.USub):
399 return self.const_type(-self.visit(node.operand).value, self.env)
400 elif isinstance(node.op, ast.UAdd):
401 raise NotImplementedError("Unary addition not supported")
403 def visit_Index(self, node, **kwargs):
404 return self.visit(node.value).value
406 def visit_Assign(self, node, **kwargs):
407 cmpr = ast.Compare(
408 ops=[ast.Eq()], left=node.targets[0], comparators=[node.value]
409 )
410 return self.visit(cmpr)
412 def visit_Subscript(self, node, **kwargs):
413 # only allow simple subscripts
415 value = self.visit(node.value)
416 slobj = self.visit(node.slice)
417 try:
418 value = value.value
419 except AttributeError:
420 pass
422 try:
423 return self.const_type(value[slobj], self.env)
424 except TypeError:
425 raise ValueError(f"cannot subscript {repr(value)} with {repr(slobj)}")
427 def visit_Attribute(self, node, **kwargs):
428 attr = node.attr
429 value = node.value
431 ctx = type(node.ctx)
432 if ctx == ast.Load:
433 # resolve the value
434 resolved = self.visit(value)
436 # try to get the value to see if we are another expression
437 try:
438 resolved = resolved.value
439 except (AttributeError):
440 pass
442 try:
443 return self.term_type(getattr(resolved, attr), self.env)
444 except AttributeError:
446 # something like datetime.datetime where scope is overridden
447 if isinstance(value, ast.Name) and value.id == attr:
448 return resolved
450 raise ValueError(f"Invalid Attribute context {ctx.__name__}")
452 def translate_In(self, op):
453 return ast.Eq() if isinstance(op, ast.In) else op
455 def _rewrite_membership_op(self, node, left, right):
456 return self.visit(node.op), node.op, left, right
459def _validate_where(w):
460 """
461 Validate that the where statement is of the right type.
463 The type may either be String, Expr, or list-like of Exprs.
465 Parameters
466 ----------
467 w : String term expression, Expr, or list-like of Exprs.
469 Returns
470 -------
471 where : The original where clause if the check was successful.
473 Raises
474 ------
475 TypeError : An invalid data type was passed in for w (e.g. dict).
476 """
478 if not (isinstance(w, (PyTablesExpr, str)) or is_list_like(w)):
479 raise TypeError(
480 "where must be passed as a string, PyTablesExpr, "
481 "or list-like of PyTablesExpr"
482 )
484 return w
487class PyTablesExpr(expr.Expr):
488 """
489 Hold a pytables-like expression, comprised of possibly multiple 'terms'.
491 Parameters
492 ----------
493 where : string term expression, PyTablesExpr, or list-like of PyTablesExprs
494 queryables : a "kinds" map (dict of column name -> kind), or None if column
495 is non-indexable
496 encoding : an encoding that will encode the query terms
498 Returns
499 -------
500 a PyTablesExpr object
502 Examples
503 --------
505 'index>=date'
506 "columns=['A', 'D']"
507 'columns=A'
508 'columns==A'
509 "~(columns=['A','B'])"
510 'index>df.index[3] & string="bar"'
511 '(index>df.index[3] & index<=df.index[6]) | string="bar"'
512 "ts>=Timestamp('2012-02-01')"
513 "major_axis>=20130101"
514 """
516 _visitor: Optional[PyTablesExprVisitor]
517 env: PyTablesScope
519 def __init__(
520 self,
521 where,
522 queryables: Optional[Dict[str, Any]] = None,
523 encoding=None,
524 scope_level: int = 0,
525 ):
527 where = _validate_where(where)
529 self.encoding = encoding
530 self.condition = None
531 self.filter = None
532 self.terms = None
533 self._visitor = None
535 # capture the environment if needed
536 local_dict: DeepChainMap[Any, Any] = DeepChainMap()
538 if isinstance(where, PyTablesExpr):
539 local_dict = where.env.scope
540 _where = where.expr
542 elif isinstance(where, (list, tuple)):
543 where = list(where)
544 for idx, w in enumerate(where):
545 if isinstance(w, PyTablesExpr):
546 local_dict = w.env.scope
547 else:
548 w = _validate_where(w)
549 where[idx] = w
550 _where = " & ".join((f"({w})" for w in com.flatten(where)))
551 else:
552 _where = where
554 self.expr = _where
555 self.env = PyTablesScope(scope_level + 1, local_dict=local_dict)
557 if queryables is not None and isinstance(self.expr, str):
558 self.env.queryables.update(queryables)
559 self._visitor = PyTablesExprVisitor(
560 self.env,
561 queryables=queryables,
562 parser="pytables",
563 engine="pytables",
564 encoding=encoding,
565 )
566 self.terms = self.parse()
568 def __repr__(self) -> str:
569 if self.terms is not None:
570 return pprint_thing(self.terms)
571 return pprint_thing(self.expr)
573 def evaluate(self):
574 """ create and return the numexpr condition and filter """
576 try:
577 self.condition = self.terms.prune(ConditionBinOp)
578 except AttributeError:
579 raise ValueError(
580 f"cannot process expression [{self.expr}], [{self}] "
581 "is not a valid condition"
582 )
583 try:
584 self.filter = self.terms.prune(FilterBinOp)
585 except AttributeError:
586 raise ValueError(
587 f"cannot process expression [{self.expr}], [{self}] "
588 "is not a valid filter"
589 )
591 return self.condition, self.filter
594class TermValue:
595 """ hold a term value the we use to construct a condition/filter """
597 def __init__(self, value, converted, kind: str):
598 assert isinstance(kind, str), kind
599 self.value = value
600 self.converted = converted
601 self.kind = kind
603 def tostring(self, encoding) -> str:
604 """ quote the string if not encoded
605 else encode and return """
606 if self.kind == "string":
607 if encoding is not None:
608 return str(self.converted)
609 return f'"{self.converted}"'
610 elif self.kind == "float":
611 # python 2 str(float) is not always
612 # round-trippable so use repr()
613 return repr(self.converted)
614 return str(self.converted)
617def maybe_expression(s) -> bool:
618 """ loose checking if s is a pytables-acceptable expression """
619 if not isinstance(s, str):
620 return False
621 ops = PyTablesExprVisitor.binary_ops + PyTablesExprVisitor.unary_ops + ("=",)
623 # make sure we have an op at least
624 return any(op in s for op in ops)