1# Copyright 2017-2020 Spotify AB
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from abc import ABCMeta, abstractmethod
16from functools import wraps
17import types
18import warnings
19
20import chartify
21import numpy as np
22import pandas as pd
23
24from spotify_confidence.options import options
25from spotify_confidence.chartgrid import ChartGrid
26
27# warnings.simplefilter("once")
28
29INITIAL_RANDOMIZATION_SEED = np.random.get_state()[1][0]
30
31
32def axis_format_precision(max_value, min_value, absolute):
33 extra_zeros = 2 if absolute else 0
34 precision = -int(np.log10(abs(max_value - min_value))) + extra_zeros
35 zeros = "".join(["0"] * precision)
36 return "0.{}{}".format(zeros, "" if absolute else "%")
37
38
39def add_color_column(df, cols):
40 for i, column in enumerate(cols):
41 if i == 0:
42 df["color"] = df[column]
43 else:
44 df["color"] = df["color"] + " " + df[column]
45 return df
46
47
48def randomization_warning_decorator(f):
49 """Set numpy randomization seed and warn users if not fixed.
50
51 Note to developers:
52 Do not compare random variables that have been
53 sampled from the same seed. It will lead to incorrect results.
54 To avoid this situation it's best to apply this decorator to
55 public methods that involve randomization.
56 """
57
58 @wraps(f)
59 def wrapper(*args, **kwargs):
60
61 option_seed = options.get_option("randomization_seed")
62 np_seed = INITIAL_RANDOMIZATION_SEED
63 if option_seed != np_seed and option_seed is None:
64 randomization_warning_message = """
65 Your analysis will not be reproducible!
66 Using a method that involves randomization without setting a seed.
67 Please run the following and add it to the top of your script or
68 notebook after you import confidence:
69
70 confidence.options.set_option('randomization_seed', {})
71
72 """.format(
73 INITIAL_RANDOMIZATION_SEED
74 )
75 warnings.warn(randomization_warning_message)
76 option_seed = np_seed
77 np.random.seed(option_seed)
78 return f(*args, **kwargs)
79
80 return wrapper
81
82
83class BaseTest(object, metaclass=ABCMeta):
84 """Base test class that provides abstract methods
85 to ensure consistency across test classes."""
86
87 def __init__(
88 self,
89 data_frame,
90 categorical_group_columns,
91 ordinal_group_column,
92 numerator_column,
93 denominator_column,
94 interval_size,
95 ):
96
97 self._data_frame = data_frame
98 self._numerator_column = numerator_column
99 self._denominator_column = denominator_column
100 self._interval_size = interval_size
101
102 categorical_string_or_none = isinstance(categorical_group_columns, str) or categorical_group_columns is None
103 self._categorical_group_columns = (
104 [categorical_group_columns] if categorical_string_or_none else categorical_group_columns
105 )
106 self._ordinal_group_column = ordinal_group_column
107
108 self._all_group_columns = self._categorical_group_columns + [self._ordinal_group_column]
109 self._all_group_columns = [column for column in self._all_group_columns if column is not None]
110 self._validate_data()
111
112 def _validate_data(self):
113 """Integrity check input dataframe."""
114 if not self._all_group_columns:
115 raise ValueError(
116 """At least one of `categorical_group_columns`
117 or `ordinal_group_column` must be specified."""
118 )
119
120 # Ensure there's at most 1 observation per grouping.
121 max_one_row_per_grouping = all(self._data_frame.groupby(self._all_group_columns).size() <= 1)
122 if not max_one_row_per_grouping:
123 raise ValueError("""Each grouping should have at most 1 observation.""")
124
125 if self._ordinal_group_column:
126 ordinal_column_type = self._data_frame[self._ordinal_group_column].dtype.type
127 if not np.issubdtype(ordinal_column_type, np.number) and not issubclass(
128 ordinal_column_type, np.datetime64
129 ):
130 raise TypeError(
131 """`ordinal_group_column` is type `{}`.
132 Must be number or datetime type.""".format(
133 ordinal_column_type
134 )
135 )
136
137 @classmethod
138 def as_cumulative(
139 cls, data_frame, numerator_column, denominator_column, ordinal_group_column, categorical_group_columns=None
140 ):
141 """
142 Instantiate the class with a cumulative representation of the dataframe.
143 Sorts by the ordinal variable and calculates the cumulative sum
144 May be used for to visualize the difference between groups as a
145 time series.
146
147 Args:
148 data_frame (pd.DataFrame): DataFrame
149 numerator_column (str): Column name for numerator column.
150 denominator_column (str): Column name for denominator column.
151 ordinal_group_column (str): Column name for ordinal grouping
152 (e.g. numeric or date values).
153 categorical_group_columns (str or list),
154 Optional: Column names for categorical groupings.
155
156 """
157
158 sorted_df = data_frame.sort_values(by=ordinal_group_column)
159 cumsum_cols = [numerator_column, denominator_column]
160 if categorical_group_columns:
161 sorted_df[cumsum_cols] = sorted_df.groupby(by=categorical_group_columns)[cumsum_cols].cumsum()
162 else:
163 sorted_df[cumsum_cols] = sorted_df[cumsum_cols].cumsum()
164
165 return cls(sorted_df, numerator_column, denominator_column, categorical_group_columns, ordinal_group_column)
166
167 def summary(self):
168 """Return Pandas DataFrame with summary statistics."""
169 return self._summary(self._data_frame, self._interval)
170
171 def _summary(self, data_frame, ci_function):
172 """Return the input dataframe with added columns:
173 - Lower & upper bounds of
174 Bayesian: credible interval
175 Frequentist: confidence interval
176 - Additional summary stats
177 (e.g. probability in the case of Binomial data)
178 """
179 summary_df = data_frame[self._all_group_columns + [self._numerator_column, self._denominator_column]].copy()
180
181 summary_df["point_estimate"] = summary_df[self._numerator_column] * 1.0 / summary_df[self._denominator_column]
182 summary_df[["ci_lower", "ci_upper"]] = data_frame.apply(ci_function, axis=1, result_type="expand")
183
184 return summary_df
185
186 def summary_plot(self, groupby=None):
187 """Plot for each group in the data_frame:
188
189 if ordinal level exists:
190 Frequentist: line graph with area to represent confidence interval
191 Bayesian: line graph with area to represent credible interval
192 if categorical levels:
193 Bayesian: KDE plot of posterior distributions by group
194 Frequentist: Interval plots of confidence intervals by group
195
196 Args:
197 groupby (str): Name of column.
198 If specified, will plot a separate chart for each level of the
199 grouping.
200
201 Returns:
202 ChartGrid object.
203 """
204 chart_grid = self._iterate_groupby_to_chartgrid(self._summary_plot, groupby=groupby)
205 return chart_grid
206
207 def _summary_plot(self, level_name, level_df, remaining_groups, groupby):
208
209 if self._ordinal_group_column is not None and self._ordinal_group_column in remaining_groups:
210
211 ch = self._ordinal_summary_plot(level_name, level_df, remaining_groups, groupby)
212 else:
213 ch = self._categorical_summary_plot(level_name, level_df, remaining_groups, groupby)
214 return ch
215
216 def _ordinal_summary_plot(self, level_name, level_df, remaining_groups, groupby):
217 remaining_groups = self._remaining_categorical_groups(remaining_groups)
218 df = self._summary(level_df, self._interval)
219 title = "Estimate of {} / {}".format(self._numerator_column, self._denominator_column)
220 y_axis_label = "{} / {}".format(self._numerator_column, self._denominator_column)
221 return self._ordinal_plot(
222 "point_estimate",
223 df,
224 groupby,
225 level_name,
226 remaining_groups,
227 absolute=True,
228 title=title,
229 y_axis_label=y_axis_label,
230 )
231
232 def _ordinal_plot(self, center_name, df, groupby, level_name, remaining_groups, absolute, title, y_axis_label):
233 df = add_color_column(df, remaining_groups)
234 colors = "color" if remaining_groups else None
235 ch = chartify.Chart(x_axis_type=self._ordinal_type())
236 ch.plot.line(
237 data_frame=df.sort_values(self._ordinal_group_column),
238 x_column=self._ordinal_group_column,
239 y_column=center_name,
240 color_column=colors,
241 )
242 ch.style.color_palette.reset_palette_order()
243 ch.plot.area(
244 data_frame=df.sort_values(self._ordinal_group_column),
245 x_column=self._ordinal_group_column,
246 y_column="ci_lower",
247 second_y_column="ci_upper",
248 color_column=colors,
249 )
250 ch.axes.set_yaxis_label(y_axis_label)
251 ch.axes.set_xaxis_label(self._ordinal_group_column)
252 ch.set_source_label("")
253 axis_format = axis_format_precision(df["ci_lower"].min(), df["ci_upper"].max(), absolute)
254 ch.axes.set_yaxis_tick_format(axis_format)
255 subtitle = "" if not groupby else "{}: {}".format(groupby, level_name)
256 ch.set_subtitle(subtitle)
257 ch.set_title(title)
258 if colors:
259 ch.set_legend_location("outside_bottom")
260 return ch
261
262 def _remaining_categorical_groups(self, remaining_groups):
263 remaining_groups_list = [remaining_groups] if isinstance(remaining_groups, str) else remaining_groups
264
265 remaining_categorical_groups = [
266 group_name for group_name in remaining_groups_list if group_name != self._ordinal_group_column
267 ]
268 return remaining_categorical_groups
269
270 def _ordinal_type(self):
271 ordinal_column_type = self._data_frame[self._ordinal_group_column].dtype.type
272 axis_type = "datetime" if issubclass(ordinal_column_type, np.datetime64) else "linear"
273 return axis_type
274
275 @abstractmethod
276 def _categorical_summary_plot(self, level_name, level_df, remaining_groups, groupby):
277 pass
278
279 @abstractmethod
280 def difference(self, level_1, level_2, absolute=True, groupby=None):
281 """Return dataframe containing the difference in means between
282 group 1 and 2 and the appropriate test statistics.
283 Frequentist:
284 - Calculate one of the following tests depending of the
285 response variable type.
286 - Binomial: Chisq / fisher exact test
287 - Gaussian: t-test / z-test
288 Return the p-value.
289 Bayesian:
290 - Calcuate the posterior distribution of the difference in means.
291 Return the
292 - probability that group 2 > group 1.
293 - Expected loss
294 - Expected change
295 - Expected gain
296 - 95% CI interval
297 """
298 pass
299
300 def difference_plot(self, level_1, level_2, absolute=True, groupby=None):
301 """Plot representing the difference between group 1 and 2.
302 - Difference in means or proportions, depending
303 on the response variable type.
304
305 Frequentist:
306 - Plot interval plot with confidence interval of the
307 difference between groups
308
309 Bayesian:
310 - Plot KDE representing the posterior distribution of the difference.
311 - Probability that group2 > group1
312 - Mean difference
313 - 95% interval.
314
315 Args:
316 level_1 (str, tuple of str): Name of first level.
317 level_2 (str, tuple of str): Name of second level.
318 absolute (bool): If True then return the absolute
319 difference (level2 - level1)
320 otherwise return the relative difference (level2 / level1 - 1)
321 groupby (str): Name of column, or list of columns.
322 If specified, will return an interval for each level
323 of the grouped dimension, or a confidence band if the
324 grouped dimension is ordinal
325
326 Returns:
327 GroupedChart object.
328 """
329
330 use_ordinal_axis = self._use_ordinal_axis(groupby)
331
332 if use_ordinal_axis:
333 ch = self._ordinal_difference_plot(level_1, level_2, absolute, groupby)
334 chart_grid = ChartGrid()
335 chart_grid.charts.append(ch)
336 else:
337 chart_grid = self._categorical_difference_plot(level_1, level_2, absolute, groupby)
338
339 return chart_grid
340
341 def _use_ordinal_axis(self, groupby):
342 is_ordinal_difference_plot = (
343 groupby is not None and self._ordinal_group_column is not None and self._ordinal_group_column in groupby
344 )
345 return is_ordinal_difference_plot
346
347 def _ordinal_difference_plot(self, level_1, level_2, absolute, groupby):
348 difference_df = self.difference(level_1, level_2, absolute, groupby)
349 remaining_groups = self._remaining_categorical_groups(groupby)
350 title = "Change from {} to {}".format(level_1, level_2)
351 y_axis_label = self.get_difference_plot_label(absolute)
352 ch = self._ordinal_plot(
353 "difference",
354 difference_df,
355 groupby=None,
356 level_name="",
357 remaining_groups=remaining_groups,
358 absolute=absolute,
359 title=title,
360 y_axis_label=y_axis_label,
361 )
362 ch.callout.line(0)
363
364 return ch
365
366 def get_difference_plot_label(self, absolute):
367 change_type = "Absolute" if absolute else "Relative"
368 return change_type + " change in {} / {}".format(self._numerator_column, self._denominator_column)
369
370 @abstractmethod
371 def _categorical_difference_plot(self, level_1, level_2, absolute, groupby):
372 pass
373
374 @abstractmethod
375 def multiple_difference(self, level, absolute=True, groupby=None, level_as_reference=False):
376 """The pairwise probability that the specific group
377 is greater than all other groups.
378 """
379 pass
380
381 def multiple_difference_plot(self, level, absolute=True, groupby=None, level_as_reference=False):
382 """Compare level to all other groups or, if level_as_reference = True,
383 all other groups to level.
384
385 Args:
386 level (str, tuple of str): Name of level.
387 absolute (bool): If True then return the absolute
388 difference (level2 - level1)
389 otherwise return the relative difference (level2 / level1 - 1)
390 groupby (str): Name of column, or list of columns.
391 If specified, will return an interval for each level
392 of the grouped dimension, or a confidence band if the
393 grouped dimension is ordinal
394 level_as_reference: If false (default), compare level to all other
395 groups. If true, compare all other groups to level.
396 """
397 use_ordinal_axis = self._use_ordinal_axis(groupby)
398
399 if use_ordinal_axis:
400 ch = self._ordinal_multiple_difference_plot(level, absolute, groupby, level_as_reference)
401 chart_grid = ChartGrid()
402 chart_grid.charts.append(ch)
403 else:
404 chart_grid = self._categorical_multiple_difference_plot(level, absolute, groupby, level_as_reference)
405
406 return chart_grid
407
408 def _ordinal_multiple_difference_plot(self, level, absolute, groupby, level_as_reference):
409 difference_df = self.multiple_difference(level, absolute, groupby, level_as_reference)
410 remaining_groups = self._remaining_categorical_groups(groupby)
411 groupby_columns = self._add_level_column(remaining_groups, level_as_reference)
412 title = "Comparison to {}".format(level)
413 y_axis_label = self.get_difference_plot_label(absolute)
414 ch = self._ordinal_plot(
415 "difference",
416 difference_df,
417 groupby=None,
418 level_name="",
419 remaining_groups=groupby_columns,
420 absolute=absolute,
421 title=title,
422 y_axis_label=y_axis_label,
423 )
424 ch.callout.line(0)
425 return ch
426
427 def _add_level_column(self, groupby, level_as_reference):
428 level_column = "level_2" if level_as_reference else "level_1"
429 if groupby is None:
430 groupby_columns = level_column
431 else:
432 if isinstance(groupby, str):
433 groupby_columns = [groupby, level_column]
434 else:
435 groupby_columns = groupby + [level_column]
436 return groupby_columns
437
438 @abstractmethod
439 def _categorical_multiple_difference_plot(self, level, absolute, groupby, level_as_reference):
440 pass
441
442 @staticmethod
443 def _validate_levels(level_df, remaining_groups, level):
444 try:
445 level_df.groupby(remaining_groups).get_group(level)
446 except (KeyError, ValueError):
447 raise ValueError(
448 """
449 Invalid level: '{}'
450 Must supply a level within the ungrouped dimensions: {}
451 Valid levels:
452 {}
453 """.format(
454 level, remaining_groups, list(level_df.groupby(remaining_groups).groups.keys())
455 )
456 )
457
458 def _groupby_iterator(self, input_function, groupby, **kwargs):
459 groupby = [] if groupby is None else groupby
460 # Will group over the whole dataframe if groupby is None
461 level_groups = groupby if groupby else np.ones(len(self._data_frame))
462
463 remaining_groups = [group for group in self._all_group_columns if group not in groupby and group is not None]
464
465 for level_name, level_df in self._data_frame.groupby(level_groups):
466 yield input_function(level_name, level_df, remaining_groups, groupby, **kwargs)
467
468 def _iterate_groupby_to_chartgrid(self, input_function, groupby, **kwargs):
469 """Iterate through groups in the test and apply the input function.
470
471 Returns ChartGrid"""
472 chart_grid = ChartGrid()
473
474 chart_grid.charts = list(self._groupby_iterator(input_function, groupby, **kwargs))
475
476 return chart_grid
477
478 def _iterate_groupby_to_dataframe(self, input_function, groupby, **kwargs):
479 """Iterate through groups in the test and apply the input function.
480
481 Returns pd.DataFrame"""
482 groupby_iterator = self._groupby_iterator(input_function, groupby, **kwargs)
483
484 # Flatten any nested generators.
485 groupby_iterator = list(groupby_iterator)
486 if isinstance(groupby_iterator[0], types.GeneratorType):
487 groupby_iterator = [group for generator in groupby_iterator for group in generator]
488
489 results_data_frame = pd.concat(groupby_iterator, axis=0)
490
491 results_data_frame = results_data_frame.reset_index(drop=True)
492
493 return results_data_frame
494
495 def _all_groups(self):
496 """Return a list of all group keys.
497
498 Returns: list"""
499 groups = list(self._data_frame.groupby(self._all_group_columns).groups.keys())
500 return groups
501
502 def _add_group_by_columns(self, difference_df, groupby, level_name):
503 if groupby:
504 groupby = groupby[0] if len(groupby) == 1 else groupby
505 if isinstance(groupby, str):
506 difference_df.insert(0, column=groupby, value=level_name)
507 else:
508 for col, val in zip(groupby, level_name):
509 difference_df.insert(0, column=col, value=val)
510
511
512# class BinomialResponse(BaseTest, metaclass=ABCMeta):
513# """Binomial Response Variable.
514# """
515
516# class GaussianResponse(BaseTest, metaclass=ABCMeta):
517# """Base class for tests of normal response variables
518
519# E.g. Revenue per user
520# """
521
522# pass
523
524
525# class PoissonResponse(BaseTest, metaclass=ABCMeta):
526# """Base class for tests of poisson response variables.
527
528# E.g. # of days active per user per month
529# """
530# pass
531
532
533# class MultinomialResponse(BaseTest, metaclass=ABCMeta):
534# """Base class for tests of multinomial response variables.
535
536# E.g. single choice answer survey
537# self.
538# """
539
540# def __init__(self, data_frame, categorical_group_columns,
541# ordinal_group_column, category_column, value_column):
542# self._category_column = category_column
543# self._value_column = value_column
544# super().__init__(data_frame, categorical_group_columns,
545# ordinal_group_column)
546
547
548# class CategoricalResponse(BaseTest, metaclass=ABCMeta):
549# """Base class for tests of categorical response variables.
550
551# E.g. multiple choice answer survey
552# """
553
554# def __init__(self, data_frame, categorical_group_columns,
555# ordinal_group_column, category_column, value_column):
556# self._category_column = category_column
557# self._value_column = value_column
558# super().__init__(data_frame, categorical_group_columns,
559# ordinal_group_column)
560
561# pass