Coverage for /Users/sebastiana/Documents/Sugarpills/confidence/spotify_confidence/analysis/frequentist/chartify_grapher.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

213 statements  

1# Copyright 2017-2020 Spotify AB 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from typing import Union, Iterable, Tuple 

16 

17import numpy as np 

18from bokeh.models import tools 

19from chartify import Chart 

20from pandas import DataFrame 

21 

22from ..abstract_base_classes.confidence_grapher_abc import ConfidenceGrapherABC 

23from ..confidence_utils import ( 

24 axis_format_precision, 

25 add_color_column, 

26 get_remaning_groups, 

27 get_all_group_columns, 

28 listify, 

29 level2str, 

30 to_finite, 

31) 

32from ..constants import ( 

33 POINT_ESTIMATE, 

34 DIFFERENCE, 

35 CI_LOWER, 

36 CI_UPPER, 

37 P_VALUE, 

38 ADJUSTED_LOWER, 

39 ADJUSTED_UPPER, 

40 ADJUSTED_P, 

41 NULL_HYPOTHESIS, 

42 NIM, 

43 NIM_TYPE, 

44) 

45from ...chartgrid import ChartGrid 

46 

47 

48class ChartifyGrapher(ConfidenceGrapherABC): 

49 def __init__( 

50 self, 

51 data_frame: DataFrame, 

52 numerator_column: str, 

53 denominator_column: str, 

54 categorical_group_columns: str, 

55 ordinal_group_column: str, 

56 ): 

57 

58 self._df = data_frame 

59 self._numerator = numerator_column 

60 self._denominator = denominator_column 

61 self._categorical_group_columns = categorical_group_columns 

62 self._ordinal_group_column = ordinal_group_column 

63 self._all_group_columns = get_all_group_columns(self._categorical_group_columns, self._ordinal_group_column) 

64 

65 def plot_summary(self, summary_df: DataFrame, groupby: Union[str, Iterable]) -> ChartGrid: 

66 

67 ch = ChartGrid() 

68 if groupby is None: 

69 ch.charts.append(self._summary_plot(level_name=None, level_df=summary_df, groupby=groupby)) 

70 else: 

71 for level_name, level_df in summary_df.groupby(groupby): 

72 ch.charts.append(self._summary_plot(level_name=level_name, level_df=level_df, groupby=groupby)) 

73 return ch 

74 

75 def plot_difference( 

76 self, difference_df, absolute, groupby, nims: NIM_TYPE, use_adjusted_intervals: bool 

77 ) -> ChartGrid: 

78 if self._ordinal_group_column in listify(groupby): 

79 ch = self._ordinal_difference_plot(difference_df, absolute, groupby, use_adjusted_intervals) 

80 chart_grid = ChartGrid([ch]) 

81 else: 

82 chart_grid = self._categorical_difference_plot(difference_df, absolute, groupby, use_adjusted_intervals) 

83 return chart_grid 

84 

85 def plot_differences( 

86 self, difference_df, absolute, groupby, nims: NIM_TYPE, use_adjusted_intervals: bool 

87 ) -> ChartGrid: 

88 

89 remaining_groups = get_remaning_groups(groupby, self._ordinal_group_column) 

90 groupby_columns = self._add_level_columns(remaining_groups) 

91 

92 if self._ordinal_group_column in listify(groupby): 

93 ch = self._ordinal_difference_plot(difference_df, absolute, groupby_columns, use_adjusted_intervals) 

94 chart_grid = ChartGrid([ch]) 

95 else: 

96 chart_grid = self._categorical_difference_plot( 

97 difference_df, absolute, groupby_columns, use_adjusted_intervals 

98 ) 

99 return chart_grid 

100 

101 def plot_multiple_difference( 

102 self, 

103 difference_df, 

104 absolute, 

105 groupby, 

106 level_as_reference, 

107 nims: NIM_TYPE, 

108 use_adjusted_intervals: bool, 

109 ) -> ChartGrid: 

110 if self._ordinal_group_column in listify(groupby): 

