Coverage for /Users/sebastiana/Documents/Sugarpills/confidence/spotify_confidence/analysis/bayesian/bayesian_base.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

201 statements  

1# Copyright 2017-2020 Spotify AB 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from abc import ABCMeta, abstractmethod 

16from functools import wraps 

17import types 

18import warnings 

19 

20import chartify 

21import numpy as np 

22import pandas as pd 

23 

24from spotify_confidence.options import options 

25from spotify_confidence.chartgrid import ChartGrid 

26 

27# warnings.simplefilter("once") 

28 

29INITIAL_RANDOMIZATION_SEED = np.random.get_state()[1][0] 

30 

31 

32def axis_format_precision(max_value, min_value, absolute): 

33 extra_zeros = 2 if absolute else 0 

34 precision = -int(np.log10(abs(max_value - min_value))) + extra_zeros 

35 zeros = "".join(["0"] * precision) 

36 return "0.{}{}".format(zeros, "" if absolute else "%") 

37 

38 

39def add_color_column(df, cols): 

40 for i, column in enumerate(cols): 

41 if i == 0: 

42 df["color"] = df[column] 

43 else: 

44 df["color"] = df["color"] + " " + df[column] 

45 return df 

46 

47 

48def randomization_warning_decorator(f): 

49 """Set numpy randomization seed and warn users if not fixed. 

50 

51 Note to developers: 

52 Do not compare random variables that have been 

53 sampled from the same seed. It will lead to incorrect results. 

54 To avoid this situation it's best to apply this decorator to 

55 public methods that involve randomization. 

56 """ 

57 

58 @wraps(f) 

59 def wrapper(*args, **kwargs): 

60 

61 option_seed = options.get_option("randomization_seed") 

62 np_seed = INITIAL_RANDOMIZATION_SEED 

63 if option_seed != np_seed and option_seed is None: 

64 randomization_warning_message = """ 

65 Your analysis will not be reproducible! 

66 Using a method that involves randomization without setting a seed. 

67 Please run the following and add it to the top of your script or 

68 notebook after you import confidence: 

69 

70 confidence.options.set_option('randomization_seed', {}) 

71 

72 """.format( 

73 INITIAL_RANDOMIZATION_SEED 

74 ) 

75 warnings.warn(randomization_warning_message) 

76 option_seed = np_seed 

77 np.random.seed(option_seed) 

78 return f(*args, **kwargs) 

79 

80 return wrapper 

81 

82 

83class BaseTest(object, metaclass=ABCMeta): 

84 """Base test class that provides abstract methods 

85 to ensure consistency across test classes.""" 

86 

87 def __init__( 

88 self, 

89 data_frame, 

90 categorical_group_columns, 

91 ordinal_group_column, 

92 numerator_column, 

93 denominator_column, 

94 interval_size, 

95 ): 

96 

97 self._data_frame = data_frame 

98 self._numerator_column = numerator_column 

99 self._denominator_column = denominator_column 

100 self._interval_size = interval_size 

101 

102 categorical_string_or_none = isinstance(categorical_group_columns, str) or categorical_group_columns is None 

103 self._categorical_group_columns = ( 

104 [categorical_group_columns] if categorical_string_or_none else categorical_group_columns 

105 ) 

106 self._ordinal_group_column = ordinal_group_column 

107 

108 self._all_group_columns = self._categorical_group_columns + [self._ordinal_group_column] 

109 self._all_group_columns = [column for column in self._all_group_columns if column is not None] 

110 self._validate_data() 

111 

112 def _validate_data(self): 

113 """Integrity check input dataframe.""" 

114 if not self._all_group_columns: 

115 raise ValueError( 

116 """At least one of `categorical_group_columns` 

117 or `ordinal_group_column` must be specified.""" 

118 ) 

119 

120 # Ensure there's at most 1 observation per grouping. 

121 max_one_row_per_grouping = all(self._data_frame.groupby(self._all_group_columns).size() <= 1) 

122 if not max_one_row_per_grouping: 

123 raise ValueError("""Each grouping should have at most 1 observation.""") 

124 

125 if self._ordinal_group_column: 

126 ordinal_column_type = self._data_frame[self._ordinal_group_column].dtype.type 

127 if not np.issubdtype(ordinal_column_type, np.number) and not issubclass( 

128 ordinal_column_type, np.datetime64 

129 ): 

