Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union 

2 

3import numpy as np 

4 

5from pandas.util._decorators import Appender, Substitution 

6 

7from pandas.core.dtypes.cast import maybe_downcast_to_dtype 

8from pandas.core.dtypes.common import is_integer_dtype, is_list_like, is_scalar 

9from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries 

10 

11import pandas.core.common as com 

12from pandas.core.frame import _shared_docs 

13from pandas.core.groupby import Grouper 

14from pandas.core.indexes.api import Index, MultiIndex, get_objs_combined_axis 

15from pandas.core.reshape.concat import concat 

16from pandas.core.reshape.util import cartesian_product 

17from pandas.core.series import Series 

18 

19if TYPE_CHECKING: 

20 from pandas import DataFrame 

21 

22 

23# Note: We need to make sure `frame` is imported before `pivot`, otherwise 

24# _shared_docs['pivot_table'] will not yet exist. TODO: Fix this dependency 

25@Substitution("\ndata : DataFrame") 

26@Appender(_shared_docs["pivot_table"], indents=1) 

27def pivot_table( 

28 data, 

29 values=None, 

30 index=None, 

31 columns=None, 

32 aggfunc="mean", 

33 fill_value=None, 

34 margins=False, 

35 dropna=True, 

36 margins_name="All", 

37 observed=False, 

38) -> "DataFrame": 

39 index = _convert_by(index) 

40 columns = _convert_by(columns) 

41 

42 if isinstance(aggfunc, list): 

43 pieces: List[DataFrame] = [] 

44 keys = [] 

45 for func in aggfunc: 

46 table = pivot_table( 

47 data, 

48 values=values, 

49 index=index, 

50 columns=columns, 

51 fill_value=fill_value, 

52 aggfunc=func, 

53 margins=margins, 

54 dropna=dropna, 

55 margins_name=margins_name, 

56 observed=observed, 

57 ) 

58 pieces.append(table) 

59 keys.append(getattr(func, "__name__", func)) 

60 

61 return concat(pieces, keys=keys, axis=1) 

62 

63 keys = index + columns 

64 

65 values_passed = values is not None 

66 if values_passed: 

67 if is_list_like(values): 

68 values_multi = True 

69 values = list(values) 

70 else: 

71 values_multi = False 

72 values = [values] 

73 

74 # GH14938 Make sure value labels are in data 

75 for i in values: 

76 if i not in data: 

77 raise KeyError(i) 

78 

79 to_filter = [] 

80 for x in keys + values: 

81 if isinstance(x, Grouper): 

82 x = x.key 

83 try: 

84 if x in data: 

85 to_filter.append(x) 

86 except TypeError: 

87 pass 

88 if len(to_filter) < len(data.columns): 

89 data = data[to_filter] 

90 

91 else: 

92 values = data.columns 

93 for key in keys: 

94 try: 

95 values = values.drop(key) 

96 except (TypeError, ValueError, KeyError): 

97 pass 

98 values = list(values) 

99 

100 grouped = data.groupby(keys, observed=observed) 

101 agged = grouped.agg(aggfunc) 

102 if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns): 

103 agged = agged.dropna(how="all") 

104 

105 # gh-21133 

106 # we want to down cast if 

107 # the original values are ints 

108 # as we grouped with a NaN value 

109 # and then dropped, coercing to floats 

110 for v in values: 

111 if ( 

112 v in data 

113 and is_integer_dtype(data[v]) 

114 and v in agged 

115 and not is_integer_dtype(agged[v]) 

116 ): 

117 agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype) 

118 

119 table = agged 

120 if table.index.nlevels > 1: 

121 # Related GH #17123 

122 # If index_names are integers, determine whether the integers refer 

123 # to the level position or name. 

124 index_names = agged.index.names[: len(index)] 

125 to_unstack = [] 

126 for i in range(len(index), len(keys)): 

127 name = agged.index.names[i] 

128 if name is None or name in index_names: 

129 to_unstack.append(i) 

130 else: 

131 to_unstack.append(name) 

132 table = agged.unstack(to_unstack) 

133 

134 if not dropna: 

135 if table.index.nlevels > 1: 

136 m = MultiIndex.from_arrays( 

137 cartesian_product(table.index.levels), names=table.index.names 

138 ) 

139 table = table.reindex(m, axis=0) 

140 

141 if table.columns.nlevels > 1: 

142 m = MultiIndex.from_arrays( 

143 cartesian_product(table.columns.levels), names=table.columns.names 

144 ) 

145 table = table.reindex(m, axis=1) 

146 

147 if isinstance(table, ABCDataFrame): 

