Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/patsy/eval.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of Patsy
2# Copyright (C) 2011 Nathaniel Smith <njs@pobox.com>
3# See file LICENSE.txt for license information.
5# Utilities that require an over-intimate knowledge of Python's execution
6# environment.
8# NB: if you add any __future__ imports to this file then you'll have to
9# adjust the tests that deal with checking the caller's execution environment
10# for __future__ flags!
12# These are made available in the patsy.* namespace
13__all__ = ["EvalEnvironment", "EvalFactor"]
15import sys
16import __future__
17import inspect
18import tokenize
19import ast
20import numbers
21import six
22from patsy import PatsyError
23from patsy.util import PushbackAdapter, no_pickling, assert_no_pickling
24from patsy.tokens import (pretty_untokenize, normalize_token_spacing,
25 python_tokenize)
26from patsy.compat import call_and_wrap_exc
28def _all_future_flags():
29 flags = 0
30 for feature_name in __future__.all_feature_names:
31 feature = getattr(__future__, feature_name)
32 if feature.getMandatoryRelease() > sys.version_info:
33 flags |= feature.compiler_flag
34 return flags
36_ALL_FUTURE_FLAGS = _all_future_flags()
38# This is just a minimal dict-like object that does lookup in a 'stack' of
39# dicts -- first it checks the first, then the second, etc. Assignments go
40# into an internal, zeroth dict.
41class VarLookupDict(object):
42 def __init__(self, dicts):
43 self._dicts = [{}] + list(dicts)
45 def __getitem__(self, key):
46 for d in self._dicts:
47 try:
48 return d[key]
49 except KeyError:
50 pass
51 raise KeyError(key)
53 def __setitem__(self, key, value):
54 self._dicts[0][key] = value
56 def __contains__(self, key):
57 try:
58 self[key]
59 except KeyError:
60 return False
61 else:
62 return True
64 def get(self, key, default=None):
65 try:
66 return self[key]
67 except KeyError:
68 return default
70 def __repr__(self):
71 return "%s(%r)" % (self.__class__.__name__, self._dicts)
73 __getstate__ = no_pickling
76def test_VarLookupDict():
77 d1 = {"a": 1}
78 d2 = {"a": 2, "b": 3}
79 ds = VarLookupDict([d1, d2])
80 assert ds["a"] == 1
81 assert ds["b"] == 3
82 assert "a" in ds
83 assert "c" not in ds
84 from nose.tools import assert_raises
85 assert_raises(KeyError, ds.__getitem__, "c")
86 ds["a"] = 10
87 assert ds["a"] == 10
88 assert d1["a"] == 1
89 assert ds.get("c") is None
90 assert isinstance(repr(ds), six.string_types)
92 assert_no_pickling(ds)
94def ast_names(code):
95 """Iterator that yields all the (ast) names in a Python expression.
97 :arg code: A string containing a Python expression.
98 """
99 # Syntax that allows new name bindings to be introduced is tricky to
100 # handle here, so we just refuse to do so.
101 disallowed_ast_nodes = (ast.Lambda, ast.ListComp, ast.GeneratorExp)
102 if sys.version_info >= (2, 7):
103 disallowed_ast_nodes += (ast.DictComp, ast.SetComp)
105 for node in ast.walk(ast.parse(code)):
106 if isinstance(node, disallowed_ast_nodes):
107 raise PatsyError("Lambda, list/dict/set comprehension, generator "
108 "expression in patsy formula not currently supported.")
109 if isinstance(node, ast.Name):
110 yield node.id
112def test_ast_names():
113 test_data = [('np.log(x)', ['np', 'x']),
114 ('x', ['x']),
115 ('center(x + 1)', ['center', 'x']),
116 ('dt.date.dt.month', ['dt'])]
117 for code, expected in test_data:
118 assert set(ast_names(code)) == set(expected)
120def test_ast_names_disallowed_nodes():
121 from nose.tools import assert_raises
122 def list_ast_names(code):
123 return list(ast_names(code))
124 assert_raises(PatsyError, list_ast_names, "lambda x: x + y")
125 assert_raises(PatsyError, list_ast_names, "[x + 1 for x in range(10)]")
126 assert_raises(PatsyError, list_ast_names, "(x + 1 for x in range(10))")
127 if sys.version_info >= (2, 7):
128 assert_raises(PatsyError, list_ast_names, "{x: True for x in range(10)}")
129 assert_raises(PatsyError, list_ast_names, "{x + 1 for x in range(10)}")
131class EvalEnvironment(object):
132 """Represents a Python execution environment.
134 Encapsulates a namespace for variable lookup and set of __future__
135 flags."""
136 def __init__(self, namespaces, flags=0):
137 assert not flags & ~_ALL_FUTURE_FLAGS
138 self._namespaces = list(namespaces)
139 self.flags = flags
141 @property
142 def namespace(self):
143 """A dict-like object that can be used to look up variables accessible
144 from the encapsulated environment."""
145 return VarLookupDict(self._namespaces)
147 def with_outer_namespace(self, outer_namespace):
148 """Return a new EvalEnvironment with an extra namespace added.
150 This namespace will be used only for variables that are not found in
151 any existing namespace, i.e., it is "outside" them all."""
152 return self.__class__(self._namespaces + [outer_namespace],
153 self.flags)
155 def eval(self, expr, source_name="<string>", inner_namespace={}):
156 """Evaluate some Python code in the encapsulated environment.
158 :arg expr: A string containing a Python expression.
159 :arg source_name: A name for this string, for use in tracebacks.
160 :arg inner_namespace: A dict-like object that will be checked first
161 when `expr` attempts to access any variables.
162 :returns: The value of `expr`.
163 """
164 code = compile(expr, source_name, "eval", self.flags, False)
165 return eval(code, {}, VarLookupDict([inner_namespace]
166 + self._namespaces))
168 @classmethod
169 def capture(cls, eval_env=0, reference=0):
170 """Capture an execution environment from the stack.
172 If `eval_env` is already an :class:`EvalEnvironment`, it is returned
173 unchanged. Otherwise, we walk up the stack by ``eval_env + reference``
174 steps and capture that function's evaluation environment.
176 For ``eval_env=0`` and ``reference=0``, the default, this captures the
177 stack frame of the function that calls :meth:`capture`. If ``eval_env
178 + reference`` is 1, then we capture that function's caller, etc.
180 This somewhat complicated calling convention is designed to be
181 convenient for functions which want to capture their caller's
182 environment by default, but also allow explicit environments to be
183 specified. See the second example.
185 Example::
187 x = 1
188 this_env = EvalEnvironment.capture()
189 assert this_env.namespace["x"] == 1
190 def child_func():
191 return EvalEnvironment.capture(1)
192 this_env_from_child = child_func()
193 assert this_env_from_child.namespace["x"] == 1
195 Example::
197 # This function can be used like:
198 # my_model(formula_like, data)
199 # -> evaluates formula_like in caller's environment
200 # my_model(formula_like, data, eval_env=1)
201 # -> evaluates formula_like in caller's caller's environment
202 # my_model(formula_like, data, eval_env=my_env)
203 # -> evaluates formula_like in environment 'my_env'
204 def my_model(formula_like, data, eval_env=0):
205 eval_env = EvalEnvironment.capture(eval_env, reference=1)
206 return model_setup_helper(formula_like, data, eval_env)
208 This is how :func:`dmatrix` works.
210 .. versionadded: 0.2.0
211 The ``reference`` argument.
212 """
213 if isinstance(eval_env, cls):
214 return eval_env
215 elif isinstance(eval_env, numbers.Integral):
216 depth = eval_env + reference
217 else:
218 raise TypeError("Parameter 'eval_env' must be either an integer "
219 "or an instance of patsy.EvalEnvironment.")
220 frame = inspect.currentframe()
221 try:
222 for i in range(depth + 1):
223 if frame is None:
224 raise ValueError("call-stack is not that deep!")
225 frame = frame.f_back
226 return cls([frame.f_locals, frame.f_globals],
227 frame.f_code.co_flags & _ALL_FUTURE_FLAGS)
228 # The try/finally is important to avoid a potential reference cycle --
229 # any exception traceback will carry a reference to *our* frame, which
230 # contains a reference to our local variables, which would otherwise
231 # carry a reference to some parent frame, where the exception was
232 # caught...:
233 finally:
234 del frame
236 def subset(self, names):
237 """Creates a new, flat EvalEnvironment that contains only
238 the variables specified."""
239 vld = VarLookupDict(self._namespaces)
240 new_ns = dict((name, vld[name]) for name in names)
241 return EvalEnvironment([new_ns], self.flags)
243 def _namespace_ids(self):
244 return [id(n) for n in self._namespaces]
246 def __eq__(self, other):
247 return (isinstance(other, EvalEnvironment)
248 and self.flags == other.flags
249 and self._namespace_ids() == other._namespace_ids())
251 def __ne__(self, other):
252 return not self == other
254 def __hash__(self):
255 return hash((EvalEnvironment,
256 self.flags,
257 tuple(self._namespace_ids())))
259 __getstate__ = no_pickling
261def _a(): # pragma: no cover
262 _a = 1
263 return _b()
265def _b(): # pragma: no cover
266 _b = 1
267 return _c()
269def _c(): # pragma: no cover
270 _c = 1
271 return [EvalEnvironment.capture(),
272 EvalEnvironment.capture(0),
273 EvalEnvironment.capture(1),
274 EvalEnvironment.capture(0, reference=1),
275 EvalEnvironment.capture(2),
276 EvalEnvironment.capture(0, 2),
277 ]
279def test_EvalEnvironment_capture_namespace():
280 c0, c, b1, b2, a1, a2 = _a()
281 assert "test_EvalEnvironment_capture_namespace" in c0.namespace
282 assert "test_EvalEnvironment_capture_namespace" in c.namespace
283 assert "test_EvalEnvironment_capture_namespace" in b1.namespace
284 assert "test_EvalEnvironment_capture_namespace" in b2.namespace
285 assert "test_EvalEnvironment_capture_namespace" in a1.namespace
286 assert "test_EvalEnvironment_capture_namespace" in a2.namespace
287 assert c0.namespace["_c"] == 1
288 assert c.namespace["_c"] == 1
289 assert b1.namespace["_b"] == 1
290 assert b2.namespace["_b"] == 1
291 assert a1.namespace["_a"] == 1
292 assert a2.namespace["_a"] == 1
293 assert b1.namespace["_c"] is _c
294 assert b2.namespace["_c"] is _c
295 from nose.tools import assert_raises
296 assert_raises(ValueError, EvalEnvironment.capture, 10 ** 6)
298 assert EvalEnvironment.capture(b1) is b1
300 assert_raises(TypeError, EvalEnvironment.capture, 1.2)
302 assert_no_pickling(EvalEnvironment.capture())
304def test_EvalEnvironment_capture_flags():
305 if sys.version_info >= (3,):
306 # This is the only __future__ feature currently usable in Python
307 # 3... fortunately it is probably not going anywhere.
308 TEST_FEATURE = "barry_as_FLUFL"
309 else:
310 TEST_FEATURE = "division"
311 test_flag = getattr(__future__, TEST_FEATURE).compiler_flag
312 assert test_flag & _ALL_FUTURE_FLAGS
313 source = ("def f():\n"
314 " in_f = 'hi from f'\n"
315 " global RETURN_INNER, RETURN_OUTER, RETURN_INNER_FROM_OUTER\n"
316 " RETURN_INNER = EvalEnvironment.capture(0)\n"
317 " RETURN_OUTER = call_capture_0()\n"
318 " RETURN_INNER_FROM_OUTER = call_capture_1()\n"
319 "f()\n")
320 code = compile(source, "<test string>", "exec", 0, 1)
321 env = {"EvalEnvironment": EvalEnvironment,
322 "call_capture_0": lambda: EvalEnvironment.capture(0),
323 "call_capture_1": lambda: EvalEnvironment.capture(1),
324 }
325 env2 = dict(env)
326 six.exec_(code, env)
327 assert env["RETURN_INNER"].namespace["in_f"] == "hi from f"
328 assert env["RETURN_INNER_FROM_OUTER"].namespace["in_f"] == "hi from f"
329 assert "in_f" not in env["RETURN_OUTER"].namespace
330 assert env["RETURN_INNER"].flags & _ALL_FUTURE_FLAGS == 0
331 assert env["RETURN_OUTER"].flags & _ALL_FUTURE_FLAGS == 0
332 assert env["RETURN_INNER_FROM_OUTER"].flags & _ALL_FUTURE_FLAGS == 0
334 code2 = compile(("from __future__ import %s\n" % (TEST_FEATURE,))
335 + source,
336 "<test string 2>", "exec", 0, 1)
337 six.exec_(code2, env2)
338 assert env2["RETURN_INNER"].namespace["in_f"] == "hi from f"
339 assert env2["RETURN_INNER_FROM_OUTER"].namespace["in_f"] == "hi from f"
340 assert "in_f" not in env2["RETURN_OUTER"].namespace
341 assert env2["RETURN_INNER"].flags & _ALL_FUTURE_FLAGS == test_flag
342 assert env2["RETURN_OUTER"].flags & _ALL_FUTURE_FLAGS == 0
343 assert env2["RETURN_INNER_FROM_OUTER"].flags & _ALL_FUTURE_FLAGS == test_flag
345def test_EvalEnvironment_eval_namespace():
346 env = EvalEnvironment([{"a": 1}])
347 assert env.eval("2 * a") == 2
348 assert env.eval("2 * a", inner_namespace={"a": 2}) == 4
349 from nose.tools import assert_raises
350 assert_raises(NameError, env.eval, "2 * b")
351 a = 3
352 env2 = EvalEnvironment.capture(0)
353 assert env2.eval("2 * a") == 6
355 env3 = env.with_outer_namespace({"a": 10, "b": 3})
356 assert env3.eval("2 * a") == 2
357 assert env3.eval("2 * b") == 6
359def test_EvalEnvironment_eval_flags():
360 from nose.tools import assert_raises
361 if sys.version_info >= (3,):
362 # This joke __future__ statement replaces "!=" with "<>":
363 # http://www.python.org/dev/peps/pep-0401/
364 test_flag = __future__.barry_as_FLUFL.compiler_flag
365 assert test_flag & _ALL_FUTURE_FLAGS
367 env = EvalEnvironment([{"a": 11}], flags=0)
368 assert env.eval("a != 0") == True
369 assert_raises(SyntaxError, env.eval, "a <> 0")
370 assert env.subset(["a"]).flags == 0
371 assert env.with_outer_namespace({"b": 10}).flags == 0
373 env2 = EvalEnvironment([{"a": 11}], flags=test_flag)
374 assert env2.eval("a <> 0") == True
375 assert_raises(SyntaxError, env2.eval, "a != 0")
376 assert env2.subset(["a"]).flags == test_flag
377 assert env2.with_outer_namespace({"b": 10}).flags == test_flag
378 else:
379 test_flag = __future__.division.compiler_flag
380 assert test_flag & _ALL_FUTURE_FLAGS
382 env = EvalEnvironment([{"a": 11}], flags=0)
383 assert env.eval("a / 2") == 11 // 2 == 5
384 assert env.subset(["a"]).flags == 0
385 assert env.with_outer_namespace({"b": 10}).flags == 0
387 env2 = EvalEnvironment([{"a": 11}], flags=test_flag)
388 assert env2.eval("a / 2") == 11 * 1. / 2 != 5
389 env2.subset(["a"]).flags == test_flag
390 assert env2.with_outer_namespace({"b": 10}).flags == test_flag
392def test_EvalEnvironment_subset():
393 env = EvalEnvironment([{"a": 1}, {"b": 2}, {"c": 3}])
395 subset_a = env.subset(["a"])
396 assert subset_a.eval("a") == 1
397 from nose.tools import assert_raises
398 assert_raises(NameError, subset_a.eval, "b")
399 assert_raises(NameError, subset_a.eval, "c")
401 subset_bc = env.subset(["b", "c"])
402 assert subset_bc.eval("b * c") == 6
403 assert_raises(NameError, subset_bc.eval, "a")
405def test_EvalEnvironment_eq():
406 # Two environments are eq only if they refer to exactly the same
407 # global/local dicts
408 env1 = EvalEnvironment.capture(0)
409 env2 = EvalEnvironment.capture(0)
410 assert env1 == env2
411 assert hash(env1) == hash(env2)
412 capture_local_env = lambda: EvalEnvironment.capture(0)
413 env3 = capture_local_env()
414 env4 = capture_local_env()
415 assert env3 != env4
417_builtins_dict = {}
418six.exec_("from patsy.builtins import *", {}, _builtins_dict)
419# This is purely to make the existence of patsy.builtins visible to systems
420# like py2app and py2exe. It's basically free, since the above line guarantees
421# that patsy.builtins will be present in sys.modules in any case.
422import patsy.builtins
424class EvalFactor(object):
425 def __init__(self, code, origin=None):
426 """A factor class that executes arbitrary Python code and supports
427 stateful transforms.
429 :arg code: A string containing a Python expression, that will be
430 evaluated to produce this factor's value.
432 This is the standard factor class that is used when parsing formula
433 strings and implements the standard stateful transform processing. See
434 :ref:`stateful-transforms` and :ref:`expert-model-specification`.
436 Two EvalFactor's are considered equal (e.g., for purposes of
437 redundancy detection) if they contain the same token stream. Basically
438 this means that the source code must be identical except for
439 whitespace::
441 assert EvalFactor("a + b") == EvalFactor("a+b")
442 assert EvalFactor("a + b") != EvalFactor("b + a")
443 """
445 # For parsed formulas, the code will already have been normalized by
446 # the parser. But let's normalize anyway, so we can be sure of having
447 # consistent semantics for __eq__ and __hash__.
448 self.code = normalize_token_spacing(code)
449 self.origin = origin
451 def name(self):
452 return self.code
454 def __repr__(self):
455 return "%s(%r)" % (self.__class__.__name__, self.code)
457 def __eq__(self, other):
458 return (isinstance(other, EvalFactor)
459 and self.code == other.code)
461 def __ne__(self, other):
462 return not self == other
464 def __hash__(self):
465 return hash((EvalFactor, self.code))
467 def memorize_passes_needed(self, state, eval_env):
468 # 'state' is just an empty dict which we can do whatever we want with,
469 # and that will be passed back to later memorize functions
470 state["transforms"] = {}
472 eval_env = eval_env.with_outer_namespace(_builtins_dict)
473 env_namespace = eval_env.namespace
474 subset_names = [name for name in ast_names(self.code)
475 if name in env_namespace]
476 eval_env = eval_env.subset(subset_names)
477 state["eval_env"] = eval_env
479 # example code: == "2 * center(x)"
480 i = [0]
481 def new_name_maker(token):
482 value = eval_env.namespace.get(token)
483 if hasattr(value, "__patsy_stateful_transform__"):
484 obj_name = "_patsy_stobj%s__%s__" % (i[0], token)
485 i[0] += 1
486 obj = value.__patsy_stateful_transform__()
487 state["transforms"][obj_name] = obj
488 return obj_name + ".transform"
489 else:
490 return token
491 # example eval_code: == "2 * _patsy_stobj0__center__.transform(x)"
492 eval_code = replace_bare_funcalls(self.code, new_name_maker)
493 state["eval_code"] = eval_code
494 # paranoia: verify that none of our new names appeared anywhere in the
495 # original code
496 if has_bare_variable_reference(state["transforms"], self.code):
497 raise PatsyError("names of this form are reserved for "
498 "internal use (%s)" % (token,), token.origin)
499 # Pull out all the '_patsy_stobj0__center__.transform(x)' pieces
500 # to make '_patsy_stobj0__center__.memorize_chunk(x)' pieces
501 state["memorize_code"] = {}
502 for obj_name in state["transforms"]:
503 transform_calls = capture_obj_method_calls(obj_name, eval_code)
504 assert len(transform_calls) == 1
505 transform_call = transform_calls[0]
506 transform_call_name, transform_call_code = transform_call
507 assert transform_call_name == obj_name + ".transform"
508 assert transform_call_code.startswith(transform_call_name + "(")
509 memorize_code = (obj_name
510 + ".memorize_chunk"
511 + transform_call_code[len(transform_call_name):])
512 state["memorize_code"][obj_name] = memorize_code
513 # Then sort the codes into bins, so that every item in bin number i
514 # depends only on items in bin (i-1) or less. (By 'depends', we mean
515 # that in something like:
516 # spline(center(x))
517 # we have to first run:
518 # center.memorize_chunk(x)
519 # then
520 # center.memorize_finish(x)
521 # and only then can we run:
522 # spline.memorize_chunk(center.transform(x))
523 # Since all of our objects have unique names, figuring out who
524 # depends on who is pretty easy -- we just check whether the
525 # memorization code for spline:
526 # spline.memorize_chunk(center.transform(x))
527 # mentions the variable 'center' (which in the example, of course, it
528 # does).
529 pass_bins = []
530 unsorted = set(state["transforms"])
531 while unsorted:
532 pass_bin = set()
533 for obj_name in unsorted:
534 other_objs = unsorted.difference([obj_name])
535 memorize_code = state["memorize_code"][obj_name]
536 if not has_bare_variable_reference(other_objs, memorize_code):
537 pass_bin.add(obj_name)
538 assert pass_bin
539 unsorted.difference_update(pass_bin)
540 pass_bins.append(pass_bin)
541 state["pass_bins"] = pass_bins
543 return len(pass_bins)
545 def _eval(self, code, memorize_state, data):
546 inner_namespace = VarLookupDict([data, memorize_state["transforms"]])
547 return call_and_wrap_exc("Error evaluating factor",
548 self,
549 memorize_state["eval_env"].eval,
550 code,
551 inner_namespace=inner_namespace)
553 def memorize_chunk(self, state, which_pass, data):
554 for obj_name in state["pass_bins"][which_pass]:
555 self._eval(state["memorize_code"][obj_name],
556 state,
557 data)
559 def memorize_finish(self, state, which_pass):
560 for obj_name in state["pass_bins"][which_pass]:
561 state["transforms"][obj_name].memorize_finish()
563 def eval(self, memorize_state, data):
564 return self._eval(memorize_state["eval_code"],
565 memorize_state,
566 data)
568 __getstate__ = no_pickling
570def test_EvalFactor_basics():
571 e = EvalFactor("a+b")
572 assert e.code == "a + b"
573 assert e.name() == "a + b"
574 e2 = EvalFactor("a +b", origin="asdf")
575 assert e == e2
576 assert hash(e) == hash(e2)
577 assert e.origin is None
578 assert e2.origin == "asdf"
580 assert_no_pickling(e)
582def test_EvalFactor_memorize_passes_needed():
583 from patsy.state import stateful_transform
584 foo = stateful_transform(lambda: "FOO-OBJ")
585 bar = stateful_transform(lambda: "BAR-OBJ")
586 quux = stateful_transform(lambda: "QUUX-OBJ")
587 e = EvalFactor("foo(x) + bar(foo(y)) + quux(z, w)")
589 state = {}
590 eval_env = EvalEnvironment.capture(0)
591 passes = e.memorize_passes_needed(state, eval_env)
592 print(passes)
593 print(state)
594 assert passes == 2
595 for name in ["foo", "bar", "quux"]:
596 assert state["eval_env"].namespace[name] is locals()[name]
597 for name in ["w", "x", "y", "z", "e", "state"]:
598 assert name not in state["eval_env"].namespace
599 assert state["transforms"] == {"_patsy_stobj0__foo__": "FOO-OBJ",
600 "_patsy_stobj1__bar__": "BAR-OBJ",
601 "_patsy_stobj2__foo__": "FOO-OBJ",
602 "_patsy_stobj3__quux__": "QUUX-OBJ"}
603 assert (state["eval_code"]
604 == "_patsy_stobj0__foo__.transform(x)"
605 " + _patsy_stobj1__bar__.transform("
606 "_patsy_stobj2__foo__.transform(y))"
607 " + _patsy_stobj3__quux__.transform(z, w)")
609 assert (state["memorize_code"]
610 == {"_patsy_stobj0__foo__":
611 "_patsy_stobj0__foo__.memorize_chunk(x)",
612 "_patsy_stobj1__bar__":
613 "_patsy_stobj1__bar__.memorize_chunk(_patsy_stobj2__foo__.transform(y))",
614 "_patsy_stobj2__foo__":
615 "_patsy_stobj2__foo__.memorize_chunk(y)",
616 "_patsy_stobj3__quux__":
617 "_patsy_stobj3__quux__.memorize_chunk(z, w)",
618 })
619 assert state["pass_bins"] == [set(["_patsy_stobj0__foo__",
620 "_patsy_stobj2__foo__",
621 "_patsy_stobj3__quux__"]),
622 set(["_patsy_stobj1__bar__"])]
624class _MockTransform(object):
625 # Adds up all memorized data, then subtracts that sum from each datum
626 def __init__(self):
627 self._sum = 0
628 self._memorize_chunk_called = 0
629 self._memorize_finish_called = 0
631 def memorize_chunk(self, data):
632 self._memorize_chunk_called += 1
633 import numpy as np
634 self._sum += np.sum(data)
636 def memorize_finish(self):
637 self._memorize_finish_called += 1
639 def transform(self, data):
640 return data - self._sum
642def test_EvalFactor_end_to_end():
643 from patsy.state import stateful_transform
644 foo = stateful_transform(_MockTransform)
645 e = EvalFactor("foo(x) + foo(foo(y))")
646 state = {}
647 eval_env = EvalEnvironment.capture(0)
648 passes = e.memorize_passes_needed(state, eval_env)
649 print(passes)
650 print(state)
651 assert passes == 2
652 assert state["eval_env"].namespace["foo"] is foo
653 for name in ["x", "y", "e", "state"]:
654 assert name not in state["eval_env"].namespace
655 import numpy as np
656 e.memorize_chunk(state, 0,
657 {"x": np.array([1, 2]),
658 "y": np.array([10, 11])})
659 assert state["transforms"]["_patsy_stobj0__foo__"]._memorize_chunk_called == 1
660 assert state["transforms"]["_patsy_stobj2__foo__"]._memorize_chunk_called == 1
661 e.memorize_chunk(state, 0, {"x": np.array([12, -10]),
662 "y": np.array([100, 3])})
663 assert state["transforms"]["_patsy_stobj0__foo__"]._memorize_chunk_called == 2
664 assert state["transforms"]["_patsy_stobj2__foo__"]._memorize_chunk_called == 2
665 assert state["transforms"]["_patsy_stobj0__foo__"]._memorize_finish_called == 0
666 assert state["transforms"]["_patsy_stobj2__foo__"]._memorize_finish_called == 0
667 e.memorize_finish(state, 0)
668 assert state["transforms"]["_patsy_stobj0__foo__"]._memorize_finish_called == 1
669 assert state["transforms"]["_patsy_stobj2__foo__"]._memorize_finish_called == 1
670 assert state["transforms"]["_patsy_stobj1__foo__"]._memorize_chunk_called == 0
671 assert state["transforms"]["_patsy_stobj1__foo__"]._memorize_finish_called == 0
672 e.memorize_chunk(state, 1, {"x": np.array([1, 2]),
673 "y": np.array([10, 11])})
674 e.memorize_chunk(state, 1, {"x": np.array([12, -10]),
675 "y": np.array([100, 3])})
676 e.memorize_finish(state, 1)
677 for transform in six.itervalues(state["transforms"]):
678 assert transform._memorize_chunk_called == 2
679 assert transform._memorize_finish_called == 1
680 # sums:
681 # 0: 1 + 2 + 12 + -10 == 5
682 # 2: 10 + 11 + 100 + 3 == 124
683 # 1: (10 - 124) + (11 - 124) + (100 - 124) + (3 - 124) == -372
684 # results:
685 # 0: -4, -3, 7, -15
686 # 2: -114, -113, -24, -121
687 # 1: 258, 259, 348, 251
688 # 0 + 1: 254, 256, 355, 236
689 assert np.all(e.eval(state,
690 {"x": np.array([1, 2, 12, -10]),
691 "y": np.array([10, 11, 100, 3])})
692 == [254, 256, 355, 236])
694def annotated_tokens(code):
695 prev_was_dot = False
696 it = PushbackAdapter(python_tokenize(code))
697 for (token_type, token, origin) in it:
698 props = {}
699 props["bare_ref"] = (not prev_was_dot and token_type == tokenize.NAME)
700 props["bare_funcall"] = (props["bare_ref"]
701 and it.has_more() and it.peek()[1] == "(")
702 yield (token_type, token, origin, props)
703 prev_was_dot = (token == ".")
705def test_annotated_tokens():
706 tokens_without_origins = [(token_type, token, props)
707 for (token_type, token, origin, props)
708 in (annotated_tokens("a(b) + c.d"))]
709 assert (tokens_without_origins
710 == [(tokenize.NAME, "a", {"bare_ref": True, "bare_funcall": True}),
711 (tokenize.OP, "(", {"bare_ref": False, "bare_funcall": False}),
712 (tokenize.NAME, "b", {"bare_ref": True, "bare_funcall": False}),
713 (tokenize.OP, ")", {"bare_ref": False, "bare_funcall": False}),
714 (tokenize.OP, "+", {"bare_ref": False, "bare_funcall": False}),
715 (tokenize.NAME, "c", {"bare_ref": True, "bare_funcall": False}),
716 (tokenize.OP, ".", {"bare_ref": False, "bare_funcall": False}),
717 (tokenize.NAME, "d",
718 {"bare_ref": False, "bare_funcall": False}),
719 ])
721 # This was a bug:
722 assert len(list(annotated_tokens("x"))) == 1
724def has_bare_variable_reference(names, code):
725 for (_, token, _, props) in annotated_tokens(code):
726 if props["bare_ref"] and token in names:
727 return True
728 return False
730def replace_bare_funcalls(code, replacer):
731 tokens = []
732 for (token_type, token, origin, props) in annotated_tokens(code):
733 if props["bare_ref"] and props["bare_funcall"]:
734 token = replacer(token)
735 tokens.append((token_type, token))
736 return pretty_untokenize(tokens)
738def test_replace_bare_funcalls():
739 def replacer1(token):
740 return {"a": "b", "foo": "_internal.foo.process"}.get(token, token)
741 def t1(code, expected):
742 replaced = replace_bare_funcalls(code, replacer1)
743 print("%r -> %r" % (code, replaced))
744 print("(wanted %r)" % (expected,))
745 assert replaced == expected
746 t1("foobar()", "foobar()")
747 t1("a()", "b()")
748 t1("foobar.a()", "foobar.a()")
749 t1("foo()", "_internal.foo.process()")
750 t1("a + 1", "a + 1")
751 t1("b() + a() * x[foo(2 ** 3)]",
752 "b() + b() * x[_internal.foo.process(2 ** 3)]")
754class _FuncallCapturer(object):
755 # captures the next funcall
756 def __init__(self, start_token_type, start_token):
757 self.func = [start_token]
758 self.tokens = [(start_token_type, start_token)]
759 self.paren_depth = 0
760 self.started = False
761 self.done = False
763 def add_token(self, token_type, token):
764 if self.done:
765 return
766 self.tokens.append((token_type, token))
767 if token in ["(", "{", "["]:
768 self.paren_depth += 1
769 if token in [")", "}", "]"]:
770 self.paren_depth -= 1
771 assert self.paren_depth >= 0
772 if not self.started:
773 if token == "(":
774 self.started = True
775 else:
776 assert token_type == tokenize.NAME or token == "."
777 self.func.append(token)
778 if self.started and self.paren_depth == 0:
779 self.done = True
781# This is not a very general function -- it assumes that all references to the
782# given object are of the form '<obj_name>.something(method call)'.
783def capture_obj_method_calls(obj_name, code):
784 capturers = []
785 for (token_type, token, origin, props) in annotated_tokens(code):
786 for capturer in capturers:
787 capturer.add_token(token_type, token)
788 if props["bare_ref"] and token == obj_name:
789 capturers.append(_FuncallCapturer(token_type, token))
790 return [("".join(capturer.func), pretty_untokenize(capturer.tokens))
791 for capturer in capturers]
793def test_capture_obj_method_calls():
794 assert (capture_obj_method_calls("foo", "a + foo.baz(bar) + b.c(d)")
795 == [("foo.baz", "foo.baz(bar)")])
796 assert (capture_obj_method_calls("b", "a + foo.baz(bar) + b.c(d)")
797 == [("b.c", "b.c(d)")])
798 assert (capture_obj_method_calls("foo", "foo.bar(foo.baz(quux))")
799 == [("foo.bar", "foo.bar(foo.baz(quux))"),
800 ("foo.baz", "foo.baz(quux)")])
801 assert (capture_obj_method_calls("bar", "foo[bar.baz(x(z[asdf])) ** 2]")
802 == [("bar.baz", "bar.baz(x(z[asdf]))")])