111 ch = self._ordinal_multiple_difference_plot( 

112 difference_df, absolute, groupby, level_as_reference, use_adjusted_intervals 

113 ) 

114 chart_grid = ChartGrid([ch]) 

115 else: 

116 chart_grid = self._categorical_multiple_difference_plot( 

117 difference_df, absolute, groupby, level_as_reference, use_adjusted_intervals 

118 ) 

119 return chart_grid 

120 

121 def _ordinal_difference_plot( 

122 self, difference_df: DataFrame, absolute: bool, groupby: Union[str, Iterable], use_adjusted_intervals: bool 

123 ) -> Chart: 

124 remaining_groups = get_remaning_groups(groupby, self._ordinal_group_column) 

125 

126 if "level_1" in groupby and "level_2" in groupby: 

127 title = "Change from level_1 to level_2" 

128 else: 

129 title = "Change from {} to {}".format( 

130 difference_df["level_1"].values[0], difference_df["level_2"].values[0] 

131 ) 

132 

133 y_axis_label = self._get_difference_plot_label(absolute) 

134 ch = self._ordinal_plot( 

135 "difference", 

136 difference_df, 

137 groupby=None, 

138 level_name="", 

139 remaining_groups=remaining_groups, 

140 absolute=absolute, 

141 title=title, 

142 y_axis_label=y_axis_label, 

143 use_adjusted_intervals=use_adjusted_intervals, 

144 ) 

145 ch.callout.line(0) 

146 

147 return ch 

148 

149 def _get_difference_plot_label(self, absolute): 

150 change_type = "Absolute" if absolute else "Relative" 

151 return change_type + " change in {} / {}".format(self._numerator, self._denominator) 

152 

153 def _categorical_difference_plot( 

154 self, difference_df: DataFrame, absolute: bool, groupby: Union[str, Iterable], use_adjusted_intervals: bool 

155 ) -> ChartGrid: 

156 if groupby is None: 

157 groupby = "dummy_groupby" 

158 difference_df[groupby] = "Difference" 

159 

160 if "level_1" in groupby and "level_2" in groupby: 

161 title = "Change from level_1 to level_2" 

162 else: 

163 title = "Change from {} to {}".format( 

164 difference_df["level_1"].values[0], difference_df["level_2"].values[0] 

165 ) 

166 x_label = "" if groupby is None else "{}".format(groupby) 

167 

168 chart_grid = self._categorical_difference_chart( 

169 absolute, difference_df, groupby, title, x_label, use_adjusted_intervals 

170 ) 

171 

172 return chart_grid 

173 

174 def _categorical_difference_chart( 

175 self, 

176 absolute: bool, 

177 difference_df: DataFrame, 

178 groupby_columns: Union[str, Iterable], 

179 title: str, 

180 x_label: str, 

181 use_adjusted_intervals: bool, 

182 ) -> ChartGrid: 

183 LOWER, UPPER = (ADJUSTED_LOWER, ADJUSTED_UPPER) if use_adjusted_intervals else (CI_LOWER, CI_UPPER) 

184 axis_format, y_min, y_max = axis_format_precision( 

185 numbers=( 

186 difference_df[LOWER] 

187 .append(difference_df[DIFFERENCE]) 

188 .append(difference_df[UPPER]) 

189 .append(difference_df[NULL_HYPOTHESIS] if NULL_HYPOTHESIS in difference_df.columns else None) 

190 ), 

191 absolute=absolute, 

192 ) 

193 

194 df = ( 

195 difference_df.assign(**{LOWER: to_finite(difference_df[LOWER], y_min)}) 

196 .assign(**{UPPER: to_finite(difference_df[UPPER], y_max)}) 

197 .assign(level_1=difference_df.level_1.map(level2str)) 

198 .assign(level_2=difference_df.level_2.map(level2str)) 

199 .set_index(groupby_columns) 

200 .assign(categorical_x=lambda df: df.index.to_numpy()) 

201 .reset_index() 

202 ) 

203 

204 ch = Chart(x_axis_type="categorical") 

