Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/patsy/design_info.py : 10%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of Patsy
2# Copyright (C) 2011-2015 Nathaniel Smith <njs@pobox.com>
3# See file LICENSE.txt for license information.
5# This file defines the main class for storing metadata about a model
6# design. It also defines a 'value-added' design matrix type -- a subclass of
7# ndarray that represents a design matrix and holds metadata about its
8# columns. The intent is that these are useful and usable data structures
9# even if you're not using *any* of the rest of patsy to actually build
10# your matrices.
13# XX TMP TODO:
14#
15# - update design_matrix_builders and build_design_matrices docs
16# - add tests and docs for new design info stuff
17# - consider renaming design_matrix_builders (and I guess
18# build_design_matrices too). Ditto for highlevel dbuilder functions.
21from __future__ import print_function
23# These are made available in the patsy.* namespace
24__all__ = ["DesignInfo", "FactorInfo", "SubtermInfo", "DesignMatrix"]
26import warnings
27import numbers
28import six
29import numpy as np
30from patsy import PatsyError
31from patsy.util import atleast_2d_column_default
32from patsy.compat import OrderedDict
33from patsy.util import (repr_pretty_delegate, repr_pretty_impl,
34 safe_issubdtype,
35 no_pickling, assert_no_pickling)
36from patsy.constraint import linear_constraint
37from patsy.contrasts import ContrastMatrix
38from patsy.desc import ModelDesc, Term
40class FactorInfo(object):
41 """A FactorInfo object is a simple class that provides some metadata about
42 the role of a factor within a model. :attr:`DesignInfo.factor_infos` is
43 a dictionary which maps factor objects to FactorInfo objects for each
44 factor in the model.
46 .. versionadded:: 0.4.0
48 Attributes:
50 .. attribute:: factor
52 The factor object being described.
54 .. attribute:: type
56 The type of the factor -- either the string ``"numerical"`` or the
57 string ``"categorical"``.
59 .. attribute:: state
61 An opaque object which holds the state needed to evaluate this
62 factor on new data (e.g., for prediction). See
63 :meth:`factor_protocol.eval`.
65 .. attribute:: num_columns
67 For numerical factors, the number of columns this factor produces. For
68 categorical factors, this attribute will always be ``None``.
70 .. attribute:: categories
72 For categorical factors, a tuple of the possible categories this factor
73 takes on, in order. For numerical factors, this attribute will always be
74 ``None``.
75 """
77 def __init__(self, factor, type, state,
78 num_columns=None, categories=None):
79 self.factor = factor
80 self.type = type
81 if self.type not in ["numerical", "categorical"]:
82 raise ValueError("FactorInfo.type must be "
83 "'numerical' or 'categorical', not %r"
84 % (self.type,))
85 self.state = state
86 if self.type == "numerical":
87 if not isinstance(num_columns, six.integer_types):
88 raise ValueError("For numerical factors, num_columns "
89 "must be an integer")
90 if categories is not None:
91 raise ValueError("For numerical factors, categories "
92 "must be None")
93 else:
94 assert self.type == "categorical"
95 if num_columns is not None:
96 raise ValueError("For categorical factors, num_columns "
97 "must be None")
98 categories = tuple(categories)
99 self.num_columns = num_columns
100 self.categories = categories
102 __repr__ = repr_pretty_delegate
103 def _repr_pretty_(self, p, cycle):
104 assert not cycle
105 class FactorState(object):
106 def __repr__(self):
107 return "<factor state>"
108 kwlist = [("factor", self.factor),
109 ("type", self.type),
110 # Don't put the state in people's faces, it will
111 # just encourage them to pay attention to the
112 # contents :-). Plus it's a bunch of gobbledygook
113 # they don't care about. They can always look at
114 # self.state if they want to know...
115 ("state", FactorState()),
116 ]
117 if self.type == "numerical":
118 kwlist.append(("num_columns", self.num_columns))
119 else:
120 kwlist.append(("categories", self.categories))
121 repr_pretty_impl(p, self, [], kwlist)
123 __getstate__ = no_pickling
125def test_FactorInfo():
126 fi1 = FactorInfo("asdf", "numerical", {"a": 1}, num_columns=10)
127 assert fi1.factor == "asdf"
128 assert fi1.state == {"a": 1}
129 assert fi1.type == "numerical"
130 assert fi1.num_columns == 10
131 assert fi1.categories is None
133 # smoke test
134 repr(fi1)
136 fi2 = FactorInfo("asdf", "categorical", {"a": 2}, categories=["z", "j"])
137 assert fi2.factor == "asdf"
138 assert fi2.state == {"a": 2}
139 assert fi2.type == "categorical"
140 assert fi2.num_columns is None
141 assert fi2.categories == ("z", "j")
143 # smoke test
144 repr(fi2)
146 from nose.tools import assert_raises
147 assert_raises(ValueError, FactorInfo, "asdf", "non-numerical", {})
148 assert_raises(ValueError, FactorInfo, "asdf", "numerical", {})
150 assert_raises(ValueError, FactorInfo, "asdf", "numerical", {},
151 num_columns="asdf")
152 assert_raises(ValueError, FactorInfo, "asdf", "numerical", {},
153 num_columns=1, categories=1)
155 assert_raises(TypeError, FactorInfo, "asdf", "categorical", {})
156 assert_raises(ValueError, FactorInfo, "asdf", "categorical", {},
157 num_columns=1)
158 assert_raises(TypeError, FactorInfo, "asdf", "categorical", {},
159 categories=1)
161 # Make sure longs are legal for num_columns
162 # (Important on python2+win64, where array shapes are tuples-of-longs)
163 if not six.PY3:
164 fi_long = FactorInfo("asdf", "numerical", {"a": 1},
165 num_columns=long(10))
166 assert fi_long.num_columns == 10
168class SubtermInfo(object):
169 """A SubtermInfo object is a simple metadata container describing a single
170 primitive interaction and how it is coded in our design matrix. Our final
171 design matrix is produced by coding each primitive interaction in order
172 from left to right, and then stacking the resulting columns. For each
173 :class:`Term`, we have one or more of these objects which describe how
174 that term is encoded. :attr:`DesignInfo.term_codings` is a dictionary
175 which maps term objects to lists of SubtermInfo objects.
177 To code a primitive interaction, the following steps are performed:
179 * Evaluate each factor on the provided data.
180 * Encode each factor into one or more proto-columns. For numerical
181 factors, these proto-columns are identical to whatever the factor
182 evaluates to; for categorical factors, they are encoded using a
183 specified contrast matrix.
184 * Form all pairwise, elementwise products between proto-columns generated
185 by different factors. (For example, if factor 1 generated proto-columns
186 A and B, and factor 2 generated proto-columns C and D, then our final
187 columns are ``A * C``, ``B * C``, ``A * D``, ``B * D``.)
188 * The resulting columns are stored directly into the final design matrix.
190 Sometimes multiple primitive interactions are needed to encode a single
191 term; this occurs, for example, in the formula ``"1 + a:b"`` when ``a``
192 and ``b`` are categorical. See :ref:`formulas-building` for full details.
194 .. versionadded:: 0.4.0
196 Attributes:
198 .. attribute:: factors
200 The factors which appear in this subterm's interaction.
202 .. attribute:: contrast_matrices
204 A dict mapping factor objects to :class:`ContrastMatrix` objects,
205 describing how each categorical factor in this interaction is coded.
207 .. attribute:: num_columns
209 The number of design matrix columns which this interaction generates.
211 """
213 def __init__(self, factors, contrast_matrices, num_columns):
214 self.factors = tuple(factors)
215 factor_set = frozenset(factors)
216 if not isinstance(contrast_matrices, dict):
217 raise ValueError("contrast_matrices must be dict")
218 for factor, contrast_matrix in six.iteritems(contrast_matrices):
219 if factor not in factor_set:
220 raise ValueError("Unexpected factor in contrast_matrices dict")
221 if not isinstance(contrast_matrix, ContrastMatrix):
222 raise ValueError("Expected a ContrastMatrix, not %r"
223 % (contrast_matrix,))
224 self.contrast_matrices = contrast_matrices
225 if not isinstance(num_columns, six.integer_types):
226 raise ValueError("num_columns must be an integer")
227 self.num_columns = num_columns
229 __repr__ = repr_pretty_delegate
230 def _repr_pretty_(self, p, cycle):
231 assert not cycle
232 repr_pretty_impl(p, self, [],
233 [("factors", self.factors),
234 ("contrast_matrices", self.contrast_matrices),
235 ("num_columns", self.num_columns)])
237 __getstate__ = no_pickling
239def test_SubtermInfo():
240 cm = ContrastMatrix(np.ones((2, 2)), ["[1]", "[2]"])
241 s = SubtermInfo(["a", "x"], {"a": cm}, 4)
242 assert s.factors == ("a", "x")
243 assert s.contrast_matrices == {"a": cm}
244 assert s.num_columns == 4
246 # Make sure longs are accepted for num_columns
247 if not six.PY3:
248 s = SubtermInfo(["a", "x"], {"a": cm}, long(4))
249 assert s.num_columns == 4
251 # smoke test
252 repr(s)
254 from nose.tools import assert_raises
255 assert_raises(TypeError, SubtermInfo, 1, {}, 1)
256 assert_raises(ValueError, SubtermInfo, ["a", "x"], 1, 1)
257 assert_raises(ValueError, SubtermInfo, ["a", "x"], {"z": cm}, 1)
258 assert_raises(ValueError, SubtermInfo, ["a", "x"], {"a": 1}, 1)
259 assert_raises(ValueError, SubtermInfo, ["a", "x"], {}, 1.5)
261class DesignInfo(object):
262 """A DesignInfo object holds metadata about a design matrix.
264 This is the main object that Patsy uses to pass metadata about a design
265 matrix to statistical libraries, in order to allow further downstream
266 processing like intelligent tests, prediction on new data, etc. Usually
267 encountered as the `.design_info` attribute on design matrices.
269 """
271 def __init__(self, column_names,
272 factor_infos=None, term_codings=None):
273 self.column_name_indexes = OrderedDict(zip(column_names,
274 range(len(column_names))))
276 if (factor_infos is None) != (term_codings is None):
277 raise ValueError("Must specify either both or neither of "
278 "factor_infos= and term_codings=")
280 self.factor_infos = factor_infos
281 self.term_codings = term_codings
283 # factor_infos is a dict containing one entry for every factor
284 # mentioned in our terms
285 # and mapping each to FactorInfo object
286 if self.factor_infos is not None:
287 if not isinstance(self.factor_infos, dict):
288 raise ValueError("factor_infos should be a dict")
290 if not isinstance(self.term_codings, OrderedDict):
291 raise ValueError("term_codings must be an OrderedDict")
292 for term, subterms in six.iteritems(self.term_codings):
293 if not isinstance(term, Term):
294 raise ValueError("expected a Term, not %r" % (term,))
295 if not isinstance(subterms, list):
296 raise ValueError("term_codings must contain lists")
297 term_factors = set(term.factors)
298 for subterm in subterms:
299 if not isinstance(subterm, SubtermInfo):
300 raise ValueError("expected SubtermInfo, "
301 "not %r" % (subterm,))
302 if not term_factors.issuperset(subterm.factors):
303 raise ValueError("unexpected factors in subterm")
305 all_factors = set()
306 for term in self.term_codings:
307 all_factors.update(term.factors)
308 if all_factors != set(self.factor_infos):
309 raise ValueError("Provided Term objects and factor_infos "
310 "do not match")
311 for factor, factor_info in six.iteritems(self.factor_infos):
312 if not isinstance(factor_info, FactorInfo):
313 raise ValueError("expected FactorInfo object, not %r"
314 % (factor_info,))
315 if factor != factor_info.factor:
316 raise ValueError("mismatched factor_info.factor")
318 for term, subterms in six.iteritems(self.term_codings):
319 for subterm in subterms:
320 exp_cols = 1
321 cat_factors = set()
322 for factor in subterm.factors:
323 fi = self.factor_infos[factor]
324 if fi.type == "numerical":
325 exp_cols *= fi.num_columns
326 else:
327 assert fi.type == "categorical"
328 cm = subterm.contrast_matrices[factor].matrix
329 if cm.shape[0] != len(fi.categories):
330 raise ValueError("Mismatched contrast matrix "
331 "for factor %r" % (factor,))
332 cat_factors.add(factor)
333 exp_cols *= cm.shape[1]
334 if cat_factors != set(subterm.contrast_matrices):
335 raise ValueError("Mismatch between contrast_matrices "
336 "and categorical factors")
337 if exp_cols != subterm.num_columns:
338 raise ValueError("Unexpected num_columns")
340 if term_codings is None:
341 # Need to invent term information
342 self.term_slices = None
343 # We invent one term per column, with the same name as the column
344 term_names = column_names
345 slices = [slice(i, i + 1) for i in range(len(column_names))]
346 self.term_name_slices = OrderedDict(zip(term_names, slices))
347 else:
348 # Need to derive term information from term_codings
349 self.term_slices = OrderedDict()
350 idx = 0
351 for term, subterm_infos in six.iteritems(self.term_codings):
352 term_columns = 0
353 for subterm_info in subterm_infos:
354 term_columns += subterm_info.num_columns
355 self.term_slices[term] = slice(idx, idx + term_columns)
356 idx += term_columns
357 if idx != len(self.column_names):
358 raise ValueError("mismatch between column_names and columns "
359 "coded by given terms")
360 self.term_name_slices = OrderedDict(
361 [(term.name(), slice_)
362 for (term, slice_) in six.iteritems(self.term_slices)])
364 # Guarantees:
365 # term_name_slices is never None
366 # The slices in term_name_slices are in order and exactly cover the
367 # whole range of columns.
368 # term_slices may be None
369 # If term_slices is not None, then its slices match the ones in
370 # term_name_slices.
371 assert self.term_name_slices is not None
372 if self.term_slices is not None:
373 assert (list(self.term_slices.values())
374 == list(self.term_name_slices.values()))
375 # These checks probably aren't necessary anymore now that we always
376 # generate the slices ourselves, but we'll leave them in just to be
377 # safe.
378 covered = 0
379 for slice_ in six.itervalues(self.term_name_slices):
380 start, stop, step = slice_.indices(len(column_names))
381 assert start == covered
382 assert step == 1
383 covered = stop
384 assert covered == len(column_names)
385 # If there is any name overlap between terms and columns, they refer
386 # to the same columns.
387 for column_name, index in six.iteritems(self.column_name_indexes):
388 if column_name in self.term_name_slices:
389 slice_ = self.term_name_slices[column_name]
390 if slice_ != slice(index, index + 1):
391 raise ValueError("term/column name collision")
393 __repr__ = repr_pretty_delegate
394 def _repr_pretty_(self, p, cycle):
395 assert not cycle
396 repr_pretty_impl(p, self,
397 [self.column_names],
398 [("factor_infos", self.factor_infos),
399 ("term_codings", self.term_codings)])
401 @property
402 def column_names(self):
403 "A list of the column names, in order."
404 return list(self.column_name_indexes)
406 @property
407 def terms(self):
408 "A list of :class:`Terms`, in order, or else None."
409 if self.term_slices is None:
410 return None
411 return list(self.term_slices)
413 @property
414 def term_names(self):
415 "A list of terms, in order."
416 return list(self.term_name_slices)
418 @property
419 def builder(self):
420 ".. deprecated:: 0.4.0"
421 warnings.warn(DeprecationWarning(
422 "The DesignInfo.builder attribute is deprecated starting in "
423 "patsy v0.4.0; distinct builder objects have been eliminated "
424 "and design_info.builder is now just a long-winded way of "
425 "writing 'design_info' (i.e. the .builder attribute just "
426 "returns self)"), stacklevel=2)
427 return self
429 @property
430 def design_info(self):
431 ".. deprecated:: 0.4.0"
432 warnings.warn(DeprecationWarning(
433 "Starting in patsy v0.4.0, the DesignMatrixBuilder class has "
434 "been merged into the DesignInfo class. So there's no need to "
435 "use builder.design_info to access the DesignInfo; 'builder' "
436 "already *is* a DesignInfo."), stacklevel=2)
437 return self
439 def slice(self, columns_specifier):
440 """Locate a subset of design matrix columns, specified symbolically.
442 A patsy design matrix has two levels of structure: the individual
443 columns (which are named), and the :ref:`terms <formulas>` in
444 the formula that generated those columns. This is a one-to-many
445 relationship: a single term may span several columns. This method
446 provides a user-friendly API for locating those columns.
448 (While we talk about columns here, this is probably most useful for
449 indexing into other arrays that are derived from the design matrix,
450 such as regression coefficients or covariance matrices.)
452 The `columns_specifier` argument can take a number of forms:
454 * A term name
455 * A column name
456 * A :class:`Term` object
457 * An integer giving a raw index
458 * A raw slice object
460 In all cases, a Python :func:`slice` object is returned, which can be
461 used directly for indexing.
463 Example::
465 y, X = dmatrices("y ~ a", demo_data("y", "a", nlevels=3))
466 betas = np.linalg.lstsq(X, y)[0]
467 a_betas = betas[X.design_info.slice("a")]
469 (If you want to look up a single individual column by name, use
470 ``design_info.column_name_indexes[name]``.)
471 """
472 if isinstance(columns_specifier, slice):
473 return columns_specifier
474 if np.issubsctype(type(columns_specifier), np.integer):
475 return slice(columns_specifier, columns_specifier + 1)
476 if (self.term_slices is not None
477 and columns_specifier in self.term_slices):
478 return self.term_slices[columns_specifier]
479 if columns_specifier in self.term_name_slices:
480 return self.term_name_slices[columns_specifier]
481 if columns_specifier in self.column_name_indexes:
482 idx = self.column_name_indexes[columns_specifier]
483 return slice(idx, idx + 1)
484 raise PatsyError("unknown column specified '%s'"
485 % (columns_specifier,))
487 def linear_constraint(self, constraint_likes):
488 """Construct a linear constraint in matrix form from a (possibly
489 symbolic) description.
491 Possible inputs:
493 * A dictionary which is taken as a set of equality constraint. Keys
494 can be either string column names, or integer column indexes.
495 * A string giving a arithmetic expression referring to the matrix
496 columns by name.
497 * A list of such strings which are ANDed together.
498 * A tuple (A, b) where A and b are array_likes, and the constraint is
499 Ax = b. If necessary, these will be coerced to the proper
500 dimensionality by appending dimensions with size 1.
502 The string-based language has the standard arithmetic operators, / * +
503 - and parentheses, plus "=" is used for equality and "," is used to
504 AND together multiple constraint equations within a string. You can
505 If no = appears in some expression, then that expression is assumed to
506 be equal to zero. Division is always float-based, even if
507 ``__future__.true_division`` isn't in effect.
509 Returns a :class:`LinearConstraint` object.
511 Examples::
513 di = DesignInfo(["x1", "x2", "x3"])
515 # Equivalent ways to write x1 == 0:
516 di.linear_constraint({"x1": 0}) # by name
517 di.linear_constraint({0: 0}) # by index
518 di.linear_constraint("x1 = 0") # string based
519 di.linear_constraint("x1") # can leave out "= 0"
520 di.linear_constraint("2 * x1 = (x1 + 2 * x1) / 3")
521 di.linear_constraint(([1, 0, 0], 0)) # constraint matrices
523 # Equivalent ways to write x1 == 0 and x3 == 10
524 di.linear_constraint({"x1": 0, "x3": 10})
525 di.linear_constraint({0: 0, 2: 10})
526 di.linear_constraint({0: 0, "x3": 10})
527 di.linear_constraint("x1 = 0, x3 = 10")
528 di.linear_constraint("x1, x3 = 10")
529 di.linear_constraint(["x1", "x3 = 0"]) # list of strings
530 di.linear_constraint("x1 = 0, x3 - 10 = x1")
531 di.linear_constraint([[1, 0, 0], [0, 0, 1]], [0, 10])
533 # You can also chain together equalities, just like Python:
534 di.linear_constraint("x1 = x2 = 3")
535 """
536 return linear_constraint(constraint_likes, self.column_names)
538 def describe(self):
539 """Returns a human-readable string describing this design info.
541 Example:
543 .. ipython::
545 In [1]: y, X = dmatrices("y ~ x1 + x2", demo_data("y", "x1", "x2"))
547 In [2]: y.design_info.describe()
548 Out[2]: 'y'
550 In [3]: X.design_info.describe()
551 Out[3]: '1 + x1 + x2'
553 .. warning::
555 There is no guarantee that the strings returned by this function
556 can be parsed as formulas, or that if they can be parsed as a
557 formula that they will produce a model equivalent to the one you
558 started with. This function produces a best-effort description
559 intended for humans to read.
561 """
563 names = []
564 for name in self.term_names:
565 if name == "Intercept":
566 names.append("1")
567 else:
568 names.append(name)
569 return " + ".join(names)
571 def subset(self, which_terms):
572 """Create a new :class:`DesignInfo` for design matrices that contain a
573 subset of the terms that the current :class:`DesignInfo` does.
575 For example, if ``design_info`` has terms ``x``, ``y``, and ``z``,
576 then::
578 design_info2 = design_info.subset(["x", "z"])
580 will return a new DesignInfo that can be used to construct design
581 matrices with only the columns corresponding to the terms ``x`` and
582 ``z``. After we do this, then in general these two expressions will
583 return the same thing (here we assume that ``x``, ``y``, and ``z``
584 each generate a single column of the output)::
586 build_design_matrix([design_info], data)[0][:, [0, 2]]
587 build_design_matrix([design_info2], data)[0]
589 However, a critical difference is that in the second case, ``data``
590 need not contain any values for ``y``. This is very useful when doing
591 prediction using a subset of a model, in which situation R usually
592 forces you to specify dummy values for ``y``.
594 If using a formula to specify the terms to include, remember that like
595 any formula, the intercept term will be included by default, so use
596 ``0`` or ``-1`` in your formula if you want to avoid this.
598 This method can also be used to reorder the terms in your design
599 matrix, in case you want to do that for some reason. I can't think of
600 any.
602 Note that this method will generally *not* produce the same result as
603 creating a new model directly. Consider these DesignInfo objects::
605 design1 = dmatrix("1 + C(a)", data)
606 design2 = design1.subset("0 + C(a)")
607 design3 = dmatrix("0 + C(a)", data)
609 Here ``design2`` and ``design3`` will both produce design matrices
610 that contain an encoding of ``C(a)`` without any intercept term. But
611 ``design3`` uses a full-rank encoding for the categorical term
612 ``C(a)``, while ``design2`` uses the same reduced-rank encoding as
613 ``design1``.
615 :arg which_terms: The terms which should be kept in the new
616 :class:`DesignMatrixBuilder`. If this is a string, then it is parsed
617 as a formula, and then the names of the resulting terms are taken as
618 the terms to keep. If it is a list, then it can contain a mixture of
619 term names (as strings) and :class:`Term` objects.
621 .. versionadded: 0.2.0
622 New method on the class DesignMatrixBuilder.
624 .. versionchanged: 0.4.0
625 Moved from DesignMatrixBuilder to DesignInfo, as part of the
626 removal of DesignMatrixBuilder.
628 """
629 if isinstance(which_terms, str):
630 desc = ModelDesc.from_formula(which_terms)
631 if desc.lhs_termlist:
632 raise PatsyError("right-hand-side-only formula required")
633 which_terms = [term.name() for term in desc.rhs_termlist]
635 if self.term_codings is None:
636 # This is a minimal DesignInfo
637 # If the name is unknown we just let the KeyError escape
638 new_names = []
639 for t in which_terms:
640 new_names += self.column_names[self.term_name_slices[t]]
641 return DesignInfo(new_names)
642 else:
643 term_name_to_term = {}
644 for term in self.term_codings:
645 term_name_to_term[term.name()] = term
647 new_column_names = []
648 new_factor_infos = {}
649 new_term_codings = OrderedDict()
650 for name_or_term in which_terms:
651 term = term_name_to_term.get(name_or_term, name_or_term)
652 # If the name is unknown we just let the KeyError escape
653 s = self.term_slices[term]
654 new_column_names += self.column_names[s]
655 for f in term.factors:
656 new_factor_infos[f] = self.factor_infos[f]
657 new_term_codings[term] = self.term_codings[term]
658 return DesignInfo(new_column_names,
659 factor_infos=new_factor_infos,
660 term_codings=new_term_codings)
662 @classmethod
663 def from_array(cls, array_like, default_column_prefix="column"):
664 """Find or construct a DesignInfo appropriate for a given array_like.
666 If the input `array_like` already has a ``.design_info``
667 attribute, then it will be returned. Otherwise, a new DesignInfo
668 object will be constructed, using names either taken from the
669 `array_like` (e.g., for a pandas DataFrame with named columns), or
670 constructed using `default_column_prefix`.
672 This is how :func:`dmatrix` (for example) creates a DesignInfo object
673 if an arbitrary matrix is passed in.
675 :arg array_like: An ndarray or pandas container.
676 :arg default_column_prefix: If it's necessary to invent column names,
677 then this will be used to construct them.
678 :returns: a DesignInfo object
679 """
680 if hasattr(array_like, "design_info") and isinstance(array_like.design_info, cls):
681 return array_like.design_info
682 arr = atleast_2d_column_default(array_like, preserve_pandas=True)
683 if arr.ndim > 2:
684 raise ValueError("design matrix can't have >2 dimensions")
685 columns = getattr(arr, "columns", range(arr.shape[1]))
686 if (hasattr(columns, "dtype")
687 and not safe_issubdtype(columns.dtype, np.integer)):
688 column_names = [str(obj) for obj in columns]
689 else:
690 column_names = ["%s%s" % (default_column_prefix, i)
691 for i in columns]
692 return DesignInfo(column_names)
694 __getstate__ = no_pickling
696def test_DesignInfo():
697 from nose.tools import assert_raises
698 class _MockFactor(object):
699 def __init__(self, name):
700 self._name = name
702 def name(self):
703 return self._name
704 f_x = _MockFactor("x")
705 f_y = _MockFactor("y")
706 t_x = Term([f_x])
707 t_y = Term([f_y])
708 factor_infos = {f_x:
709 FactorInfo(f_x, "numerical", {}, num_columns=3),
710 f_y:
711 FactorInfo(f_y, "numerical", {}, num_columns=1),
712 }
713 term_codings = OrderedDict([(t_x, [SubtermInfo([f_x], {}, 3)]),
714 (t_y, [SubtermInfo([f_y], {}, 1)])])
715 di = DesignInfo(["x1", "x2", "x3", "y"], factor_infos, term_codings)
716 assert di.column_names == ["x1", "x2", "x3", "y"]
717 assert di.term_names == ["x", "y"]
718 assert di.terms == [t_x, t_y]
719 assert di.column_name_indexes == {"x1": 0, "x2": 1, "x3": 2, "y": 3}
720 assert di.term_name_slices == {"x": slice(0, 3), "y": slice(3, 4)}
721 assert di.term_slices == {t_x: slice(0, 3), t_y: slice(3, 4)}
722 assert di.describe() == "x + y"
724 assert di.slice(1) == slice(1, 2)
725 assert di.slice("x1") == slice(0, 1)
726 assert di.slice("x2") == slice(1, 2)
727 assert di.slice("x3") == slice(2, 3)
728 assert di.slice("x") == slice(0, 3)
729 assert di.slice(t_x) == slice(0, 3)
730 assert di.slice("y") == slice(3, 4)
731 assert di.slice(t_y) == slice(3, 4)
732 assert di.slice(slice(2, 4)) == slice(2, 4)
733 assert_raises(PatsyError, di.slice, "asdf")
735 # smoke test
736 repr(di)
738 assert_no_pickling(di)
740 # One without term objects
741 di = DesignInfo(["a1", "a2", "a3", "b"])
742 assert di.column_names == ["a1", "a2", "a3", "b"]
743 assert di.term_names == ["a1", "a2", "a3", "b"]
744 assert di.terms is None
745 assert di.column_name_indexes == {"a1": 0, "a2": 1, "a3": 2, "b": 3}
746 assert di.term_name_slices == {"a1": slice(0, 1),
747 "a2": slice(1, 2),
748 "a3": slice(2, 3),
749 "b": slice(3, 4)}
750 assert di.term_slices is None
751 assert di.describe() == "a1 + a2 + a3 + b"
753 assert di.slice(1) == slice(1, 2)
754 assert di.slice("a1") == slice(0, 1)
755 assert di.slice("a2") == slice(1, 2)
756 assert di.slice("a3") == slice(2, 3)
757 assert di.slice("b") == slice(3, 4)
759 # Check intercept handling in describe()
760 assert DesignInfo(["Intercept", "a", "b"]).describe() == "1 + a + b"
762 # Failure modes
763 # must specify either both or neither of factor_infos and term_codings:
764 assert_raises(ValueError, DesignInfo,
765 ["x1", "x2", "x3", "y"], factor_infos=factor_infos)
766 assert_raises(ValueError, DesignInfo,
767 ["x1", "x2", "x3", "y"], term_codings=term_codings)
768 # factor_infos must be a dict
769 assert_raises(ValueError, DesignInfo,
770 ["x1", "x2", "x3", "y"], list(factor_infos), term_codings)
771 # wrong number of column names:
772 assert_raises(ValueError, DesignInfo,
773 ["x1", "x2", "x3", "y1", "y2"], factor_infos, term_codings)
774 assert_raises(ValueError, DesignInfo,
775 ["x1", "x2", "x3"], factor_infos, term_codings)
776 # name overlap problems
777 assert_raises(ValueError, DesignInfo,
778 ["x1", "x2", "y", "y2"], factor_infos, term_codings)
779 # duplicate name
780 assert_raises(ValueError, DesignInfo,
781 ["x1", "x1", "x1", "y"], factor_infos, term_codings)
783 # f_y is in factor_infos, but not mentioned in any term
784 term_codings_x_only = OrderedDict(term_codings)
785 del term_codings_x_only[t_y]
786 assert_raises(ValueError, DesignInfo,
787 ["x1", "x2", "x3"], factor_infos, term_codings_x_only)
789 # f_a is in a term, but not in factor_infos
790 f_a = _MockFactor("a")
791 t_a = Term([f_a])
792 term_codings_with_a = OrderedDict(term_codings)
793 term_codings_with_a[t_a] = [SubtermInfo([f_a], {}, 1)]
794 assert_raises(ValueError, DesignInfo,
795 ["x1", "x2", "x3", "y", "a"],
796 factor_infos, term_codings_with_a)
798 # bad factor_infos
799 not_factor_infos = dict(factor_infos)
800 not_factor_infos[f_x] = "what is this I don't even"
801 assert_raises(ValueError, DesignInfo,
802 ["x1", "x2", "x3", "y"], not_factor_infos, term_codings)
804 mismatch_factor_infos = dict(factor_infos)
805 mismatch_factor_infos[f_x] = FactorInfo(f_a, "numerical", {}, num_columns=3)
806 assert_raises(ValueError, DesignInfo,
807 ["x1", "x2", "x3", "y"], mismatch_factor_infos, term_codings)
809 # bad term_codings
810 assert_raises(ValueError, DesignInfo,
811 ["x1", "x2", "x3", "y"], factor_infos, dict(term_codings))
813 not_term_codings = OrderedDict(term_codings)
814 not_term_codings["this is a string"] = term_codings[t_x]
815 assert_raises(ValueError, DesignInfo,
816 ["x1", "x2", "x3", "y"], factor_infos, not_term_codings)
818 non_list_term_codings = OrderedDict(term_codings)
819 non_list_term_codings[t_y] = tuple(term_codings[t_y])
820 assert_raises(ValueError, DesignInfo,
821 ["x1", "x2", "x3", "y"], factor_infos, non_list_term_codings)
823 non_subterm_term_codings = OrderedDict(term_codings)
824 non_subterm_term_codings[t_y][0] = "not a SubtermInfo"
825 assert_raises(ValueError, DesignInfo,
826 ["x1", "x2", "x3", "y"], factor_infos, non_subterm_term_codings)
828 bad_subterm = OrderedDict(term_codings)
829 # f_x is a factor in this model, but it is not a factor in t_y
830 term_codings[t_y][0] = SubtermInfo([f_x], {}, 1)
831 assert_raises(ValueError, DesignInfo,
832 ["x1", "x2", "x3", "y"], factor_infos, bad_subterm)
834 # contrast matrix has wrong number of rows
835 factor_codings_a = {f_a:
836 FactorInfo(f_a, "categorical", {},
837 categories=["a1", "a2"])}
838 term_codings_a_bad_rows = OrderedDict([
839 (t_a,
840 [SubtermInfo([f_a],
841 {f_a: ContrastMatrix(np.ones((3, 2)),
842 ["[1]", "[2]"])},
843 2)])])
844 assert_raises(ValueError, DesignInfo,
845 ["a[1]", "a[2]"],
846 factor_codings_a,
847 term_codings_a_bad_rows)
849 # have a contrast matrix for a non-categorical factor
850 t_ax = Term([f_a, f_x])
851 factor_codings_ax = {f_a:
852 FactorInfo(f_a, "categorical", {},
853 categories=["a1", "a2"]),
854 f_x:
855 FactorInfo(f_x, "numerical", {},
856 num_columns=2)}
857 term_codings_ax_extra_cm = OrderedDict([
858 (t_ax,
859 [SubtermInfo([f_a, f_x],
860 {f_a: ContrastMatrix(np.ones((2, 2)), ["[1]", "[2]"]),
861 f_x: ContrastMatrix(np.ones((2, 2)), ["[1]", "[2]"])},
862 4)])])
863 assert_raises(ValueError, DesignInfo,
864 ["a[1]:x[1]", "a[2]:x[1]", "a[1]:x[2]", "a[2]:x[2]"],
865 factor_codings_ax,
866 term_codings_ax_extra_cm)
868 # no contrast matrix for a categorical factor
869 term_codings_ax_missing_cm = OrderedDict([
870 (t_ax,
871 [SubtermInfo([f_a, f_x],
872 {},
873 4)])])
874 # This actually fails before it hits the relevant check with a KeyError,
875 # but that's okay... the previous test still exercises the check.
876 assert_raises((ValueError, KeyError), DesignInfo,
877 ["a[1]:x[1]", "a[2]:x[1]", "a[1]:x[2]", "a[2]:x[2]"],
878 factor_codings_ax,
879 term_codings_ax_missing_cm)
881 # subterm num_columns doesn't match the value computed from the individual
882 # factors
883 term_codings_ax_wrong_subterm_columns = OrderedDict([
884 (t_ax,
885 [SubtermInfo([f_a, f_x],
886 {f_a: ContrastMatrix(np.ones((2, 3)),
887 ["[1]", "[2]", "[3]"])},
888 # should be 2 * 3 = 6
889 5)])])
890 assert_raises(ValueError, DesignInfo,
891 ["a[1]:x[1]", "a[2]:x[1]", "a[3]:x[1]",
892 "a[1]:x[2]", "a[2]:x[2]", "a[3]:x[2]"],
893 factor_codings_ax,
894 term_codings_ax_wrong_subterm_columns)
896def test_DesignInfo_from_array():
897 di = DesignInfo.from_array([1, 2, 3])
898 assert di.column_names == ["column0"]
899 di2 = DesignInfo.from_array([[1, 2], [2, 3], [3, 4]])
900 assert di2.column_names == ["column0", "column1"]
901 di3 = DesignInfo.from_array([1, 2, 3], default_column_prefix="x")
902 assert di3.column_names == ["x0"]
903 di4 = DesignInfo.from_array([[1, 2], [2, 3], [3, 4]],
904 default_column_prefix="x")
905 assert di4.column_names == ["x0", "x1"]
906 m = DesignMatrix([1, 2, 3], di3)
907 assert DesignInfo.from_array(m) is di3
908 # But weird objects are ignored
909 m.design_info = "asdf"
910 di_weird = DesignInfo.from_array(m)
911 assert di_weird.column_names == ["column0"]
913 from nose.tools import assert_raises
914 assert_raises(ValueError, DesignInfo.from_array, np.ones((2, 2, 2)))
916 from patsy.util import have_pandas
917 if have_pandas:
918 import pandas
919 # with named columns
920 di5 = DesignInfo.from_array(pandas.DataFrame([[1, 2]],
921 columns=["a", "b"]))
922 assert di5.column_names == ["a", "b"]
923 # with irregularly numbered columns
924 di6 = DesignInfo.from_array(pandas.DataFrame([[1, 2]],
925 columns=[0, 10]))
926 assert di6.column_names == ["column0", "column10"]
927 # with .design_info attr
928 df = pandas.DataFrame([[1, 2]])
929 df.design_info = di6
930 assert DesignInfo.from_array(df) is di6
932def test_DesignInfo_linear_constraint():
933 di = DesignInfo(["a1", "a2", "a3", "b"])
934 con = di.linear_constraint(["2 * a1 = b + 1", "a3"])
935 assert con.variable_names == ["a1", "a2", "a3", "b"]
936 assert np.all(con.coefs == [[2, 0, 0, -1], [0, 0, 1, 0]])
937 assert np.all(con.constants == [[1], [0]])
939def test_DesignInfo_deprecated_attributes():
940 d = DesignInfo(["a1", "a2"])
941 def check(attr):
942 with warnings.catch_warnings(record=True) as w:
943 warnings.simplefilter("always")
944 assert getattr(d, attr) is d
945 assert len(w) == 1
946 assert w[0].category is DeprecationWarning
947 check("builder")
948 check("design_info")
950# Idea: format with a reasonable amount of precision, then if that turns out
951# to be higher than necessary, remove as many zeros as we can. But only do
952# this while we can do it to *all* the ordinarily-formatted numbers, to keep
953# decimal points aligned.
954def _format_float_column(precision, col):
955 format_str = "%." + str(precision) + "f"
956 assert col.ndim == 1
957 # We don't want to look at numbers like "1e-5" or "nan" when stripping.
958 simple_float_chars = set("+-0123456789.")
959 col_strs = np.array([format_str % (x,) for x in col], dtype=object)
960 # Really every item should have a decimal, but just in case, we don't want
961 # to strip zeros off the end of "10" or something like that.
962 mask = np.array([simple_float_chars.issuperset(col_str) and "." in col_str
963 for col_str in col_strs])
964 mask_idxes = np.nonzero(mask)[0]
965 strip_char = "0"
966 if np.any(mask):
967 while True:
968 if np.all([s.endswith(strip_char) for s in col_strs[mask]]):
969 for idx in mask_idxes:
970 col_strs[idx] = col_strs[idx][:-1]
971 else:
972 if strip_char == "0":
973 strip_char = "."
974 else:
975 break
976 return col_strs
978def test__format_float_column():
979 def t(precision, numbers, expected):
980 got = _format_float_column(precision, np.asarray(numbers))
981 print(got, expected)
982 assert np.array_equal(got, expected)
983 # This acts weird on old python versions (e.g. it can be "-nan"), so don't
984 # hardcode it:
985 nan_string = "%.3f" % (np.nan,)
986 t(3, [1, 2.1234, 2.1239, np.nan], ["1.000", "2.123", "2.124", nan_string])
987 t(3, [1, 2, 3, np.nan], ["1", "2", "3", nan_string])
988 t(3, [1.0001, 2, 3, np.nan], ["1", "2", "3", nan_string])
989 t(4, [1.0001, 2, 3, np.nan], ["1.0001", "2.0000", "3.0000", nan_string])
991# http://docs.scipy.org/doc/numpy/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array
992class DesignMatrix(np.ndarray):
993 """A simple numpy array subclass that carries design matrix metadata.
995 .. attribute:: design_info
997 A :class:`DesignInfo` object containing metadata about this design
998 matrix.
1000 This class also defines a fancy __repr__ method with labeled
1001 columns. Otherwise it is identical to a regular numpy ndarray.
1003 .. warning::
1005 You should never check for this class using
1006 :func:`isinstance`. Limitations of the numpy API mean that it is
1007 impossible to prevent the creation of numpy arrays that have type
1008 DesignMatrix, but that are not actually design matrices (and such
1009 objects will behave like regular ndarrays in every way). Instead, check
1010 for the presence of a ``.design_info`` attribute -- this will be
1011 present only on "real" DesignMatrix objects.
1012 """
1014 def __new__(cls, input_array, design_info=None,
1015 default_column_prefix="column"):
1016 """Create a DesignMatrix, or cast an existing matrix to a DesignMatrix.
1018 A call like::
1020 DesignMatrix(my_array)
1022 will convert an arbitrary array_like object into a DesignMatrix.
1024 The return from this function is guaranteed to be a two-dimensional
1025 ndarray with a real-valued floating point dtype, and a
1026 ``.design_info`` attribute which matches its shape. If the
1027 `design_info` argument is not given, then one is created via
1028 :meth:`DesignInfo.from_array` using the given
1029 `default_column_prefix`.
1031 Depending on the input array, it is possible this will pass through
1032 its input unchanged, or create a view.
1033 """
1034 # Pass through existing DesignMatrixes. The design_info check is
1035 # necessary because numpy is sort of annoying and cannot be stopped
1036 # from turning non-design-matrix arrays into DesignMatrix
1037 # instances. (E.g., my_dm.diagonal() will return a DesignMatrix
1038 # object, but one without a design_info attribute.)
1039 if (isinstance(input_array, DesignMatrix)
1040 and hasattr(input_array, "design_info")):
1041 return input_array
1042 self = atleast_2d_column_default(input_array).view(cls)
1043 # Upcast integer to floating point
1044 if safe_issubdtype(self.dtype, np.integer):
1045 self = np.asarray(self, dtype=float).view(cls)
1046 if self.ndim > 2:
1047 raise ValueError("DesignMatrix must be 2d")
1048 assert self.ndim == 2
1049 if design_info is None:
1050 design_info = DesignInfo.from_array(self, default_column_prefix)
1051 if len(design_info.column_names) != self.shape[1]:
1052 raise ValueError("wrong number of column names for design matrix "
1053 "(got %s, wanted %s)"
1054 % (len(design_info.column_names), self.shape[1]))
1055 self.design_info = design_info
1056 if not safe_issubdtype(self.dtype, np.floating):
1057 raise ValueError("design matrix must be real-valued floating point")
1058 return self
1060 __repr__ = repr_pretty_delegate
1061 def _repr_pretty_(self, p, cycle):
1062 if not hasattr(self, "design_info"):
1063 # Not a real DesignMatrix
1064 p.pretty(np.asarray(self))
1065 return
1066 assert not cycle
1068 # XX: could try calculating width of the current terminal window:
1069 # http://stackoverflow.com/questions/566746/how-to-get-console-window-width-in-python
1070 # sadly it looks like ipython does not actually pass this information
1071 # in, even if we use _repr_pretty_ -- the pretty-printer object has a
1072 # fixed width it always uses. (As of IPython 0.12.)
1073 MAX_TOTAL_WIDTH = 78
1074 SEP = 2
1075 INDENT = 2
1076 MAX_ROWS = 30
1077 PRECISION = 5
1079 names = self.design_info.column_names
1080 column_name_widths = [len(name) for name in names]
1081 min_total_width = (INDENT + SEP * (self.shape[1] - 1)
1082 + np.sum(column_name_widths))
1083 if min_total_width <= MAX_TOTAL_WIDTH:
1084 printable_part = np.asarray(self)[:MAX_ROWS, :]
1085 formatted_cols = [_format_float_column(PRECISION,
1086 printable_part[:, i])
1087 for i in range(self.shape[1])]
1088 def max_width(col):
1089 assert col.ndim == 1
1090 if not col.shape[0]:
1091 return 0
1092 else:
1093 return max([len(s) for s in col])
1094 column_num_widths = [max_width(col) for col in formatted_cols]
1095 column_widths = [max(name_width, num_width)
1096 for (name_width, num_width)
1097 in zip(column_name_widths, column_num_widths)]
1098 total_width = (INDENT + SEP * (self.shape[1] - 1)
1099 + np.sum(column_widths))
1100 print_numbers = (total_width < MAX_TOTAL_WIDTH)
1101 else:
1102 print_numbers = False
1104 p.begin_group(INDENT, "DesignMatrix with shape %s" % (self.shape,))
1105 p.breakable("\n" + " " * p.indentation)
1106 if print_numbers:
1107 # We can fit the numbers on the screen
1108 sep = " " * SEP
1109 # list() is for Py3 compatibility
1110 for row in [names] + list(zip(*formatted_cols)):
1111 cells = [cell.rjust(width)
1112 for (width, cell) in zip(column_widths, row)]
1113 p.text(sep.join(cells))
1114 p.text("\n" + " " * p.indentation)
1115 if MAX_ROWS < self.shape[0]:
1116 p.text("[%s rows omitted]" % (self.shape[0] - MAX_ROWS,))
1117 p.text("\n" + " " * p.indentation)
1118 else:
1119 p.begin_group(2, "Columns:")
1120 p.breakable("\n" + " " * p.indentation)
1121 p.pretty(names)
1122 p.end_group(2, "")
1123 p.breakable("\n" + " " * p.indentation)
1125 p.begin_group(2, "Terms:")
1126 p.breakable("\n" + " " * p.indentation)
1127 for term_name, span in six.iteritems(self.design_info.term_name_slices):
1128 if span.start != 0:
1129 p.breakable(", ")
1130 p.pretty(term_name)
1131 if span.stop - span.start == 1:
1132 coltext = "column %s" % (span.start,)
1133 else:
1134 coltext = "columns %s:%s" % (span.start, span.stop)
1135 p.text(" (%s)" % (coltext,))
1136 p.end_group(2, "")
1138 if not print_numbers or self.shape[0] > MAX_ROWS:
1139 # some data was not shown
1140 p.breakable("\n" + " " * p.indentation)
1141 p.text("(to view full data, use np.asarray(this_obj))")
1143 p.end_group(INDENT, "")
1145 # No __array_finalize__ method, because we don't want slices of this
1146 # object to keep the design_info (they may have different columns!), or
1147 # anything fancy like that.
1149 __reduce__ = no_pickling
1151def test_design_matrix():
1152 from nose.tools import assert_raises
1154 di = DesignInfo(["a1", "a2", "a3", "b"])
1155 mm = DesignMatrix([[12, 14, 16, 18]], di)
1156 assert mm.design_info.column_names == ["a1", "a2", "a3", "b"]
1158 bad_di = DesignInfo(["a1"])
1159 assert_raises(ValueError, DesignMatrix, [[12, 14, 16, 18]], bad_di)
1161 mm2 = DesignMatrix([[12, 14, 16, 18]])
1162 assert mm2.design_info.column_names == ["column0", "column1", "column2",
1163 "column3"]
1165 mm3 = DesignMatrix([12, 14, 16, 18])
1166 assert mm3.shape == (4, 1)
1168 # DesignMatrix always has exactly 2 dimensions
1169 assert_raises(ValueError, DesignMatrix, [[[1]]])
1171 # DesignMatrix constructor passes through existing DesignMatrixes
1172 mm4 = DesignMatrix(mm)
1173 assert mm4 is mm
1174 # But not if they are really slices:
1175 mm5 = DesignMatrix(mm.diagonal())
1176 assert mm5 is not mm
1178 mm6 = DesignMatrix([[12, 14, 16, 18]], default_column_prefix="x")
1179 assert mm6.design_info.column_names == ["x0", "x1", "x2", "x3"]
1181 assert_no_pickling(mm6)
1183 # Only real-valued matrices can be DesignMatrixs
1184 assert_raises(ValueError, DesignMatrix, [1, 2, 3j])
1185 assert_raises(ValueError, DesignMatrix, ["a", "b", "c"])
1186 assert_raises(ValueError, DesignMatrix, [1, 2, object()])
1188 # Just smoke tests
1189 repr(mm)
1190 repr(DesignMatrix(np.arange(100)))
1191 repr(DesignMatrix(np.arange(100) * 2.0))
1192 repr(mm[1:, :])
1193 repr(DesignMatrix(np.arange(100).reshape((1, 100))))
1194 repr(DesignMatrix([np.nan, np.inf]))
1195 repr(DesignMatrix([np.nan, 0, 1e20, 20.5]))
1196 # handling of zero-size matrices
1197 repr(DesignMatrix(np.zeros((1, 0))))
1198 repr(DesignMatrix(np.zeros((0, 1))))
1199 repr(DesignMatrix(np.zeros((0, 0))))