148 table = table.sort_index(axis=1) 

149 

150 if fill_value is not None: 

151 _table = table.fillna(fill_value, downcast="infer") 

152 assert _table is not None # needed for mypy 

153 table = _table 

154 

155 if margins: 

156 if dropna: 

157 data = data[data.notna().all(axis=1)] 

158 table = _add_margins( 

159 table, 

160 data, 

161 values, 

162 rows=index, 

163 cols=columns, 

164 aggfunc=aggfunc, 

165 observed=dropna, 

166 margins_name=margins_name, 

167 fill_value=fill_value, 

168 ) 

169 

170 # discard the top level 

171 if ( 

172 values_passed 

173 and not values_multi 

174 and not table.empty 

175 and (table.columns.nlevels > 1) 

176 ): 

177 table = table[values[0]] 

178 

179 if len(index) == 0 and len(columns) > 0: 

180 table = table.T 

181 

182 # GH 15193 Make sure empty columns are removed if dropna=True 

183 if isinstance(table, ABCDataFrame) and dropna: 

184 table = table.dropna(how="all", axis=1) 

185 

186 return table 

187 

188 

189def _add_margins( 

190 table: Union["Series", "DataFrame"], 

191 data, 

192 values, 

193 rows, 

194 cols, 

195 aggfunc, 

196 observed=None, 

197 margins_name: str = "All", 

198 fill_value=None, 

199): 

200 if not isinstance(margins_name, str): 

201 raise ValueError("margins_name argument must be a string") 

202 

203 msg = 'Conflicting name "{name}" in margins'.format(name=margins_name) 

204 for level in table.index.names: 

205 if margins_name in table.index.get_level_values(level): 

206 raise ValueError(msg) 

207 

208 grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name) 

209 

210 if table.ndim == 2: 

211 # i.e. DataFramae 

212 for level in table.columns.names[1:]: 

213 if margins_name in table.columns.get_level_values(level): 

214 raise ValueError(msg) 

215 

216 key: Union[str, Tuple[str, ...]] 

217 if len(rows) > 1: 

218 key = (margins_name,) + ("",) * (len(rows) - 1) 

219 else: 

220 key = margins_name 

221 

222 if not values and isinstance(table, ABCSeries): 

223 # If there are no values and the table is a series, then there is only 

224 # one column in the data. Compute grand margin and return it. 

225 return table.append(Series({key: grand_margin[margins_name]})) 

226 

227 elif values: 

228 marginal_result_set = _generate_marginal_results( 

229 table, 

230 data, 

231 values, 

232 rows, 

233 cols, 

234 aggfunc, 

235 observed, 

236 grand_margin, 

237 margins_name, 

238 ) 

239 if not isinstance(marginal_result_set, tuple): 

240 return marginal_result_set 

241 result, margin_keys, row_margin = marginal_result_set 

242 else: 

243 # no values, and table is a DataFrame 

244 assert isinstance(table, ABCDataFrame) 

245 marginal_result_set = _generate_marginal_results_without_values( 

246 table, data, rows, cols, aggfunc, observed, margins_name 

247 ) 

248 if not isinstance(marginal_result_set, tuple): 

249 return marginal_result_set 

250 result, margin_keys, row_margin = marginal_result_set 

251 

252 row_margin = row_margin.reindex(result.columns, fill_value=fill_value) 

253 # populate grand margin 

254 for k in margin_keys: 

255 if isinstance(k, str): 

256 row_margin[k] = grand_margin[k] 

257 else: 

258 row_margin[k] = grand_margin[k[0]] 

259 

260 from pandas import DataFrame 

261 

262 margin_dummy = DataFrame(row_margin, columns=[key]).T 

263 

264 row_names = result.index.names 

265 try: 

266 # check the result column and leave floats 

267 for dtype in set(result.dtypes): 

268 cols = result.select_dtypes([dtype]).columns 

269 margin_dummy[cols] = margin_dummy[cols].apply( 

270 maybe_downcast_to_dtype, args=(dtype,) 

271 ) 

272 result = result.append(margin_dummy) 

273 except TypeError: 

274 

275 # we cannot reshape, so coerce the axis 

276 result.index = result.index._to_safe_for_reshape() 

277 result = result.append(margin_dummy) 

278 result.index.names = row_names 

279 

280 return result 

281 

282 

283def _compute_grand_margin(data, values, aggfunc, margins_name: str = "All"): 

284 

285 if values: 

286 grand_margin = {} 

287 for k, v in data[values].items(): 

288 try: 

289 if isinstance(aggfunc, str): 

290 grand_margin[k] = getattr(v, aggfunc)() 