205 ch.plot.interval( 

206 data_frame=df.sort_values(groupby_columns), 

207 categorical_columns=groupby_columns, 

208 lower_bound_column=LOWER, 

209 upper_bound_column=UPPER, 

210 middle_column=DIFFERENCE, 

211 categorical_order_by="labels", 

212 categorical_order_ascending=False, 

213 ) 

214 # Also plot transparent circles, just to be able to show hover box 

215 ch.style.color_palette.reset_palette_order() 

216 ch.figure.circle( 

217 source=df, x="categorical_x", y=DIFFERENCE, size=20, name="center", line_alpha=0, fill_alpha=0 

218 ) 

219 if NULL_HYPOTHESIS in df.columns: 

220 ch.style.color_palette.reset_palette_order() 

221 dash_source = ( 

222 df[~df[NIM].isna()] 

223 .assign( 

224 color_column=lambda df: df.apply( 

225 lambda row: "red" 

226 if row[LOWER] < row[NULL_HYPOTHESIS] and row[NULL_HYPOTHESIS] < row[UPPER] 

227 else "green", 

228 axis=1, 

229 ) 

230 ) 

231 .sort_values(groupby_columns) 

232 ) 

233 ch.figure.dash( 

234 source=dash_source, 

235 x="categorical_x", 

236 y=NULL_HYPOTHESIS, 

237 size=320 / len(df), 

238 line_width=3, 

239 name="nim", 

240 line_color="color_column", 

241 ) 

242 ch.axes.set_yaxis_label(self._get_difference_plot_label(absolute)) 

243 ch.set_source_label("") 

244 ch.callout.line(0) 

245 ch.axes.set_yaxis_range(y_min - 0.05 * (y_max - y_min), y_max + 0.05 * (y_max - y_min)) 

246 ch.axes.set_yaxis_tick_format(axis_format) 

247 ch.set_title(title) 

248 ch.axes.set_xaxis_label(x_label) 

249 ch.set_subtitle("") 

250 

251 self.add_tools( 

252 chart=ch, 

253 df=( 

254 difference_df.set_index(groupby_columns) 

255 .assign(categorical_x=lambda df: df.index.to_numpy()) 

256 .reset_index() 

257 ), 

258 center_name=DIFFERENCE, 

259 absolute=absolute, 

260 ordinal=False, 

261 use_adjusted_intervals=use_adjusted_intervals, 

262 ) 

263 

264 chart_grid = ChartGrid() 

265 chart_grid.charts.append(ch) 

266 

267 return chart_grid 

268 

269 def _summary_plot(self, level_name: Union[str, Tuple], level_df: DataFrame, groupby: Union[str, Iterable]): 

270 remaining_groups = get_remaning_groups(self._all_group_columns, groupby) 

271 if self._ordinal_group_column is not None and self._ordinal_group_column in remaining_groups: 

272 

273 ch = self._ordinal_summary_plot(level_name, level_df, remaining_groups, groupby) 

274 else: 

275 ch = self._categorical_summary_plot(level_name, level_df, remaining_groups, groupby) 

276 return ch 

277 

278 def _ordinal_summary_plot( 

279 self, 

280 level_name: Union[str, Tuple], 

281 level_df: DataFrame, 

282 remaining_groups: Union[str, Iterable], 

283 groupby: Union[str, Iterable], 

284 ): 

285 remaining_groups = get_remaning_groups(remaining_groups, self._ordinal_group_column) 

286 title = "Estimate of {} / {}".format(self._numerator, self._denominator) 

287 y_axis_label = "{} / {}".format(self._numerator, self._denominator) 

288 return self._ordinal_plot( 

289 POINT_ESTIMATE, 

290 level_df, 

291 groupby, 

292 level_name, 

293 remaining_groups, 

294 absolute=True, 

295 title=title, 

296 y_axis_label=y_axis_label, 

297 use_adjusted_intervals=False, 

298 ) 

299 