130 raise TypeError( 

131 """`ordinal_group_column` is type `{}`. 

132 Must be number or datetime type.""".format( 

133 ordinal_column_type 

134 ) 

135 ) 

136 

137 @classmethod 

138 def as_cumulative( 

139 cls, data_frame, numerator_column, denominator_column, ordinal_group_column, categorical_group_columns=None 

140 ): 

141 """ 

142 Instantiate the class with a cumulative representation of the dataframe. 

143 Sorts by the ordinal variable and calculates the cumulative sum 

144 May be used for to visualize the difference between groups as a 

145 time series. 

146 

147 Args: 

148 data_frame (pd.DataFrame): DataFrame 

149 numerator_column (str): Column name for numerator column. 

150 denominator_column (str): Column name for denominator column. 

151 ordinal_group_column (str): Column name for ordinal grouping 

152 (e.g. numeric or date values). 

153 categorical_group_columns (str or list), 

154 Optional: Column names for categorical groupings. 

155 

156 """ 

157 

158 sorted_df = data_frame.sort_values(by=ordinal_group_column) 

159 cumsum_cols = [numerator_column, denominator_column] 

160 if categorical_group_columns: 

161 sorted_df[cumsum_cols] = sorted_df.groupby(by=categorical_group_columns)[cumsum_cols].cumsum() 

162 else: 

163 sorted_df[cumsum_cols] = sorted_df[cumsum_cols].cumsum() 

164 

165 return cls(sorted_df, numerator_column, denominator_column, categorical_group_columns, ordinal_group_column) 

166 

167 def summary(self): 

168 """Return Pandas DataFrame with summary statistics.""" 

169 return self._summary(self._data_frame, self._interval) 

170 

171 def _summary(self, data_frame, ci_function): 

172 """Return the input dataframe with added columns: 

173 - Lower & upper bounds of 

174 Bayesian: credible interval 

175 Frequentist: confidence interval 

176 - Additional summary stats 

177 (e.g. probability in the case of Binomial data) 

178 """ 

179 summary_df = data_frame[self._all_group_columns + [self._numerator_column, self._denominator_column]].copy() 

180 

181 summary_df["point_estimate"] = summary_df[self._numerator_column] * 1.0 / summary_df[self._denominator_column] 

182 summary_df[["ci_lower", "ci_upper"]] = data_frame.apply(ci_function, axis=1, result_type="expand") 

183 

184 return summary_df 

185 

186 def summary_plot(self, groupby=None): 

187 """Plot for each group in the data_frame: 

188 

189 if ordinal level exists: 

190 Frequentist: line graph with area to represent confidence interval 

191 Bayesian: line graph with area to represent credible interval 

192 if categorical levels: 

193 Bayesian: KDE plot of posterior distributions by group 

194 Frequentist: Interval plots of confidence intervals by group 

195 

196 Args: 

197 groupby (str): Name of column. 

198 If specified, will plot a separate chart for each level of the 

199 grouping. 

200 

201 Returns: 

202 ChartGrid object. 

203 """ 

204 chart_grid = self._iterate_groupby_to_chartgrid(self._summary_plot, groupby=groupby) 

205 return chart_grid 

206 

207 def _summary_plot(self, level_name, level_df, remaining_groups, groupby): 

208 

209 if self._ordinal_group_column is not None and self._ordinal_group_column in remaining_groups: 

210 

211 ch = self._ordinal_summary_plot(level_name, level_df, remaining_groups, groupby) 

212 else: 

213 ch = self._categorical_summary_plot(level_name, level_df, remaining_groups, groupby) 

214 return ch 

215 

216 def _ordinal_summary_plot(self, level_name, level_df, remaining_groups, groupby): 

217 remaining_groups = self._remaining_categorical_groups(remaining_groups) 

218 df = self._summary(level_df, self._interval) 

219 title = "Estimate of {} / {}".format(self._numerator_column, self._denominator_column) 

220 y_axis_label = "{} / {}".format(self._numerator_column, self._denominator_column) 

221 return self._ordinal_plot( 

222 "point_estimate", 

223 df, 

224 groupby, 

225 level_name, 

226 remaining_groups, 

227 absolute=True, 

228 title=title, 

229 y_axis_label=y_axis_label, 

230 ) 

231 

