Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/pandas/core/reshape/pivot.py : 8%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union
3import numpy as np
5from pandas.util._decorators import Appender, Substitution
7from pandas.core.dtypes.cast import maybe_downcast_to_dtype
8from pandas.core.dtypes.common import is_integer_dtype, is_list_like, is_scalar
9from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
11import pandas.core.common as com
12from pandas.core.frame import _shared_docs
13from pandas.core.groupby import Grouper
14from pandas.core.indexes.api import Index, MultiIndex, get_objs_combined_axis
15from pandas.core.reshape.concat import concat
16from pandas.core.reshape.util import cartesian_product
17from pandas.core.series import Series
19if TYPE_CHECKING:
20 from pandas import DataFrame
23# Note: We need to make sure `frame` is imported before `pivot`, otherwise
24# _shared_docs['pivot_table'] will not yet exist. TODO: Fix this dependency
25@Substitution("\ndata : DataFrame")
26@Appender(_shared_docs["pivot_table"], indents=1)
27def pivot_table(
28 data,
29 values=None,
30 index=None,
31 columns=None,
32 aggfunc="mean",
33 fill_value=None,
34 margins=False,
35 dropna=True,
36 margins_name="All",
37 observed=False,
38) -> "DataFrame":
39 index = _convert_by(index)
40 columns = _convert_by(columns)
42 if isinstance(aggfunc, list):
43 pieces: List[DataFrame] = []
44 keys = []
45 for func in aggfunc:
46 table = pivot_table(
47 data,
48 values=values,
49 index=index,
50 columns=columns,
51 fill_value=fill_value,
52 aggfunc=func,
53 margins=margins,
54 dropna=dropna,
55 margins_name=margins_name,
56 observed=observed,
57 )
58 pieces.append(table)
59 keys.append(getattr(func, "__name__", func))
61 return concat(pieces, keys=keys, axis=1)
63 keys = index + columns
65 values_passed = values is not None
66 if values_passed:
67 if is_list_like(values):
68 values_multi = True
69 values = list(values)
70 else:
71 values_multi = False
72 values = [values]
74 # GH14938 Make sure value labels are in data
75 for i in values:
76 if i not in data:
77 raise KeyError(i)
79 to_filter = []
80 for x in keys + values:
81 if isinstance(x, Grouper):
82 x = x.key
83 try:
84 if x in data:
85 to_filter.append(x)
86 except TypeError:
87 pass
88 if len(to_filter) < len(data.columns):
89 data = data[to_filter]
91 else:
92 values = data.columns
93 for key in keys:
94 try:
95 values = values.drop(key)
96 except (TypeError, ValueError, KeyError):
97 pass
98 values = list(values)
100 grouped = data.groupby(keys, observed=observed)
101 agged = grouped.agg(aggfunc)
102 if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
103 agged = agged.dropna(how="all")
105 # gh-21133
106 # we want to down cast if
107 # the original values are ints
108 # as we grouped with a NaN value
109 # and then dropped, coercing to floats
110 for v in values:
111 if (
112 v in data
113 and is_integer_dtype(data[v])
114 and v in agged
115 and not is_integer_dtype(agged[v])
116 ):
117 agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
119 table = agged
120 if table.index.nlevels > 1:
121 # Related GH #17123
122 # If index_names are integers, determine whether the integers refer
123 # to the level position or name.
124 index_names = agged.index.names[: len(index)]
125 to_unstack = []
126 for i in range(len(index), len(keys)):
127 name = agged.index.names[i]
128 if name is None or name in index_names:
129 to_unstack.append(i)
130 else:
131 to_unstack.append(name)
132 table = agged.unstack(to_unstack)
134 if not dropna:
135 if table.index.nlevels > 1:
136 m = MultiIndex.from_arrays(
137 cartesian_product(table.index.levels), names=table.index.names
138 )
139 table = table.reindex(m, axis=0)
141 if table.columns.nlevels > 1:
142 m = MultiIndex.from_arrays(
143 cartesian_product(table.columns.levels), names=table.columns.names
144 )
145 table = table.reindex(m, axis=1)
147 if isinstance(table, ABCDataFrame):
148 table = table.sort_index(axis=1)
150 if fill_value is not None:
151 _table = table.fillna(fill_value, downcast="infer")
152 assert _table is not None # needed for mypy
153 table = _table
155 if margins:
156 if dropna:
157 data = data[data.notna().all(axis=1)]
158 table = _add_margins(
159 table,
160 data,
161 values,
162 rows=index,
163 cols=columns,
164 aggfunc=aggfunc,
165 observed=dropna,
166 margins_name=margins_name,
167 fill_value=fill_value,
168 )
170 # discard the top level
171 if (
172 values_passed
173 and not values_multi
174 and not table.empty
175 and (table.columns.nlevels > 1)
176 ):
177 table = table[values[0]]
179 if len(index) == 0 and len(columns) > 0:
180 table = table.T
182 # GH 15193 Make sure empty columns are removed if dropna=True
183 if isinstance(table, ABCDataFrame) and dropna:
184 table = table.dropna(how="all", axis=1)
186 return table
189def _add_margins(
190 table: Union["Series", "DataFrame"],
191 data,
192 values,
193 rows,
194 cols,
195 aggfunc,
196 observed=None,
197 margins_name: str = "All",
198 fill_value=None,
199):
200 if not isinstance(margins_name, str):
201 raise ValueError("margins_name argument must be a string")
203 msg = 'Conflicting name "{name}" in margins'.format(name=margins_name)
204 for level in table.index.names:
205 if margins_name in table.index.get_level_values(level):
206 raise ValueError(msg)
208 grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
210 if table.ndim == 2:
211 # i.e. DataFramae
212 for level in table.columns.names[1:]:
213 if margins_name in table.columns.get_level_values(level):
214 raise ValueError(msg)
216 key: Union[str, Tuple[str, ...]]
217 if len(rows) > 1:
218 key = (margins_name,) + ("",) * (len(rows) - 1)
219 else:
220 key = margins_name
222 if not values and isinstance(table, ABCSeries):
223 # If there are no values and the table is a series, then there is only
224 # one column in the data. Compute grand margin and return it.
225 return table.append(Series({key: grand_margin[margins_name]}))
227 elif values:
228 marginal_result_set = _generate_marginal_results(
229 table,
230 data,
231 values,
232 rows,
233 cols,
234 aggfunc,
235 observed,
236 grand_margin,
237 margins_name,
238 )
239 if not isinstance(marginal_result_set, tuple):
240 return marginal_result_set
241 result, margin_keys, row_margin = marginal_result_set
242 else:
243 # no values, and table is a DataFrame
244 assert isinstance(table, ABCDataFrame)
245 marginal_result_set = _generate_marginal_results_without_values(
246 table, data, rows, cols, aggfunc, observed, margins_name
247 )
248 if not isinstance(marginal_result_set, tuple):
249 return marginal_result_set
250 result, margin_keys, row_margin = marginal_result_set
252 row_margin = row_margin.reindex(result.columns, fill_value=fill_value)
253 # populate grand margin
254 for k in margin_keys:
255 if isinstance(k, str):
256 row_margin[k] = grand_margin[k]
257 else:
258 row_margin[k] = grand_margin[k[0]]
260 from pandas import DataFrame
262 margin_dummy = DataFrame(row_margin, columns=[key]).T
264 row_names = result.index.names
265 try:
266 # check the result column and leave floats
267 for dtype in set(result.dtypes):
268 cols = result.select_dtypes([dtype]).columns
269 margin_dummy[cols] = margin_dummy[cols].apply(
270 maybe_downcast_to_dtype, args=(dtype,)
271 )
272 result = result.append(margin_dummy)
273 except TypeError:
275 # we cannot reshape, so coerce the axis
276 result.index = result.index._to_safe_for_reshape()
277 result = result.append(margin_dummy)
278 result.index.names = row_names
280 return result
283def _compute_grand_margin(data, values, aggfunc, margins_name: str = "All"):
285 if values:
286 grand_margin = {}
287 for k, v in data[values].items():
288 try:
289 if isinstance(aggfunc, str):
290 grand_margin[k] = getattr(v, aggfunc)()
291 elif isinstance(aggfunc, dict):
292 if isinstance(aggfunc[k], str):
293 grand_margin[k] = getattr(v, aggfunc[k])()
294 else:
295 grand_margin[k] = aggfunc[k](v)
296 else:
297 grand_margin[k] = aggfunc(v)
298 except TypeError:
299 pass
300 return grand_margin
301 else:
302 return {margins_name: aggfunc(data.index)}
305def _generate_marginal_results(
306 table,
307 data,
308 values,
309 rows,
310 cols,
311 aggfunc,
312 observed,
313 grand_margin,
314 margins_name: str = "All",
315):
316 if len(cols) > 0:
317 # need to "interleave" the margins
318 table_pieces = []
319 margin_keys = []
321 def _all_key(key):
322 return (key, margins_name) + ("",) * (len(cols) - 1)
324 if len(rows) > 0:
325 margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc)
326 cat_axis = 1
328 for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed):
329 all_key = _all_key(key)
331 # we are going to mutate this, so need to copy!
332 piece = piece.copy()
333 try:
334 piece[all_key] = margin[key]
335 except TypeError:
337 # we cannot reshape, so coerce the axis
338 piece.set_axis(
339 piece._get_axis(cat_axis)._to_safe_for_reshape(),
340 axis=cat_axis,
341 inplace=True,
342 )
343 piece[all_key] = margin[key]
345 table_pieces.append(piece)
346 margin_keys.append(all_key)
347 else:
348 margin = grand_margin
349 cat_axis = 0
350 for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed):
351 all_key = _all_key(key)
352 table_pieces.append(piece)
353 table_pieces.append(Series(margin[key], index=[all_key]))
354 margin_keys.append(all_key)
356 result = concat(table_pieces, axis=cat_axis)
358 if len(rows) == 0:
359 return result
360 else:
361 result = table
362 margin_keys = table.columns
364 if len(cols) > 0:
365 row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc)
366 row_margin = row_margin.stack()
368 # slight hack
369 new_order = [len(cols)] + list(range(len(cols)))
370 row_margin.index = row_margin.index.reorder_levels(new_order)
371 else:
372 row_margin = Series(np.nan, index=result.columns)
374 return result, margin_keys, row_margin
377def _generate_marginal_results_without_values(
378 table: "DataFrame", data, rows, cols, aggfunc, observed, margins_name: str = "All"
379):
380 if len(cols) > 0:
381 # need to "interleave" the margins
382 margin_keys = []
384 def _all_key():
385 if len(cols) == 1:
386 return margins_name
387 return (margins_name,) + ("",) * (len(cols) - 1)
389 if len(rows) > 0:
390 margin = data[rows].groupby(rows, observed=observed).apply(aggfunc)
391 all_key = _all_key()
392 table[all_key] = margin
393 result = table
394 margin_keys.append(all_key)
396 else:
397 margin = data.groupby(level=0, axis=0, observed=observed).apply(aggfunc)
398 all_key = _all_key()
399 table[all_key] = margin
400 result = table
401 margin_keys.append(all_key)
402 return result
403 else:
404 result = table
405 margin_keys = table.columns
407 if len(cols):
408 row_margin = data[cols].groupby(cols, observed=observed).apply(aggfunc)
409 else:
410 row_margin = Series(np.nan, index=result.columns)
412 return result, margin_keys, row_margin
415def _convert_by(by):
416 if by is None:
417 by = []
418 elif (
419 is_scalar(by)
420 or isinstance(by, (np.ndarray, Index, ABCSeries, Grouper))
421 or hasattr(by, "__call__")
422 ):
423 by = [by]
424 else:
425 by = list(by)
426 return by
429@Substitution("\ndata : DataFrame")
430@Appender(_shared_docs["pivot"], indents=1)
431def pivot(data: "DataFrame", index=None, columns=None, values=None) -> "DataFrame":
432 if values is None:
433 cols = [columns] if index is None else [index, columns]
434 append = index is None
435 indexed = data.set_index(cols, append=append)
436 else:
437 if index is None:
438 index = data.index
439 else:
440 index = data[index]
441 index = MultiIndex.from_arrays([index, data[columns]])
443 if is_list_like(values) and not isinstance(values, tuple):
444 # Exclude tuple because it is seen as a single column name
445 indexed = data._constructor(
446 data[values].values, index=index, columns=values
447 )
448 else:
449 indexed = data._constructor_sliced(data[values].values, index=index)
450 return indexed.unstack(columns)
453def crosstab(
454 index,
455 columns,
456 values=None,
457 rownames=None,
458 colnames=None,
459 aggfunc=None,
460 margins=False,
461 margins_name: str = "All",
462 dropna: bool = True,
463 normalize=False,
464) -> "DataFrame":
465 """
466 Compute a simple cross tabulation of two (or more) factors. By default
467 computes a frequency table of the factors unless an array of values and an
468 aggregation function are passed.
470 Parameters
471 ----------
472 index : array-like, Series, or list of arrays/Series
473 Values to group by in the rows.
474 columns : array-like, Series, or list of arrays/Series
475 Values to group by in the columns.
476 values : array-like, optional
477 Array of values to aggregate according to the factors.
478 Requires `aggfunc` be specified.
479 rownames : sequence, default None
480 If passed, must match number of row arrays passed.
481 colnames : sequence, default None
482 If passed, must match number of column arrays passed.
483 aggfunc : function, optional
484 If specified, requires `values` be specified as well.
485 margins : bool, default False
486 Add row/column margins (subtotals).
487 margins_name : str, default 'All'
488 Name of the row/column that will contain the totals
489 when margins is True.
491 .. versionadded:: 0.21.0
493 dropna : bool, default True
494 Do not include columns whose entries are all NaN.
495 normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False
496 Normalize by dividing all values by the sum of values.
498 - If passed 'all' or `True`, will normalize over all values.
499 - If passed 'index' will normalize over each row.
500 - If passed 'columns' will normalize over each column.
501 - If margins is `True`, will also normalize margin values.
503 Returns
504 -------
505 DataFrame
506 Cross tabulation of the data.
508 See Also
509 --------
510 DataFrame.pivot : Reshape data based on column values.
511 pivot_table : Create a pivot table as a DataFrame.
513 Notes
514 -----
515 Any Series passed will have their name attributes used unless row or column
516 names for the cross-tabulation are specified.
518 Any input passed containing Categorical data will have **all** of its
519 categories included in the cross-tabulation, even if the actual data does
520 not contain any instances of a particular category.
522 In the event that there aren't overlapping indexes an empty DataFrame will
523 be returned.
525 Examples
526 --------
527 >>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar",
528 ... "bar", "bar", "foo", "foo", "foo"], dtype=object)
529 >>> b = np.array(["one", "one", "one", "two", "one", "one",
530 ... "one", "two", "two", "two", "one"], dtype=object)
531 >>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny",
532 ... "shiny", "dull", "shiny", "shiny", "shiny"],
533 ... dtype=object)
534 >>> pd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'])
535 b one two
536 c dull shiny dull shiny
537 a
538 bar 1 2 1 0
539 foo 2 2 1 2
541 Here 'c' and 'f' are not represented in the data and will not be
542 shown in the output because dropna is True by default. Set
543 dropna=False to preserve categories with no data.
545 >>> foo = pd.Categorical(['a', 'b'], categories=['a', 'b', 'c'])
546 >>> bar = pd.Categorical(['d', 'e'], categories=['d', 'e', 'f'])
547 >>> pd.crosstab(foo, bar)
548 col_0 d e
549 row_0
550 a 1 0
551 b 0 1
552 >>> pd.crosstab(foo, bar, dropna=False)
553 col_0 d e f
554 row_0
555 a 1 0 0
556 b 0 1 0
557 c 0 0 0
558 """
560 index = com.maybe_make_list(index)
561 columns = com.maybe_make_list(columns)
563 rownames = _get_names(index, rownames, prefix="row")
564 colnames = _get_names(columns, colnames, prefix="col")
566 common_idx = None
567 pass_objs = [x for x in index + columns if isinstance(x, (ABCSeries, ABCDataFrame))]
568 if pass_objs:
569 common_idx = get_objs_combined_axis(pass_objs, intersect=True, sort=False)
571 data: Dict = {}
572 data.update(zip(rownames, index))
573 data.update(zip(colnames, columns))
575 if values is None and aggfunc is not None:
576 raise ValueError("aggfunc cannot be used without values.")
578 if values is not None and aggfunc is None:
579 raise ValueError("values cannot be used without an aggfunc.")
581 from pandas import DataFrame
583 df = DataFrame(data, index=common_idx)
584 if values is None:
585 df["__dummy__"] = 0
586 kwargs = {"aggfunc": len, "fill_value": 0}
587 else:
588 df["__dummy__"] = values
589 kwargs = {"aggfunc": aggfunc}
591 table = df.pivot_table(
592 "__dummy__",
593 index=rownames,
594 columns=colnames,
595 margins=margins,
596 margins_name=margins_name,
597 dropna=dropna,
598 **kwargs,
599 )
601 # Post-process
602 if normalize is not False:
603 table = _normalize(
604 table, normalize=normalize, margins=margins, margins_name=margins_name
605 )
607 return table
610def _normalize(table, normalize, margins: bool, margins_name="All"):
612 if not isinstance(normalize, (bool, str)):
613 axis_subs = {0: "index", 1: "columns"}
614 try:
615 normalize = axis_subs[normalize]
616 except KeyError:
617 raise ValueError("Not a valid normalize argument")
619 if margins is False:
621 # Actual Normalizations
622 normalizers: Dict[Union[bool, str], Callable] = {
623 "all": lambda x: x / x.sum(axis=1).sum(axis=0),
624 "columns": lambda x: x / x.sum(),
625 "index": lambda x: x.div(x.sum(axis=1), axis=0),
626 }
628 normalizers[True] = normalizers["all"]
630 try:
631 f = normalizers[normalize]
632 except KeyError:
633 raise ValueError("Not a valid normalize argument")
635 table = f(table)
636 table = table.fillna(0)
638 elif margins is True:
639 # keep index and column of pivoted table
640 table_index = table.index
641 table_columns = table.columns
643 # check if margin name is in (for MI cases) or equal to last
644 # index/column and save the column and index margin
645 if (margins_name not in table.iloc[-1, :].name) | (
646 margins_name != table.iloc[:, -1].name
647 ):
648 raise ValueError(
649 "{mname} not in pivoted DataFrame".format(mname=margins_name)
650 )
651 column_margin = table.iloc[:-1, -1]
652 index_margin = table.iloc[-1, :-1]
654 # keep the core table
655 table = table.iloc[:-1, :-1]
657 # Normalize core
658 table = _normalize(table, normalize=normalize, margins=False)
660 # Fix Margins
661 if normalize == "columns":
662 column_margin = column_margin / column_margin.sum()
663 table = concat([table, column_margin], axis=1)
664 table = table.fillna(0)
665 table.columns = table_columns
667 elif normalize == "index":
668 index_margin = index_margin / index_margin.sum()
669 table = table.append(index_margin)
670 table = table.fillna(0)
671 table.index = table_index
673 elif normalize == "all" or normalize is True:
674 column_margin = column_margin / column_margin.sum()
675 index_margin = index_margin / index_margin.sum()
676 index_margin.loc[margins_name] = 1
677 table = concat([table, column_margin], axis=1)
678 table = table.append(index_margin)
680 table = table.fillna(0)
681 table.index = table_index
682 table.columns = table_columns
684 else:
685 raise ValueError("Not a valid normalize argument")
687 else:
688 raise ValueError("Not a valid margins argument")
690 return table
693def _get_names(arrs, names, prefix: str = "row"):
694 if names is None:
695 names = []
696 for i, arr in enumerate(arrs):
697 if isinstance(arr, ABCSeries) and arr.name is not None:
698 names.append(arr.name)
699 else:
700 names.append("{prefix}_{i}".format(prefix=prefix, i=i))
701 else:
702 if len(names) != len(arrs):
703 raise AssertionError("arrays and names must have the same length")
704 if not isinstance(names, list):
705 names = list(names)
707 return names