300 def _ordinal_plot( 

301 self, 

302 center_name: str, 

303 level_df: DataFrame, 

304 groupby: Union[str, Iterable], 

305 level_name: Union[str, Tuple], 

306 remaining_groups: Union[str, Iterable], 

307 absolute: bool, 

308 title: str, 

309 y_axis_label: str, 

310 use_adjusted_intervals: bool, 

311 ): 

312 LOWER, UPPER = (ADJUSTED_LOWER, ADJUSTED_UPPER) if use_adjusted_intervals else (CI_LOWER, CI_UPPER) 

313 df = add_color_column(level_df, remaining_groups) 

314 colors = "color" if remaining_groups else None 

315 axis_format, y_min, y_max = axis_format_precision( 

316 numbers=( 

317 df[LOWER] 

318 .append(df[center_name]) 

319 .append(df[UPPER]) 

320 .append(df[NULL_HYPOTHESIS] if NULL_HYPOTHESIS in df.columns else None) 

321 ), 

322 absolute=absolute, 

323 ) 

324 ch = Chart(x_axis_type=self._ordinal_type()) 

325 ch.plot.line( 

326 data_frame=df.sort_values(self._ordinal_group_column), 

327 x_column=self._ordinal_group_column, 

328 y_column=center_name, 

329 color_column=colors, 

330 ) 

331 ch.style.color_palette.reset_palette_order() 

332 ch.plot.area( 

333 data_frame=( 

334 df.assign(**{LOWER: to_finite(df[LOWER], y_min)}) 

335 .assign(**{UPPER: to_finite(df[UPPER], y_max)}) 

336 .sort_values(self._ordinal_group_column) 

337 ), 

338 x_column=self._ordinal_group_column, 

339 y_column=LOWER, 

340 second_y_column=UPPER, 

341 color_column=colors, 

342 ) 

343 if NULL_HYPOTHESIS in df.columns: 

344 ch.style.color_palette.reset_palette_order() 

345 ch.plot.line( 

346 data_frame=df.sort_values(self._ordinal_group_column), 

347 x_column=self._ordinal_group_column, 

348 y_column=NULL_HYPOTHESIS, 

349 color_column=colors, 

350 line_dash="dashed", 

351 line_width=1, 

352 ) 

353 ch.axes.set_yaxis_label(y_axis_label) 

354 ch.axes.set_xaxis_label(self._ordinal_group_column) 

355 ch.set_source_label("") 

356 ch.axes.set_yaxis_range(y_min - 0.05 * (y_max - y_min), y_max + 0.05 * (y_max - y_min)) 

357 ch.axes.set_yaxis_tick_format(axis_format) 

358 subtitle = "" if not groupby else "{}: {}".format(groupby, level_name) 

359 ch.set_subtitle(subtitle) 

360 ch.set_title(title) 

361 if colors: 

362 ch.set_legend_location("outside_bottom") 

363 self.add_tools( 

364 chart=ch, 

365 df=df, 

366 center_name=center_name, 

367 absolute=absolute, 

368 ordinal=True, 

369 use_adjusted_intervals=use_adjusted_intervals, 

370 ) 

371 return ch 

372 

373 def _categorical_summary_plot(self, level_name, summary_df, remaining_groups, groupby): 

374 if not remaining_groups: 

375 remaining_groups = listify(groupby) 

376 df = summary_df.set_index(remaining_groups).assign(categorical_x=lambda df: df.index.to_numpy()).reset_index() 

377 

378 axis_format, y_min, y_max = axis_format_precision( 

379 numbers=(df[CI_LOWER].append(df[POINT_ESTIMATE]).append(df[CI_UPPER])), absolute=True 

380 ) 

381 

382 ch = Chart(x_axis_type="categorical") 

383 ch.plot.interval( 

384 ( 

385 df.assign(**{CI_LOWER: to_finite(df[CI_LOWER], y_min)}).assign( 

386 **{CI_UPPER: to_finite(df[CI_UPPER], y_max)} 

387 ) 

388 ), 

389 categorical_columns=remaining_groups, 

390 lower_bound_column=CI_LOWER, 

391 upper_bound_column=CI_UPPER, 

392 middle_column=POINT_ESTIMATE, 

393 categorical_order_by="labels", 

394 categorical_order_ascending=True, 

395 ) 