291 elif isinstance(aggfunc, dict): 

292 if isinstance(aggfunc[k], str): 

293 grand_margin[k] = getattr(v, aggfunc[k])() 

294 else: 

295 grand_margin[k] = aggfunc[k](v) 

296 else: 

297 grand_margin[k] = aggfunc(v) 

298 except TypeError: 

299 pass 

300 return grand_margin 

301 else: 

302 return {margins_name: aggfunc(data.index)} 

303 

304 

305def _generate_marginal_results( 

306 table, 

307 data, 

308 values, 

309 rows, 

310 cols, 

311 aggfunc, 

312 observed, 

313 grand_margin, 

314 margins_name: str = "All", 

315): 

316 if len(cols) > 0: 

317 # need to "interleave" the margins 

318 table_pieces = [] 

319 margin_keys = [] 

320 

321 def _all_key(key): 

322 return (key, margins_name) + ("",) * (len(cols) - 1) 

323 

324 if len(rows) > 0: 

325 margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc) 

326 cat_axis = 1 

327 

328 for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed): 

329 all_key = _all_key(key) 

330 

331 # we are going to mutate this, so need to copy! 

332 piece = piece.copy() 

333 try: 

334 piece[all_key] = margin[key] 

335 except TypeError: 

336 

337 # we cannot reshape, so coerce the axis 

338 piece.set_axis( 

339 piece._get_axis(cat_axis)._to_safe_for_reshape(), 

340 axis=cat_axis, 

341 inplace=True, 

342 ) 

343 piece[all_key] = margin[key] 

344 

345 table_pieces.append(piece) 

346 margin_keys.append(all_key) 

347 else: 

348 margin = grand_margin 

349 cat_axis = 0 

350 for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed): 

351 all_key = _all_key(key) 

352 table_pieces.append(piece) 

353 table_pieces.append(Series(margin[key], index=[all_key])) 

354 margin_keys.append(all_key) 

355 

356 result = concat(table_pieces, axis=cat_axis) 

357 

358 if len(rows) == 0: 

359 return result 

360 else: 

361 result = table 

362 margin_keys = table.columns 

363 

364 if len(cols) > 0: 

365 row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc) 

366 row_margin = row_margin.stack() 

367 

368 # slight hack 

369 new_order = [len(cols)] + list(range(len(cols))) 

370 row_margin.index = row_margin.index.reorder_levels(new_order) 

371 else: 

372 row_margin = Series(np.nan, index=result.columns) 

373 

374 return result, margin_keys, row_margin 

375 

376 

377def _generate_marginal_results_without_values( 

378 table: "DataFrame", data, rows, cols, aggfunc, observed, margins_name: str = "All" 

379): 

380 if len(cols) > 0: 

381 # need to "interleave" the margins 

382 margin_keys = [] 

383 

384 def _all_key(): 

385 if len(cols) == 1: 

386 return margins_name 

387 return (margins_name,) + ("",) * (len(cols) - 1) 

388 

389 if len(rows) > 0: 

390 margin = data[rows].groupby(rows, observed=observed).apply(aggfunc) 

391 all_key = _all_key() 

392 table[all_key] = margin 

393 result = table 

394 margin_keys.append(all_key) 

395 

396 else: 

397 margin = data.groupby(level=0, axis=0, observed=observed).apply(aggfunc) 

398 all_key = _all_key() 

399 table[all_key] = margin 

400 result = table 

401 margin_keys.append(all_key) 

402 return result 

403 else: 

404 result = table 

405 margin_keys = table.columns 

406 

407 if len(cols): 

408 row_margin = data[cols].groupby(cols, observed=observed).apply(aggfunc) 

409 else: 

410 row_margin = Series(np.nan, index=result.columns) 

411 

412 return result, margin_keys, row_margin 

413 

414 

415def _convert_by(by): 

416 if by is None: 

417 by = [] 

418 elif ( 

419 is_scalar(by) 

420 or isinstance(by, (np.ndarray, Index, ABCSeries, Grouper)) 

421 or hasattr(by, "__call__") 

422 ): 

423 by = [by] 

424 else: 

425 by = list(by) 

426 return by 

427 

428 

429@Substitution("\ndata : DataFrame") 

430@Appender(_shared_docs["pivot"], indents=1) 

431def pivot(data: "DataFrame", index=None, columns=None, values=None) -> "DataFrame": 

432 if values is None: 

433 cols = [columns] if index is None else [index, columns] 

434 append = index is None 

435 indexed = data.set_index(cols, append=append) 

436 else: 

437 if index is None: 

438 index = data.index 

439 else: 

440 index = data[index] 

