1# Copyright 2017-2020 Spotify AB
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Union, Iterable, Tuple
16
17import numpy as np
18from bokeh.models import tools
19from chartify import Chart
20from pandas import DataFrame
21
22from ..abstract_base_classes.confidence_grapher_abc import ConfidenceGrapherABC
23from ..confidence_utils import (
24 axis_format_precision,
25 add_color_column,
26 get_remaning_groups,
27 get_all_group_columns,
28 listify,
29 level2str,
30 to_finite,
31)
32from ..constants import (
33 POINT_ESTIMATE,
34 DIFFERENCE,
35 CI_LOWER,
36 CI_UPPER,
37 P_VALUE,
38 ADJUSTED_LOWER,
39 ADJUSTED_UPPER,
40 ADJUSTED_P,
41 NULL_HYPOTHESIS,
42 NIM,
43 NIM_TYPE,
44)
45from ...chartgrid import ChartGrid
46
47
48class ChartifyGrapher(ConfidenceGrapherABC):
49 def __init__(
50 self,
51 data_frame: DataFrame,
52 numerator_column: str,
53 denominator_column: str,
54 categorical_group_columns: str,
55 ordinal_group_column: str,
56 ):
57
58 self._df = data_frame
59 self._numerator = numerator_column
60 self._denominator = denominator_column
61 self._categorical_group_columns = categorical_group_columns
62 self._ordinal_group_column = ordinal_group_column
63 self._all_group_columns = get_all_group_columns(self._categorical_group_columns, self._ordinal_group_column)
64
65 def plot_summary(self, summary_df: DataFrame, groupby: Union[str, Iterable]) -> ChartGrid:
66
67 ch = ChartGrid()
68 if groupby is None:
69 ch.charts.append(self._summary_plot(level_name=None, level_df=summary_df, groupby=groupby))
70 else:
71 for level_name, level_df in summary_df.groupby(groupby):
72 ch.charts.append(self._summary_plot(level_name=level_name, level_df=level_df, groupby=groupby))
73 return ch
74
75 def plot_difference(
76 self, difference_df, absolute, groupby, nims: NIM_TYPE, use_adjusted_intervals: bool
77 ) -> ChartGrid:
78 if self._ordinal_group_column in listify(groupby):
79 ch = self._ordinal_difference_plot(difference_df, absolute, groupby, use_adjusted_intervals)
80 chart_grid = ChartGrid([ch])
81 else:
82 chart_grid = self._categorical_difference_plot(difference_df, absolute, groupby, use_adjusted_intervals)
83 return chart_grid
84
85 def plot_differences(
86 self, difference_df, absolute, groupby, nims: NIM_TYPE, use_adjusted_intervals: bool
87 ) -> ChartGrid:
88
89 remaining_groups = get_remaning_groups(groupby, self._ordinal_group_column)
90 groupby_columns = self._add_level_columns(remaining_groups)
91
92 if self._ordinal_group_column in listify(groupby):
93 ch = self._ordinal_difference_plot(difference_df, absolute, groupby_columns, use_adjusted_intervals)
94 chart_grid = ChartGrid([ch])
95 else:
96 chart_grid = self._categorical_difference_plot(
97 difference_df, absolute, groupby_columns, use_adjusted_intervals
98 )
99 return chart_grid
100
101 def plot_multiple_difference(
102 self,
103 difference_df,
104 absolute,
105 groupby,
106 level_as_reference,
107 nims: NIM_TYPE,
108 use_adjusted_intervals: bool,
109 ) -> ChartGrid:
110 if self._ordinal_group_column in listify(groupby):
111 ch = self._ordinal_multiple_difference_plot(
112 difference_df, absolute, groupby, level_as_reference, use_adjusted_intervals
113 )
114 chart_grid = ChartGrid([ch])
115 else:
116 chart_grid = self._categorical_multiple_difference_plot(
117 difference_df, absolute, groupby, level_as_reference, use_adjusted_intervals
118 )
119 return chart_grid
120
121 def _ordinal_difference_plot(
122 self, difference_df: DataFrame, absolute: bool, groupby: Union[str, Iterable], use_adjusted_intervals: bool
123 ) -> Chart:
124 remaining_groups = get_remaning_groups(groupby, self._ordinal_group_column)
125
126 if "level_1" in groupby and "level_2" in groupby:
127 title = "Change from level_1 to level_2"
128 else:
129 title = "Change from {} to {}".format(
130 difference_df["level_1"].values[0], difference_df["level_2"].values[0]
131 )
132
133 y_axis_label = self._get_difference_plot_label(absolute)
134 ch = self._ordinal_plot(
135 "difference",
136 difference_df,
137 groupby=None,
138 level_name="",
139 remaining_groups=remaining_groups,
140 absolute=absolute,
141 title=title,
142 y_axis_label=y_axis_label,
143 use_adjusted_intervals=use_adjusted_intervals,
144 )
145 ch.callout.line(0)
146
147 return ch
148
149 def _get_difference_plot_label(self, absolute):
150 change_type = "Absolute" if absolute else "Relative"
151 return change_type + " change in {} / {}".format(self._numerator, self._denominator)
152
153 def _categorical_difference_plot(
154 self, difference_df: DataFrame, absolute: bool, groupby: Union[str, Iterable], use_adjusted_intervals: bool
155 ) -> ChartGrid:
156 if groupby is None:
157 groupby = "dummy_groupby"
158 difference_df[groupby] = "Difference"
159
160 if "level_1" in groupby and "level_2" in groupby:
161 title = "Change from level_1 to level_2"
162 else:
163 title = "Change from {} to {}".format(
164 difference_df["level_1"].values[0], difference_df["level_2"].values[0]
165 )
166 x_label = "" if groupby is None else "{}".format(groupby)
167
168 chart_grid = self._categorical_difference_chart(
169 absolute, difference_df, groupby, title, x_label, use_adjusted_intervals
170 )
171
172 return chart_grid
173
174 def _categorical_difference_chart(
175 self,
176 absolute: bool,
177 difference_df: DataFrame,
178 groupby_columns: Union[str, Iterable],
179 title: str,
180 x_label: str,
181 use_adjusted_intervals: bool,
182 ) -> ChartGrid:
183 LOWER, UPPER = (ADJUSTED_LOWER, ADJUSTED_UPPER) if use_adjusted_intervals else (CI_LOWER, CI_UPPER)
184 axis_format, y_min, y_max = axis_format_precision(
185 numbers=(
186 difference_df[LOWER]
187 .append(difference_df[DIFFERENCE])
188 .append(difference_df[UPPER])
189 .append(difference_df[NULL_HYPOTHESIS] if NULL_HYPOTHESIS in difference_df.columns else None)
190 ),
191 absolute=absolute,
192 )
193
194 df = (
195 difference_df.assign(**{LOWER: to_finite(difference_df[LOWER], y_min)})
196 .assign(**{UPPER: to_finite(difference_df[UPPER], y_max)})
197 .assign(level_1=difference_df.level_1.map(level2str))
198 .assign(level_2=difference_df.level_2.map(level2str))
199 .set_index(groupby_columns)
200 .assign(categorical_x=lambda df: df.index.to_numpy())
201 .reset_index()
202 )
203
204 ch = Chart(x_axis_type="categorical")
205 ch.plot.interval(
206 data_frame=df.sort_values(groupby_columns),
207 categorical_columns=groupby_columns,
208 lower_bound_column=LOWER,
209 upper_bound_column=UPPER,
210 middle_column=DIFFERENCE,
211 categorical_order_by="labels",
212 categorical_order_ascending=False,
213 )
214 # Also plot transparent circles, just to be able to show hover box
215 ch.style.color_palette.reset_palette_order()
216 ch.figure.circle(
217 source=df, x="categorical_x", y=DIFFERENCE, size=20, name="center", line_alpha=0, fill_alpha=0
218 )
219 if NULL_HYPOTHESIS in df.columns:
220 ch.style.color_palette.reset_palette_order()
221 dash_source = (
222 df[~df[NIM].isna()]
223 .assign(
224 color_column=lambda df: df.apply(
225 lambda row: "red"
226 if row[LOWER] < row[NULL_HYPOTHESIS] and row[NULL_HYPOTHESIS] < row[UPPER]
227 else "green",
228 axis=1,
229 )
230 )
231 .sort_values(groupby_columns)
232 )
233 ch.figure.dash(
234 source=dash_source,
235 x="categorical_x",
236 y=NULL_HYPOTHESIS,
237 size=320 / len(df),
238 line_width=3,
239 name="nim",
240 line_color="color_column",
241 )
242 ch.axes.set_yaxis_label(self._get_difference_plot_label(absolute))
243 ch.set_source_label("")
244 ch.callout.line(0)
245 ch.axes.set_yaxis_range(y_min - 0.05 * (y_max - y_min), y_max + 0.05 * (y_max - y_min))
246 ch.axes.set_yaxis_tick_format(axis_format)
247 ch.set_title(title)
248 ch.axes.set_xaxis_label(x_label)
249 ch.set_subtitle("")
250
251 self.add_tools(
252 chart=ch,
253 df=(
254 difference_df.set_index(groupby_columns)
255 .assign(categorical_x=lambda df: df.index.to_numpy())
256 .reset_index()
257 ),
258 center_name=DIFFERENCE,
259 absolute=absolute,
260 ordinal=False,
261 use_adjusted_intervals=use_adjusted_intervals,
262 )
263
264 chart_grid = ChartGrid()
265 chart_grid.charts.append(ch)
266
267 return chart_grid
268
269 def _summary_plot(self, level_name: Union[str, Tuple], level_df: DataFrame, groupby: Union[str, Iterable]):
270 remaining_groups = get_remaning_groups(self._all_group_columns, groupby)
271 if self._ordinal_group_column is not None and self._ordinal_group_column in remaining_groups:
272
273 ch = self._ordinal_summary_plot(level_name, level_df, remaining_groups, groupby)
274 else:
275 ch = self._categorical_summary_plot(level_name, level_df, remaining_groups, groupby)
276 return ch
277
278 def _ordinal_summary_plot(
279 self,
280 level_name: Union[str, Tuple],
281 level_df: DataFrame,
282 remaining_groups: Union[str, Iterable],
283 groupby: Union[str, Iterable],
284 ):
285 remaining_groups = get_remaning_groups(remaining_groups, self._ordinal_group_column)
286 title = "Estimate of {} / {}".format(self._numerator, self._denominator)
287 y_axis_label = "{} / {}".format(self._numerator, self._denominator)
288 return self._ordinal_plot(
289 POINT_ESTIMATE,
290 level_df,
291 groupby,
292 level_name,
293 remaining_groups,
294 absolute=True,
295 title=title,
296 y_axis_label=y_axis_label,
297 use_adjusted_intervals=False,
298 )
299
300 def _ordinal_plot(
301 self,
302 center_name: str,
303 level_df: DataFrame,
304 groupby: Union[str, Iterable],
305 level_name: Union[str, Tuple],
306 remaining_groups: Union[str, Iterable],
307 absolute: bool,
308 title: str,
309 y_axis_label: str,
310 use_adjusted_intervals: bool,
311 ):
312 LOWER, UPPER = (ADJUSTED_LOWER, ADJUSTED_UPPER) if use_adjusted_intervals else (CI_LOWER, CI_UPPER)
313 df = add_color_column(level_df, remaining_groups)
314 colors = "color" if remaining_groups else None
315 axis_format, y_min, y_max = axis_format_precision(
316 numbers=(
317 df[LOWER]
318 .append(df[center_name])
319 .append(df[UPPER])
320 .append(df[NULL_HYPOTHESIS] if NULL_HYPOTHESIS in df.columns else None)
321 ),
322 absolute=absolute,
323 )
324 ch = Chart(x_axis_type=self._ordinal_type())
325 ch.plot.line(
326 data_frame=df.sort_values(self._ordinal_group_column),
327 x_column=self._ordinal_group_column,
328 y_column=center_name,
329 color_column=colors,
330 )
331 ch.style.color_palette.reset_palette_order()
332 ch.plot.area(
333 data_frame=(
334 df.assign(**{LOWER: to_finite(df[LOWER], y_min)})
335 .assign(**{UPPER: to_finite(df[UPPER], y_max)})
336 .sort_values(self._ordinal_group_column)
337 ),
338 x_column=self._ordinal_group_column,
339 y_column=LOWER,
340 second_y_column=UPPER,
341 color_column=colors,
342 )
343 if NULL_HYPOTHESIS in df.columns:
344 ch.style.color_palette.reset_palette_order()
345 ch.plot.line(
346 data_frame=df.sort_values(self._ordinal_group_column),
347 x_column=self._ordinal_group_column,
348 y_column=NULL_HYPOTHESIS,
349 color_column=colors,
350 line_dash="dashed",
351 line_width=1,
352 )
353 ch.axes.set_yaxis_label(y_axis_label)
354 ch.axes.set_xaxis_label(self._ordinal_group_column)
355 ch.set_source_label("")
356 ch.axes.set_yaxis_range(y_min - 0.05 * (y_max - y_min), y_max + 0.05 * (y_max - y_min))
357 ch.axes.set_yaxis_tick_format(axis_format)
358 subtitle = "" if not groupby else "{}: {}".format(groupby, level_name)
359 ch.set_subtitle(subtitle)
360 ch.set_title(title)
361 if colors:
362 ch.set_legend_location("outside_bottom")
363 self.add_tools(
364 chart=ch,
365 df=df,
366 center_name=center_name,
367 absolute=absolute,
368 ordinal=True,
369 use_adjusted_intervals=use_adjusted_intervals,
370 )
371 return ch
372
373 def _categorical_summary_plot(self, level_name, summary_df, remaining_groups, groupby):
374 if not remaining_groups:
375 remaining_groups = listify(groupby)
376 df = summary_df.set_index(remaining_groups).assign(categorical_x=lambda df: df.index.to_numpy()).reset_index()
377
378 axis_format, y_min, y_max = axis_format_precision(
379 numbers=(df[CI_LOWER].append(df[POINT_ESTIMATE]).append(df[CI_UPPER])), absolute=True
380 )
381
382 ch = Chart(x_axis_type="categorical")
383 ch.plot.interval(
384 (
385 df.assign(**{CI_LOWER: to_finite(df[CI_LOWER], y_min)}).assign(
386 **{CI_UPPER: to_finite(df[CI_UPPER], y_max)}
387 )
388 ),
389 categorical_columns=remaining_groups,
390 lower_bound_column=CI_LOWER,
391 upper_bound_column=CI_UPPER,
392 middle_column=POINT_ESTIMATE,
393 categorical_order_by="labels",
394 categorical_order_ascending=True,
395 )
396 # Also plot transparent circles, just to be able to show hover box
397 ch.style.color_palette.reset_palette_order()
398 ch.figure.circle(
399 source=df, x="categorical_x", y=POINT_ESTIMATE, size=20, name="center", line_alpha=0, fill_alpha=0
400 )
401 ch.set_title("Estimate of {} / {}".format(self._numerator, self._denominator))
402 if groupby:
403 ch.set_subtitle("{}: {}".format(groupby, level_name))
404 else:
405 ch.set_subtitle("")
406 ch.axes.set_xaxis_label("{}".format(", ".join(remaining_groups)))
407 ch.axes.set_yaxis_label("{} / {}".format(self._numerator, self._denominator))
408 ch.set_source_label("")
409 ch.axes.set_yaxis_tick_format(axis_format)
410 self.add_tools(
411 chart=ch, df=df, center_name=POINT_ESTIMATE, absolute=True, ordinal=False, use_adjusted_intervals=False
412 )
413 return ch
414
415 def _ordinal_type(self):
416 ordinal_column_type = self._df[self._ordinal_group_column].dtype.type
417 axis_type = "datetime" if issubclass(ordinal_column_type, np.datetime64) else "linear"
418 return axis_type
419
420 def _ordinal_multiple_difference_plot(
421 self,
422 difference_df: DataFrame,
423 absolute: bool,
424 groupby: Union[str, Iterable],
425 level_as_reference: bool,
426 use_adjusted_intervals: bool,
427 ):
428 remaining_groups = get_remaning_groups(groupby, self._ordinal_group_column)
429 groupby_columns = self._add_level_column(remaining_groups, level_as_reference)
430 title = self._get_multiple_difference_title(difference_df, level_as_reference)
431 y_axis_label = self._get_difference_plot_label(absolute)
432 ch = self._ordinal_plot(
433 DIFFERENCE,
434 difference_df,
435 groupby=None,
436 level_name="",
437 remaining_groups=groupby_columns,
438 absolute=absolute,
439 title=title,
440 y_axis_label=y_axis_label,
441 use_adjusted_intervals=use_adjusted_intervals,
442 )
443 ch.callout.line(0)
444 return ch
445
446 def _categorical_multiple_difference_plot(
447 self,
448 difference_df: DataFrame,
449 absolute: bool,
450 groupby: Union[str, Iterable],
451 level_as_reference: bool,
452 use_adjusted_intervals: bool,
453 ):
454 groupby_columns = self._add_level_column(groupby, level_as_reference)
455 title = self._get_multiple_difference_title(difference_df, level_as_reference)
456 x_label = "" if groupby is None else "{}".format(groupby)
457 chart_grid = self._categorical_difference_chart(
458 absolute, difference_df, groupby_columns, title, x_label, use_adjusted_intervals
459 )
460
461 return chart_grid
462
463 def _get_multiple_difference_title(self, difference_df, level_as_reference):
464 reference_level = "level_1" if level_as_reference else "level_2"
465 title = "Comparison to {}".format(difference_df[reference_level].values[0])
466 return title
467
468 def _add_level_column(self, groupby, level_as_reference):
469 level_column = "level_2" if level_as_reference else "level_1"
470 if groupby is None:
471 groupby_columns = level_column
472 else:
473 if isinstance(groupby, str):
474 groupby_columns = [groupby, level_column]
475 else:
476 groupby_columns = groupby + [level_column]
477 return groupby_columns
478
479 def _add_level_columns(self, groupby):
480 levels = ["level_1", "level_2"]
481 if groupby is None:
482 groupby_columns = levels
483 else:
484 if isinstance(groupby, str):
485 groupby_columns = [groupby] + levels
486 else:
487 groupby_columns = groupby + levels
488 return groupby_columns
489
490 def add_ci_to_chart_datasources(
491 self, chart: Chart, df: DataFrame, center_name: str, ordinal: bool, use_adjusted_intervals: bool
492 ):
493 LOWER, UPPER = (ADJUSTED_LOWER, ADJUSTED_UPPER) if use_adjusted_intervals else (CI_LOWER, CI_UPPER)
494 group_col = "color" if ordinal and "color" in df.columns else "categorical_x"
495 for data in chart.data:
496 if center_name in data.keys() or NULL_HYPOTHESIS in data.keys():
497 index = data["index"]
498 data[LOWER] = np.array(df[LOWER][index])
499 data[UPPER] = np.array(df[UPPER][index])
500 data["color"] = np.array(df[group_col][index])
501 if DIFFERENCE in data.keys() or NULL_HYPOTHESIS in data.keys():
502 index = data["index"]
503 data[DIFFERENCE] = np.array(df[DIFFERENCE][index])
504 data["p_value"] = np.array(df[P_VALUE][index])
505 data["adjusted_p"] = np.array(df[ADJUSTED_P][index])
506 if NULL_HYPOTHESIS in df.columns:
507 data["null_hyp"] = np.array(df[NULL_HYPOTHESIS][index])
508
509 def add_tools(
510 self,
511 chart: Chart,
512 df: DataFrame,
513 center_name: str,
514 absolute: bool,
515 ordinal: bool,
516 use_adjusted_intervals: bool,
517 ):
518 self.add_ci_to_chart_datasources(chart, df, center_name, ordinal, use_adjusted_intervals)
519 LOWER, UPPER = (ADJUSTED_LOWER, ADJUSTED_UPPER) if use_adjusted_intervals else (CI_LOWER, CI_UPPER)
520
521 if len(chart.figure.legend) > 0:
522 chart.figure.legend.click_policy = "hide"
523 axis_format, y_min, y_max = axis_format_precision(
524 numbers=(
525 df[LOWER]
526 .append(df[center_name])
527 .append(df[UPPER])
528 .append(df[NULL_HYPOTHESIS] if NULL_HYPOTHESIS in df.columns else None)
529 ),
530 absolute=absolute,
531 extra_zeros=2,
532 )
533 ordinal_tool_tip = [] if not ordinal else [(self._ordinal_group_column, f"@{self._ordinal_group_column}")]
534 p_value_tool_tip = (
535 (
536 [("p-value", "@p_value{0.0000}")]
537 + ([("adjusted p-value", "@adjusted_p{0.0000}")] if len(df) > 1 else [])
538 )
539 if center_name == DIFFERENCE
540 else []
541 )
542 nim_tool_tip = [("null hypothesis", f"@null_hyp{{{axis_format}}}")] if NULL_HYPOTHESIS in df.columns else []
543 tooltips = (
544 [("group", "@color")]
545 + ordinal_tool_tip
546 + [(f"{center_name}", f"@{center_name}{{{axis_format}}}")]
547 + [
548 (
549 ("adjusted " if use_adjusted_intervals else "") + "confidence interval",
550 f"(@{{{LOWER}}}{{{axis_format}}}," f" @{{{UPPER}}}{{{axis_format}}})",
551 )
552 ]
553 + p_value_tool_tip
554 + nim_tool_tip
555 )
556 lines_with_hover = [] if ordinal else ["center", "nim"]
557 hover = tools.HoverTool(tooltips=tooltips, names=lines_with_hover)
558 box_zoom = tools.BoxZoomTool()
559
560 chart.figure.add_tools(
561 hover, tools.ZoomInTool(), tools.ZoomOutTool(), box_zoom, tools.PanTool(), tools.ResetTool()
562 )
563 chart.figure.toolbar.active_drag = box_zoom