396 # Also plot transparent circles, just to be able to show hover box 

397 ch.style.color_palette.reset_palette_order() 

398 ch.figure.circle( 

399 source=df, x="categorical_x", y=POINT_ESTIMATE, size=20, name="center", line_alpha=0, fill_alpha=0 

400 ) 

401 ch.set_title("Estimate of {} / {}".format(self._numerator, self._denominator)) 

402 if groupby: 

403 ch.set_subtitle("{}: {}".format(groupby, level_name)) 

404 else: 

405 ch.set_subtitle("") 

406 ch.axes.set_xaxis_label("{}".format(", ".join(remaining_groups))) 

407 ch.axes.set_yaxis_label("{} / {}".format(self._numerator, self._denominator)) 

408 ch.set_source_label("") 

409 ch.axes.set_yaxis_tick_format(axis_format) 

410 self.add_tools( 

411 chart=ch, df=df, center_name=POINT_ESTIMATE, absolute=True, ordinal=False, use_adjusted_intervals=False 

412 ) 

413 return ch 

414 

415 def _ordinal_type(self): 

416 ordinal_column_type = self._df[self._ordinal_group_column].dtype.type 

417 axis_type = "datetime" if issubclass(ordinal_column_type, np.datetime64) else "linear" 

418 return axis_type 

419 

420 def _ordinal_multiple_difference_plot( 

421 self, 

422 difference_df: DataFrame, 

423 absolute: bool, 

424 groupby: Union[str, Iterable], 

425 level_as_reference: bool, 

426 use_adjusted_intervals: bool, 

427 ): 

428 remaining_groups = get_remaning_groups(groupby, self._ordinal_group_column) 

429 groupby_columns = self._add_level_column(remaining_groups, level_as_reference) 

430 title = self._get_multiple_difference_title(difference_df, level_as_reference) 

431 y_axis_label = self._get_difference_plot_label(absolute) 

432 ch = self._ordinal_plot( 

433 DIFFERENCE, 

434 difference_df, 

435 groupby=None, 

436 level_name="", 

437 remaining_groups=groupby_columns, 

438 absolute=absolute, 

439 title=title, 

440 y_axis_label=y_axis_label, 

441 use_adjusted_intervals=use_adjusted_intervals, 

442 ) 

443 ch.callout.line(0) 

444 return ch 

445 

446 def _categorical_multiple_difference_plot( 

447 self, 

448 difference_df: DataFrame, 

449 absolute: bool, 

450 groupby: Union[str, Iterable], 

451 level_as_reference: bool, 

452 use_adjusted_intervals: bool, 

453 ): 

454 groupby_columns = self._add_level_column(groupby, level_as_reference) 

455 title = self._get_multiple_difference_title(difference_df, level_as_reference) 

456 x_label = "" if groupby is None else "{}".format(groupby) 

457 chart_grid = self._categorical_difference_chart( 

458 absolute, difference_df, groupby_columns, title, x_label, use_adjusted_intervals 

459 ) 

460 

461 return chart_grid 

462 

463 def _get_multiple_difference_title(self, difference_df, level_as_reference): 

464 reference_level = "level_1" if level_as_reference else "level_2" 

465 title = "Comparison to {}".format(difference_df[reference_level].values[0]) 

466 return title 

467 

468 def _add_level_column(self, groupby, level_as_reference): 

469 level_column = "level_2" if level_as_reference else "level_1" 

470 if groupby is None: 

471 groupby_columns = level_column 

472 else: 

473 if isinstance(groupby, str): 

474 groupby_columns = [groupby, level_column] 

475 else: 

476 groupby_columns = groupby + [level_column] 

477 return groupby_columns 

478 

479 def _add_level_columns(self, groupby): 

480 levels = ["level_1", "level_2"] 

481 if groupby is None: 

482 groupby_columns = levels 

483 else: 

484 if isinstance(groupby, str): 

485 groupby_columns = [groupby] + levels 

486 else: 