441 index = MultiIndex.from_arrays([index, data[columns]]) 

442 

443 if is_list_like(values) and not isinstance(values, tuple): 

444 # Exclude tuple because it is seen as a single column name 

445 indexed = data._constructor( 

446 data[values].values, index=index, columns=values 

447 ) 

448 else: 

449 indexed = data._constructor_sliced(data[values].values, index=index) 

450 return indexed.unstack(columns) 

451 

452 

453def crosstab( 

454 index, 

455 columns, 

456 values=None, 

457 rownames=None, 

458 colnames=None, 

459 aggfunc=None, 

460 margins=False, 

461 margins_name: str = "All", 

462 dropna: bool = True, 

463 normalize=False, 

464) -> "DataFrame": 

465 """ 

466 Compute a simple cross tabulation of two (or more) factors. By default 

467 computes a frequency table of the factors unless an array of values and an 

468 aggregation function are passed. 

469 

470 Parameters 

471 ---------- 

472 index : array-like, Series, or list of arrays/Series 

473 Values to group by in the rows. 

474 columns : array-like, Series, or list of arrays/Series 

475 Values to group by in the columns. 

476 values : array-like, optional 

477 Array of values to aggregate according to the factors. 

478 Requires `aggfunc` be specified. 

479 rownames : sequence, default None 

480 If passed, must match number of row arrays passed. 

481 colnames : sequence, default None 

482 If passed, must match number of column arrays passed. 

483 aggfunc : function, optional 

484 If specified, requires `values` be specified as well. 

485 margins : bool, default False 

486 Add row/column margins (subtotals). 

487 margins_name : str, default 'All' 

488 Name of the row/column that will contain the totals 

489 when margins is True. 

490 

491 .. versionadded:: 0.21.0 

492 

493 dropna : bool, default True 

494 Do not include columns whose entries are all NaN. 

495 normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False 

496 Normalize by dividing all values by the sum of values. 

497 

498 - If passed 'all' or `True`, will normalize over all values. 

499 - If passed 'index' will normalize over each row. 

500 - If passed 'columns' will normalize over each column. 

501 - If margins is `True`, will also normalize margin values. 

502 

503 Returns 

504 ------- 

505 DataFrame 

506 Cross tabulation of the data. 

507 

508 See Also 

509 -------- 

510 DataFrame.pivot : Reshape data based on column values. 

511 pivot_table : Create a pivot table as a DataFrame. 

512 

513 Notes 

514 ----- 

515 Any Series passed will have their name attributes used unless row or column 

516 names for the cross-tabulation are specified. 

517 

518 Any input passed containing Categorical data will have **all** of its 

519 categories included in the cross-tabulation, even if the actual data does 

520 not contain any instances of a particular category. 

521 

522 In the event that there aren't overlapping indexes an empty DataFrame will 

523 be returned. 

524 

525 Examples 

526 -------- 

527 >>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar", 

528 ... "bar", "bar", "foo", "foo", "foo"], dtype=object) 

529 >>> b = np.array(["one", "one", "one", "two", "one", "one", 

530 ... "one", "two", "two", "two", "one"], dtype=object) 

531 >>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny", 

532 ... "shiny", "dull", "shiny", "shiny", "shiny"], 

533 ... dtype=object) 

534 >>> pd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c']) 

535 b one two 

536 c dull shiny dull shiny 

537 a 

538 bar 1 2 1 0 

539 foo 2 2 1 2 

540 

541 Here 'c' and 'f' are not represented in the data and will not be 

542 shown in the output because dropna is True by default. Set 

543 dropna=False to preserve categories with no data. 

544 

545 >>> foo = pd.Categorical(['a', 'b'], categories=['a', 'b', 'c']) 

546 >>> bar = pd.Categorical(['d', 'e'], categories=['d', 'e', 'f']) 

547 >>> pd.crosstab(foo, bar) 

548 col_0 d e 

549 row_0 

550 a 1 0 

551 b 0 1 

552 >>> pd.crosstab(foo, bar, dropna=False) 

553 col_0 d e f 

554 row_0 

555 a 1 0 0 

556 b 0 1 0 

557 c 0 0 0 

558 """ 

559 

560 index = com.maybe_make_list(index) 

561 columns = com.maybe_make_list(columns) 

562 

563 rownames = _get_names(index, rownames, prefix="row") 

564 colnames = _get_names(columns, colnames, prefix="col") 

565 

566 common_idx = None 

567 pass_objs = [x for x in index + columns if isinstance(x, (ABCSeries, ABCDataFrame))] 

568 if pass_objs: 