232 def _ordinal_plot(self, center_name, df, groupby, level_name, remaining_groups, absolute, title, y_axis_label): 

233 df = add_color_column(df, remaining_groups) 

234 colors = "color" if remaining_groups else None 

235 ch = chartify.Chart(x_axis_type=self._ordinal_type()) 

236 ch.plot.line( 

237 data_frame=df.sort_values(self._ordinal_group_column), 

238 x_column=self._ordinal_group_column, 

239 y_column=center_name, 

240 color_column=colors, 

241 ) 

242 ch.style.color_palette.reset_palette_order() 

243 ch.plot.area( 

244 data_frame=df.sort_values(self._ordinal_group_column), 

245 x_column=self._ordinal_group_column, 

246 y_column="ci_lower", 

247 second_y_column="ci_upper", 

248 color_column=colors, 

249 ) 

250 ch.axes.set_yaxis_label(y_axis_label) 

251 ch.axes.set_xaxis_label(self._ordinal_group_column) 

252 ch.set_source_label("") 

253 axis_format = axis_format_precision(df["ci_lower"].min(), df["ci_upper"].max(), absolute) 

254 ch.axes.set_yaxis_tick_format(axis_format) 

255 subtitle = "" if not groupby else "{}: {}".format(groupby, level_name) 

256 ch.set_subtitle(subtitle) 

257 ch.set_title(title) 

258 if colors: 

259 ch.set_legend_location("outside_bottom") 

260 return ch 

261 

262 def _remaining_categorical_groups(self, remaining_groups): 

263 remaining_groups_list = [remaining_groups] if isinstance(remaining_groups, str) else remaining_groups 

264 

265 remaining_categorical_groups = [ 

266 group_name for group_name in remaining_groups_list if group_name != self._ordinal_group_column 

267 ] 

268 return remaining_categorical_groups 

269 

270 def _ordinal_type(self): 

271 ordinal_column_type = self._data_frame[self._ordinal_group_column].dtype.type 

272 axis_type = "datetime" if issubclass(ordinal_column_type, np.datetime64) else "linear" 

273 return axis_type 

274 

275 @abstractmethod 

276 def _categorical_summary_plot(self, level_name, level_df, remaining_groups, groupby): 

277 pass 

278 

279 @abstractmethod 

280 def difference(self, level_1, level_2, absolute=True, groupby=None): 

281 """Return dataframe containing the difference in means between 

282 group 1 and 2 and the appropriate test statistics. 

283 Frequentist: 

284 - Calculate one of the following tests depending of the 

285 response variable type. 

286 - Binomial: Chisq / fisher exact test 

287 - Gaussian: t-test / z-test 

288 Return the p-value. 

289 Bayesian: 

290 - Calcuate the posterior distribution of the difference in means. 

291 Return the 

292 - probability that group 2 > group 1. 

293 - Expected loss 

294 - Expected change 

295 - Expected gain 

296 - 95% CI interval 

297 """ 

298 pass 

299 

300 def difference_plot(self, level_1, level_2, absolute=True, groupby=None): 

301 """Plot representing the difference between group 1 and 2. 

302 - Difference in means or proportions, depending 

303 on the response variable type. 

304 

305 Frequentist: 

306 - Plot interval plot with confidence interval of the 

307 difference between groups 

308 

309 Bayesian: 

310 - Plot KDE representing the posterior distribution of the difference. 

311 - Probability that group2 > group1 

312 - Mean difference 

313 - 95% interval. 

314 

315 Args: 

316 level_1 (str, tuple of str): Name of first level. 

317 level_2 (str, tuple of str): Name of second level. 

318 absolute (bool): If True then return the absolute 

319 difference (level2 - level1) 

320 otherwise return the relative difference (level2 / level1 - 1) 

321 groupby (str): Name of column, or list of columns. 

322 If specified, will return an interval for each level 

323 of the grouped dimension, or a confidence band if the 

324 grouped dimension is ordinal 

325 

326 Returns: 

327 GroupedChart object. 

328 """ 

329 

330 use_ordinal_axis = self._use_ordinal_axis(groupby) 

331 

332 if use_ordinal_axis: 

333 ch = self._ordinal_difference_plot(level_1, level_2, absolute, groupby) 

334 chart_grid = ChartGrid() 

335 chart_grid.charts.append(ch) 

336 else: 