487 groupby_columns = groupby + levels 

488 return groupby_columns 

489 

490 def add_ci_to_chart_datasources( 

491 self, chart: Chart, df: DataFrame, center_name: str, ordinal: bool, use_adjusted_intervals: bool 

492 ): 

493 LOWER, UPPER = (ADJUSTED_LOWER, ADJUSTED_UPPER) if use_adjusted_intervals else (CI_LOWER, CI_UPPER) 

494 group_col = "color" if ordinal and "color" in df.columns else "categorical_x" 

495 for data in chart.data: 

496 if center_name in data.keys() or NULL_HYPOTHESIS in data.keys(): 

497 index = data["index"] 

498 data[LOWER] = np.array(df[LOWER][index]) 

499 data[UPPER] = np.array(df[UPPER][index]) 

500 data["color"] = np.array(df[group_col][index]) 

501 if DIFFERENCE in data.keys() or NULL_HYPOTHESIS in data.keys(): 

502 index = data["index"] 

503 data[DIFFERENCE] = np.array(df[DIFFERENCE][index]) 

504 data["p_value"] = np.array(df[P_VALUE][index]) 

505 data["adjusted_p"] = np.array(df[ADJUSTED_P][index]) 

506 if NULL_HYPOTHESIS in df.columns: 

507 data["null_hyp"] = np.array(df[NULL_HYPOTHESIS][index]) 

508 

509 def add_tools( 

510 self, 

511 chart: Chart, 

512 df: DataFrame, 

513 center_name: str, 

514 absolute: bool, 

515 ordinal: bool, 

516 use_adjusted_intervals: bool, 

517 ): 

518 self.add_ci_to_chart_datasources(chart, df, center_name, ordinal, use_adjusted_intervals) 

519 LOWER, UPPER = (ADJUSTED_LOWER, ADJUSTED_UPPER) if use_adjusted_intervals else (CI_LOWER, CI_UPPER) 

520 

521 if len(chart.figure.legend) > 0: 

522 chart.figure.legend.click_policy = "hide" 

523 axis_format, y_min, y_max = axis_format_precision( 

524 numbers=( 

525 df[LOWER] 

526 .append(df[center_name]) 

527 .append(df[UPPER]) 

528 .append(df[NULL_HYPOTHESIS] if NULL_HYPOTHESIS in df.columns else None) 

529 ), 

530 absolute=absolute, 

531 extra_zeros=2, 

532 ) 

533 ordinal_tool_tip = [] if not ordinal else [(self._ordinal_group_column, f"@{self._ordinal_group_column}")] 

534 p_value_tool_tip = ( 

535 ( 

536 [("p-value", "@p_value{0.0000}")] 

537 + ([("adjusted p-value", "@adjusted_p{0.0000}")] if len(df) > 1 else []) 

538 ) 

539 if center_name == DIFFERENCE 

540 else [] 

541 ) 

542 nim_tool_tip = [("null hypothesis", f"@null_hyp{{{axis_format}}}")] if NULL_HYPOTHESIS in df.columns else [] 

543 tooltips = ( 

544 [("group", "@color")] 

545 + ordinal_tool_tip 

546 + [(f"{center_name}", f"@{center_name}{{{axis_format}}}")] 

547 + [ 

548 ( 

549 ("adjusted " if use_adjusted_intervals else "") + "confidence interval", 

550 f"(@{{{LOWER}}}{{{axis_format}}}," f" @{{{UPPER}}}{{{axis_format}}})", 

551 ) 

552 ] 

553 + p_value_tool_tip 

554 + nim_tool_tip 

555 ) 

556 lines_with_hover = [] if ordinal else ["center", "nim"] 

557 hover = tools.HoverTool(tooltips=tooltips, names=lines_with_hover) 

558 box_zoom = tools.BoxZoomTool() 

559 

560 chart.figure.add_tools( 

561 hover, tools.ZoomInTool(), tools.ZoomOutTool(), box_zoom, tools.PanTool(), tools.ResetTool() 

562 ) 

563 chart.figure.toolbar.active_drag = box_zoom