569 common_idx = get_objs_combined_axis(pass_objs, intersect=True, sort=False) 

570 

571 data: Dict = {} 

572 data.update(zip(rownames, index)) 

573 data.update(zip(colnames, columns)) 

574 

575 if values is None and aggfunc is not None: 

576 raise ValueError("aggfunc cannot be used without values.") 

577 

578 if values is not None and aggfunc is None: 

579 raise ValueError("values cannot be used without an aggfunc.") 

580 

581 from pandas import DataFrame 

582 

583 df = DataFrame(data, index=common_idx) 

584 if values is None: 

585 df["__dummy__"] = 0 

586 kwargs = {"aggfunc": len, "fill_value": 0} 

587 else: 

588 df["__dummy__"] = values 

589 kwargs = {"aggfunc": aggfunc} 

590 

591 table = df.pivot_table( 

592 "__dummy__", 

593 index=rownames, 

594 columns=colnames, 

595 margins=margins, 

596 margins_name=margins_name, 

597 dropna=dropna, 

598 **kwargs, 

599 ) 

600 

601 # Post-process 

602 if normalize is not False: 

603 table = _normalize( 

604 table, normalize=normalize, margins=margins, margins_name=margins_name 

605 ) 

606 

607 return table 

608 

609 

610def _normalize(table, normalize, margins: bool, margins_name="All"): 

611 

612 if not isinstance(normalize, (bool, str)): 

613 axis_subs = {0: "index", 1: "columns"} 

614 try: 

615 normalize = axis_subs[normalize] 

616 except KeyError: 

617 raise ValueError("Not a valid normalize argument") 

618 

619 if margins is False: 

620 

621 # Actual Normalizations 

622 normalizers: Dict[Union[bool, str], Callable] = { 

623 "all": lambda x: x / x.sum(axis=1).sum(axis=0), 

624 "columns": lambda x: x / x.sum(), 

625 "index": lambda x: x.div(x.sum(axis=1), axis=0), 

626 } 

627 

628 normalizers[True] = normalizers["all"] 

629 

630 try: 

631 f = normalizers[normalize] 

632 except KeyError: 

633 raise ValueError("Not a valid normalize argument") 

634 

635 table = f(table) 

636 table = table.fillna(0) 

637 

638 elif margins is True: 

639 # keep index and column of pivoted table 

640 table_index = table.index 

641 table_columns = table.columns 

642 

643 # check if margin name is in (for MI cases) or equal to last 

644 # index/column and save the column and index margin 

645 if (margins_name not in table.iloc[-1, :].name) | ( 

646 margins_name != table.iloc[:, -1].name 

647 ): 

648 raise ValueError( 

649 "{mname} not in pivoted DataFrame".format(mname=margins_name) 

650 ) 

651 column_margin = table.iloc[:-1, -1] 

652 index_margin = table.iloc[-1, :-1] 

653 

654 # keep the core table 

655 table = table.iloc[:-1, :-1] 

656 

657 # Normalize core 

658 table = _normalize(table, normalize=normalize, margins=False) 

659 

660 # Fix Margins 

661 if normalize == "columns": 

662 column_margin = column_margin / column_margin.sum() 

663 table = concat([table, column_margin], axis=1) 

664 table = table.fillna(0) 

665 table.columns = table_columns 

666 

667 elif normalize == "index": 

668 index_margin = index_margin / index_margin.sum() 

669 table = table.append(index_margin) 

670 table = table.fillna(0) 

671 table.index = table_index 

672 

673 elif normalize == "all" or normalize is True: 

674 column_margin = column_margin / column_margin.sum() 

675 index_margin = index_margin / index_margin.sum() 

676 index_margin.loc[margins_name] = 1 

677 table = concat([table, column_margin], axis=1) 

678 table = table.append(index_margin) 

679 

680 table = table.fillna(0) 

681 table.index = table_index 

682 table.columns = table_columns 

683 

684 else: 

685 raise ValueError("Not a valid normalize argument") 

686 

687 else: 

688 raise ValueError("Not a valid margins argument") 

689 

690 return table 

691 

692 

693def _get_names(arrs, names, prefix: str = "row"): 

694 if names is None: 

695 names = [] 

696 for i, arr in enumerate(arrs): 

697 if isinstance(arr, ABCSeries) and arr.name is not None: 

698 names.append(arr.name) 

699 else: 

700 names.append("{prefix}_{i}".format(prefix=prefix, i=i)) 

701 else: 

702 if len(names) != len(arrs): 

703 raise AssertionError("arrays and names must have the same length") 

704 if not isinstance(names, list): 

705 names = list(names) 

706 

707 return names