337 chart_grid = self._categorical_difference_plot(level_1, level_2, absolute, groupby) 

338 

339 return chart_grid 

340 

341 def _use_ordinal_axis(self, groupby): 

342 is_ordinal_difference_plot = ( 

343 groupby is not None and self._ordinal_group_column is not None and self._ordinal_group_column in groupby 

344 ) 

345 return is_ordinal_difference_plot 

346 

347 def _ordinal_difference_plot(self, level_1, level_2, absolute, groupby): 

348 difference_df = self.difference(level_1, level_2, absolute, groupby) 

349 remaining_groups = self._remaining_categorical_groups(groupby) 

350 title = "Change from {} to {}".format(level_1, level_2) 

351 y_axis_label = self.get_difference_plot_label(absolute) 

352 ch = self._ordinal_plot( 

353 "difference", 

354 difference_df, 

355 groupby=None, 

356 level_name="", 

357 remaining_groups=remaining_groups, 

358 absolute=absolute, 

359 title=title, 

360 y_axis_label=y_axis_label, 

361 ) 

362 ch.callout.line(0) 

363 

364 return ch 

365 

366 def get_difference_plot_label(self, absolute): 

367 change_type = "Absolute" if absolute else "Relative" 

368 return change_type + " change in {} / {}".format(self._numerator_column, self._denominator_column) 

369 

370 @abstractmethod 

371 def _categorical_difference_plot(self, level_1, level_2, absolute, groupby): 

372 pass 

373 

374 @abstractmethod 

375 def multiple_difference(self, level, absolute=True, groupby=None, level_as_reference=False): 

376 """The pairwise probability that the specific group 

377 is greater than all other groups. 

378 """ 

379 pass 

380 

381 def multiple_difference_plot(self, level, absolute=True, groupby=None, level_as_reference=False): 

382 """Compare level to all other groups or, if level_as_reference = True, 

383 all other groups to level. 

384 

385 Args: 

386 level (str, tuple of str): Name of level. 

387 absolute (bool): If True then return the absolute 

388 difference (level2 - level1) 

389 otherwise return the relative difference (level2 / level1 - 1) 

390 groupby (str): Name of column, or list of columns. 

391 If specified, will return an interval for each level 

392 of the grouped dimension, or a confidence band if the 

393 grouped dimension is ordinal 

394 level_as_reference: If false (default), compare level to all other 

395 groups. If true, compare all other groups to level. 

396 """ 

397 use_ordinal_axis = self._use_ordinal_axis(groupby) 

398 

399 if use_ordinal_axis: 

400 ch = self._ordinal_multiple_difference_plot(level, absolute, groupby, level_as_reference) 

401 chart_grid = ChartGrid() 

402 chart_grid.charts.append(ch) 

403 else: 

404 chart_grid = self._categorical_multiple_difference_plot(level, absolute, groupby, level_as_reference) 

405 

406 return chart_grid 

407 

408 def _ordinal_multiple_difference_plot(self, level, absolute, groupby, level_as_reference): 

409 difference_df = self.multiple_difference(level, absolute, groupby, level_as_reference) 

410 remaining_groups = self._remaining_categorical_groups(groupby) 

411 groupby_columns = self._add_level_column(remaining_groups, level_as_reference) 

412 title = "Comparison to {}".format(level) 

413 y_axis_label = self.get_difference_plot_label(absolute) 

414 ch = self._ordinal_plot( 

415 "difference", 

416 difference_df, 

417 groupby=None, 

418 level_name="", 

419 remaining_groups=groupby_columns, 

420 absolute=absolute, 

421 title=title, 

422 y_axis_label=y_axis_label, 

423 ) 

424 ch.callout.line(0) 

425 return ch 

426 

427 def _add_level_column(self, groupby, level_as_reference): 

428 level_column = "level_2" if level_as_reference else "level_1" 

429 if groupby is None: 

430 groupby_columns = level_column 

431 else: 

432 if isinstance(groupby, str): 

433 groupby_columns = [groupby, level_column] 

434 else: 

435 groupby_columns = groupby + [level_column] 

436 return groupby_columns 

437 

438 @abstractmethod 

439 def _categorical_multiple_difference_plot(self, level, absolute, groupby, level_as_reference): 

440 pass 

441 

442 @staticmethod 

443 def _validate_levels(level_df, remaining_groups, level): 

444 try: 

445 level_df.groupby(remaining_groups).get_group(level) 

446 except (KeyError, ValueError): 

447 raise ValueError( 

448 """ 

449 Invalid level: '{}' 

450 Must supply a level within the ungrouped dimensions: {} 

451 Valid levels: 

452 {} 

453 """.format( 

454 level, remaining_groups, list(level_df.groupby(remaining_groups).groups.keys()) 

455 ) 

456 ) 

457 

458 def _groupby_iterator(self, input_function, groupby, **kwargs): 

459 groupby = [] if groupby is None else groupby 

460 # Will group over the whole dataframe if groupby is None 

461 level_groups = groupby if groupby else np.ones(len(self._data_frame)) 

462 

463 remaining_groups = [group for group in self._all_group_columns if group not in groupby and group is not None] 

464 

465 for level_name, level_df in self._data_frame.groupby(level_groups): 

466 yield input_function(level_name, level_df, remaining_groups, groupby, **kwargs) 

467 

468 def _iterate_groupby_to_chartgrid(self, input_function, groupby, **kwargs): 

469 """Iterate through groups in the test and apply the input function. 

470 

471 Returns ChartGrid""" 

472 chart_grid = ChartGrid() 

473 

474 chart_grid.charts = list(self._groupby_iterator(input_function, groupby, **kwargs)) 

475 

476 return chart_grid 

477 

478 def _iterate_groupby_to_dataframe(self, input_function, groupby, **kwargs): 

479 """Iterate through groups in the test and apply the input function. 

480 

481 Returns pd.DataFrame""" 

482 groupby_iterator = self._groupby_iterator(input_function, groupby, **kwargs) 

483 

484 # Flatten any nested generators. 

485 groupby_iterator = list(groupby_iterator) 

486 if isinstance(groupby_iterator[0], types.GeneratorType): 

487 groupby_iterator = [group for generator in groupby_iterator for group in generator] 

488 

489 results_data_frame = pd.concat(groupby_iterator, axis=0) 

490 

491 results_data_frame = results_data_frame.reset_index(drop=True) 

492 

493 return results_data_frame 

494 

495 def _all_groups(self): 

496 """Return a list of all group keys. 

497 

498 Returns: list""" 

499 groups = list(self._data_frame.groupby(self._all_group_columns).groups.keys()) 

500 return groups 

501 

502 def _add_group_by_columns(self, difference_df, groupby, level_name): 

503 if groupby: 

504 groupby = groupby[0] if len(groupby) == 1 else groupby 

505 if isinstance(groupby, str): 

506 difference_df.insert(0, column=groupby, value=level_name) 

507 else: 

508 for col, val in zip(groupby, level_name): 

509 difference_df.insert(0, column=col, value=val) 

510 

511 

512# class BinomialResponse(BaseTest, metaclass=ABCMeta): 

513# """Binomial Response Variable. 

514# """ 

515 

516# class GaussianResponse(BaseTest, metaclass=ABCMeta): 

517# """Base class for tests of normal response variables 

518 

519# E.g. Revenue per user 

520# """ 

521 

522# pass 

523 

524 

525# class PoissonResponse(BaseTest, metaclass=ABCMeta): 

526# """Base class for tests of poisson response variables. 

527 

528# E.g. # of days active per user per month 

529# """ 

530# pass 

531 

532 

533# class MultinomialResponse(BaseTest, metaclass=ABCMeta): 

534# """Base class for tests of multinomial response variables. 

535 

536# E.g. single choice answer survey 

537# self. 

538# """ 

539 

540# def __init__(self, data_frame, categorical_group_columns, 

541# ordinal_group_column, category_column, value_column): 

542# self._category_column = category_column 

543# self._value_column = value_column 

544# super().__init__(data_frame, categorical_group_columns, 

545# ordinal_group_column) 

546 

547 

548# class CategoricalResponse(BaseTest, metaclass=ABCMeta): 

549# """Base class for tests of categorical response variables. 

550 

551# E.g. multiple choice answer survey 

552# """ 

553 

554# def __init__(self, data_frame, categorical_group_columns, 

555# ordinal_group_column, category_column, value_column): 

556# self._category_column = category_column 

557# self._value_column = value_column 

558# super().__init__(data_frame, categorical_group_columns, 

559# ordinal_group_column) 

560 

561# pass