Coverage for /Users/sebastiana/Documents/Sugarpills/confidence/spotify_confidence/samplesize/sample_size_calculator.py: 72%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

206 statements  

1# Copyright 2017-2020 Spotify AB 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from ipywidgets import widgets 

16from IPython.display import display 

17import scipy.stats as st 

18import numpy as np 

19import math 

20 

21 

22class SampleSize(object): 

23 """Frequentist sample size calculations. 

24 

25 See: Duflo, E., Glennerster, R., & Kremer, M. (2007). Using 

26 randomization in development economics research: A toolkit. 

27 Handbook of Development Economics, 4, 3895–3962. pp28-31. 

28 

29 Methods 

30 ------- 

31 binomial() 

32 Calculate the required sample size for a binomial metric. 

33 binomial_interactive() 

34 Interactive version of the binomial() function for notebook use. 

35 continuous() 

36 Calculate the required sample size for a continuous metric. 

37 continuous_interactive() 

38 Interactive version of the continuous() function for notebook use. 

39 achieved_power() 

40 TODO: Calculate achieved power given reached sample size. 

41 

42 """ 

43 

44 default_alpha = 0.05 

45 default_power = 0.85 

46 default_treatments = 2 

47 default_comparisons = "control_vs_all" 

48 default_treatment_costs = None 

49 default_treatment_allocations = None 

50 default_bonferroni = False 

51 

52 @staticmethod 

53 def continuous( 

54 average_absolute_mde, 

55 baseline_variance, 

56 alpha=default_alpha, 

57 power=default_power, 

58 treatments=default_treatments, 

59 comparisons=default_comparisons, 

60 treatment_costs=default_treatment_costs, 

61 treatment_allocations=default_treatment_allocations, 

62 bonferroni_correction=default_bonferroni, 

63 ): 

64 """Calculate the required sample size for a binomial metric. 

65 

66 Args: 

67 average_absolute_mde (float): Average absolute minimal detectable 

68 effect size (mean difference) across all tests. 

69 baseline_variance (float): Baseline metric variance in 

70 target population. 

71 alpha (float, optional): Probability of Type I error 

72 (false positive). Defaults to 0.05. 

73 power (float, optional): 1 - B, where B is the probability of 

74 Type II error (false negative). Defaults to 0.85. 

75 treatments (int, optional): Number of treatment variants 

76 in the a/b test, including control. Defaults to 2. 

77 comparisons ({'control_vs_all', 'all_vs_all'}, optional): Which 

78 treatments to compare. Defaults to 'control_vs_all'. 

79 treatment_costs (numpy.ndarray, optional): Array with same length 

80 as the number of treatments containing positive floats 

81 specifying the treatments' relative costs. Defaults to equal 

82 cost for all treatments. 

83 treatment_allocations (numpy.ndarray, optional): Array with same 

84 length as the number of treatments containing proportion of 

85 sample allocated to each treatment. If not specified defaults 

86 to automatic allocation. 

87 bonferroni_correction (bool): Whether Bonferroni correction should 

88 be applied to control the false positive rate across all 

89 comparisons. Defaults to false. 

90 

91 Returns: 

92 int: Total required sample size across all treatments. 

93 list of int: Required sample size for each treatment. 

94 list of float: Proportion of total sample allocated 

95 to each treatment. 

96 

97 Raises: 

98 ValueError: If `power` is less than or equal to`alpha`. 

99 

100 """ 

101 mde = SampleSize._clean_continuous_mde(average_absolute_mde) 

102 baseline_variance = SampleSize._validate_positive(baseline_variance) 

103 

104 return SampleSize._calculate_samplesize( 

105 mde, 

106 baseline_variance, 

107 alpha, 

108 power, 

109 treatments, 

110 comparisons, 

111 treatment_costs, 

112 treatment_allocations, 

113 bonferroni_correction, 

114 ) 

115 

116 @staticmethod 

117 def continuous_interactive(): 

118 SampleSize._calculate_sample_size_interactive("continuous") 

119 

120 @staticmethod 

121 def binomial( 

122 absolute_percentage_mde, 

123 baseline_proportion, 

124 alpha=default_alpha, 

125 power=default_power, 

126 treatments=default_treatments, 

127 comparisons=default_comparisons, 

128 treatment_costs=default_treatment_costs, 

129 treatment_allocations=default_treatment_allocations, 

130 bonferroni_correction=default_bonferroni, 

131 ): 

132 """Calculate the required sample size for a binomial metric. 

133 

134 Args: 

135 absolute_percentage_mde (float): Average absolute minimal 

136 detectable effect size across all tests. 

137 baseline_proportion (float): Baseline metric proportion in 

138 target population. 

139 alpha (float, optional): Probability of Type I error 

140 (false positive). Defaults to 0.05. 

141 power (float, optional): 1 - B, where B is the probability of 

142 Type II error (false negative). Defaults to 0.85. 

143 treatments (int, optional): Number of treatment variants 

144 in the a/b test, including control. Defaults to 2. 

145 comparisons ({'control_vs_all', 'all_vs_all'}, optional): Which 

146 treatments to compare. Defaults to 'control_vs_all'. 

147 treatment_costs (numpy.ndarray, optional): Array with same length 

148 as the number of treatments containing positive floats 

149 specifying the treatments' relative costs. Defaults to equal 

150 cost for all treatments. 

151 treatment_allocations (numpy.ndarray, optional): Array with same 

152 length as the number of treatments containing proportion of 

153 sample allocated to each treatment. If not specified defaults 

154 to automatic allocation. 

155 bonferroni_correction (bool): Whether Bonferroni correction should 

156 be applied to control the false positive rate across all 

157 comparisons. Defaults to false. 

158 

159 Returns: 

160 int: Total required sample size across all treatments. 

161 list of int: Required sample size for each treatment. 

162 list of float: Proportion of total sample allocated 

163 to each treatment. 

164 

165 Raises: 

166 ValueError: If `power` is less than or equal to`alpha`. 

167 ValueError: If `baseline_proportion` - `absolute_percentage_mde` 

168 < 0 and `baseline_proportion` + `absolute_percentage_mde` > 1. 

169 I.e. if the mde always implies a non-valid percentage. 

170 

171 """ 

172 baseline = SampleSize._validate_percentage(baseline_proportion) 

173 mde = SampleSize._clean_binomial_mde(absolute_percentage_mde, baseline) 

174 baseline_variance = baseline * (1 - baseline) 

175 

176 return SampleSize._calculate_samplesize( 

177 mde, 

178 baseline_variance, 

179 alpha, 

180 power, 

181 treatments, 

182 comparisons, 

183 treatment_costs, 

184 treatment_allocations, 

185 bonferroni_correction, 

186 ) 

187 

188 @staticmethod 

189 def binomial_interactive(): 

190 SampleSize._calculate_sample_size_interactive("binomial") 

191 

192 @staticmethod 

193 def _calculate_samplesize( 

194 mde, 

195 baseline_variance, 

196 alpha, 

197 power, 

198 treatments, 

199 comparisons, 

200 treatment_costs, 

201 treatment_allocations, 

202 bonferroni, 

203 ): 

204 power = SampleSize._validate_percentage(power) 

205 treatments = SampleSize._clean_treatments(treatments) 

206 comparisons = SampleSize._clean_comparisons(comparisons) 

207 treatment_costs = SampleSize._clean_treatment_costs(treatments, treatment_costs) 

208 

209 alpha = SampleSize._get_alpha(alpha, power, bonferroni, treatments, comparisons) 

210 treatment_allocations = SampleSize._get_treatment_allocations( 

211 treatments, comparisons, treatment_costs, treatment_allocations 

212 ) 

213 

214 num_comparisons = SampleSize._num_comparisons(treatments, comparisons) 

215 comparison_matrix = SampleSize._get_comparison_matrix(treatments, comparisons) 

216 

217 z_alpha = st.norm.ppf(1 - alpha / 2) 

218 z_power = st.norm.ppf(power) 

219 

220 a = np.power(1.0 / (num_comparisons * mde), 2) 

221 b = np.power(z_power + z_alpha, 2) 

222 c = baseline_variance 

223 d = 0 

224 for i in range(treatments): 

225 for j in range(treatments): 

226 if comparison_matrix[i, j] > 0: 

227 d += np.sqrt(1.0 / treatment_allocations[i] + 1.0 / treatment_allocations[j]) 

228 d = np.power(d, 2) 

229 

230 n_total = np.ceil(a * b * c * d).astype(int) 

231 n_allocation = np.ceil(treatment_allocations * n_total).astype(int) 

232 return n_total, n_allocation, treatment_allocations 

233 

234 @staticmethod 

235 def _calculate_sample_size_interactive(metric): 

236 style = {"description_width": "initial"} 

237 desc_layout = widgets.Layout(width="50%") 

238 if metric == "continuous": 

239 mde_widget = widgets.FloatText( 

240 value=0.01, 

241 description="", 

242 ) 

243 

244 mde_desc = widgets.HTML( 

245 """ 

246 <small> 

247 This is the smallest absolute difference in averages that 

248 any of your comparisons can detect at the given statistical 

249 rigour. 

250 </small> 

251 """, 

252 layout=desc_layout, 

253 ) 

254 

255 baseline_title = widgets.HTML("<strong>Baseline variance</strong>") 

256 baseline_widget = widgets.BoundedFloatText( 

257 value=1.0, 

258 min=0.00001, 

259 max=1000000000.0, 

260 description="", 

261 ) 

262 baseline_desc = widgets.HTML( 

263 """ 

264 <small> 

265 This is the expected variance of the metric among 

266 users in your control group. 

267 </small> 

268 """, 

269 layout=desc_layout, 

270 ) 

271 

272 elif metric == "binomial": 

273 mde_widget = widgets.FloatLogSlider( 

274 value=0.003, base=10, min=-4, max=np.log10(0.5), step=0.001, description="", readout_format=".4f" 

275 ) 

276 

277 mde_desc = widgets.HTML( 

278 """ 

279 <small> 

280 This is the smallest absolute difference (percentage 

281 point / 100) that any of your comparisons can detect 

282 at the given statistical rigour. 

283 </small> 

284 """, 

285 layout=desc_layout, 

286 ) 

287 

288 baseline_title = widgets.HTML("<strong>Baseline " "proportion</strong>") 

289 baseline_widget = widgets.FloatSlider(value=0.5, min=0.00001, max=0.99999, step=0.01, description="") 

290 baseline_desc = widgets.HTML( 

291 """ 

292 <small> 

293 This is the expected value of the metric among 

294 users in your control group. 

295 </small> 

296 """, 

297 layout=desc_layout, 

298 ) 

299 

300 else: 

301 raise ValueError("metric must be `continuous` or `binomial`") 

302 

303 alpha_widget = widgets.FloatSlider( 

304 value=0.05, min=0.001, max=0.10, step=0.001, description=r"\(\alpha\)", readout_format=".3f" 

305 ) 

306 

307 power_widget = widgets.FloatSlider( 

308 value=0.85, min=0.8, max=0.99, step=0.01, description=r"Power, \( 1-\beta\)" 

309 ) 

310 

311 treatments_widget = widgets.IntSlider( 

312 value=2, min=2, max=20, step=1, description="Groups (including control)", style=style 

313 ) 

314 

315 comparisons_widget = widgets.RadioButtons( 

316 options=["Control vs. All", "All vs. All"], 

317 value="Control vs. All", 

318 description="Groups to compare", 

319 style=style, 

320 ) 

321 

322 control_group_widget = widgets.FloatLogSlider( 

323 value=1, 

324 step=0.1, 

325 base=10, 

326 min=0, 

327 max=4, 

328 description="Control group advantage", 

329 readout=False, 

330 style=style, 

331 ) 

332 control_group_description = widgets.HTML( 

333 """ 

334 <small> 

335 Sometime we want the control group to be bigger than what is 

336 strictly optimal. This can be either because we can collect 

337 samples quickly enough anyway or because we believe the 

338 treatment variants are riskier. Boosting the size of the 

339 control group comes at the cost of an increased total 

340 required sample. 

341 </small> 

342 """, 

343 layout=desc_layout, 

344 ) 

345 

346 bonferroni_widget = widgets.Checkbox(value=False, description="Apply Bonferroni correction") 

347 

348 risk_reset_btn = widgets.Button( 

349 description=" ", 

350 disabled=False, 

351 button_style="", 

352 tooltip="Reset variant risk", 

353 icon="repeat", 

354 layout=widgets.Layout(width="40px"), 

355 ) 

356 

357 def reset_widget(b): 

358 control_group_widget.value = 1 

359 

360 risk_reset_btn.on_click(reset_widget) 

361 

362 ui = widgets.VBox( 

363 [ 

364 widgets.HTML("<h4>Target metric</h4>"), 

365 widgets.VBox( 

366 children=[ 

367 widgets.HTML("<strong>Minimal Detectable Effect " "size</strong>"), 

368 mde_widget, 

369 mde_desc, 

370 ], 

371 ), 

372 widgets.VBox( 

373 children=[baseline_title, baseline_widget, baseline_desc], 

374 ), 

375 widgets.HTML("<h4>Statistical rigour</h4>"), 

376 alpha_widget, 

377 power_widget, 

378 bonferroni_widget, 

379 widgets.HTML("<h4>Treatment groups</h4>"), 

380 treatments_widget, 

381 comparisons_widget, 

382 widgets.VBox( 

383 children=[widgets.HBox([control_group_widget, risk_reset_btn]), control_group_description] 

384 ), 

385 ] 

386 ) 

387 

388 def show_samplesize( 

389 mde, baseline, alpha, power, treatments, comparisons_readable, bonferroni_correction, relative_risk 

390 ): 

391 if comparisons_readable == "Control vs. All": 

392 comparisons = "control_vs_all" 

393 else: 

394 comparisons = "all_vs_all" 

395 

396 treatment_costs = np.ones(treatments) 

397 treatment_costs[1:] = relative_risk 

398 treatment_allocations = None 

399 

400 if metric == "continuous": 

401 n_optimal, _, _ = SampleSize.continuous( 

402 mde, 

403 baseline, 

404 alpha, 

405 power, 

406 treatments, 

407 comparisons, 

408 None, 

409 treatment_allocations, 

410 bonferroni_correction, 

411 ) 

412 n_tot, n_cell, prop_cell = SampleSize.continuous( 

413 mde, 

414 baseline, 

415 alpha, 

416 power, 

417 treatments, 

418 comparisons, 

419 treatment_costs, 

420 treatment_allocations, 

421 bonferroni_correction, 

422 ) 

423 code_html = widgets.HTML( 

424 "<pre><code>" 

425 f"SampleSize.continuous(average_absolute_mde={ mde },\n" 

426 f" baseline_variance={ baseline },\n" 

427 f" alpha={ alpha },\n" 

428 f" power={ power },\n" 

429 f" treatments={ treatments },\n" 

430 f" comparisons=" 

431 f"'{ comparisons }',\n" 

432 f" treatment_costs=" 

433 f"{ list(treatment_costs) },\n" 

434 f" treatment_allocations=None,\n" 

435 f" bonferroni_correction=" 

436 f"{ bonferroni_correction })" 

437 "<code></pre>" 

438 ) 

439 else: 

440 n_tot, n_cell, prop_cell = SampleSize.binomial( 

441 mde, 

442 baseline, 

443 alpha, 

444 power, 

445 treatments, 

446 comparisons, 

447 treatment_costs, 

448 treatment_allocations, 

449 bonferroni_correction, 

450 ) 

451 n_optimal, _, _ = SampleSize.binomial( 

452 mde, 

453 baseline, 

454 alpha, 

455 power, 

456 treatments, 

457 comparisons, 

458 None, 

459 treatment_allocations, 

460 bonferroni_correction, 

461 ) 

462 code_html = widgets.HTML( 

463 "<pre><code>" 

464 f"SampleSize.binomial(absolute_percentage_mde={ mde },\n" 

465 f" baseline_proportion=" 

466 f"{ baseline },\n" 

467 f" alpha={ alpha },\n" 

468 f" power={ power },\n" 

469 f" treatments={ treatments },\n" 

470 f" comparisons=" 

471 f"'{ comparisons }',\n" 

472 f" treatment_costs=" 

473 f"{ list(treatment_costs) },\n" 

474 f" treatment_allocations=None,\n" 

475 f" bonferroni_correction=" 

476 f"{ bonferroni_correction })" 

477 "<code></pre>" 

478 ) 

479 

480 def compare_against_optimal(current, optimal): 

481 if current == optimal: 

482 return "" 

483 else: 

484 return ( 

485 f"<br><small><em>{current/optimal:.1f}x " 

486 f"optimal group allocation of {optimal:,}." 

487 f"</em></small>" 

488 ) 

489 

490 display( 

491 widgets.HTML( 

492 f"<h4>Required sample size</h4>" 

493 f"<strong>Total:</strong><br>{n_tot:,}" 

494 f"{compare_against_optimal(n_tot, n_optimal)}" 

495 ) 

496 ) 

497 cell_str = "<strong>Sample size in each cell</strong>" 

498 for i in range(len(n_cell)): 

499 if i == 0: 

500 treatment = "Control" 

501 else: 

502 treatment = "Variant " + str(i) 

503 

504 cell_str += f"<br><em>{treatment}:</em> " f"{n_cell[i]:,} ({prop_cell[i]*100:.1f}%)" 

505 

506 display(widgets.HTML(cell_str)) 

507 display(code_html) 

508 

509 out = widgets.interactive_output( 

510 show_samplesize, 

511 { 

512 "mde": mde_widget, 

513 "baseline": baseline_widget, 

514 "alpha": alpha_widget, 

515 "power": power_widget, 

516 "treatments": treatments_widget, 

517 "comparisons_readable": comparisons_widget, 

518 "bonferroni_correction": bonferroni_widget, 

519 "relative_risk": control_group_widget, 

520 }, 

521 ) 

522 

523 display(ui, out) 

524 

525 @staticmethod 

526 def _clean_treatments(treatments): 

527 """Validate treatments input. 

528 

529 Args: 

530 treatments (int): Number of treatment variants in the a/b test, 

531 including control. Defaults to 2. 

532 

533 Returns: 

534 int: Number of treatment variants. 

535 

536 Raises: 

537 TypeError: If `treatments` is not a number. 

538 ValueError: If `treatments` is not an integer greater than or 

539 equal to two. 

540 

541 """ 

542 error_string = "Treatments must be a whole number " "greater than or equal to two" 

543 try: 

544 remainder = treatments % 1 

545 except TypeError: 

546 raise TypeError(error_string) 

547 

548 if remainder != 0: 

549 raise ValueError(error_string) 

550 elif treatments < 2: 

551 raise ValueError(error_string) 

552 else: 

553 return int(treatments) 

554 

555 @staticmethod 

556 def _clean_comparisons(comparisons): 

557 """Validate comparisons input. 

558 

559 Args: 

560 comparisons ({'control_vs_all', 'all_vs_all'}): Which treatments 

561 to compare. 

562 

563 Returns: 

564 str: Which treatments to compare. 

565 

566 Raises: 

567 ValueError: If `comparisons` is not one of 'control_vs_all' or 

568 'all_vs_all'. 

569 

570 """ 

571 if comparisons not in ("control_vs_all", "all_vs_all"): 

572 raise ValueError("comparisons must be either " '"control_vs_all" or "all_vs_all"') 

573 else: 

574 return comparisons 

575 

576 @staticmethod 

577 def _num_comparisons(treatments, comparisons): 

578 """Calculate the number of hypothesis tests. 

579 

580 When comparing all treatments against each other, calculating 

581 the number of hypothesis tests is an n-choose-k problem with 

582 n=treatments, and k=2: https://en.wikipedia.org/wiki/Combination. 

583 

584 Args: 

585 treatments (int): Number of treatment variants in the a/b test, 

586 including control. 

587 comparisons ({'control_vs_all', 'all_vs_all'}): Which treatments 

588 to compare. 

589 

590 Returns: 

591 int: Number of hypothesis tests to conduct. 

592 

593 """ 

594 treatments = SampleSize._clean_treatments(treatments) 

595 comparisons = SampleSize._clean_comparisons(comparisons) 

596 

597 if comparisons == "control_vs_all": 

598 num_comparisons = treatments - 1 

599 else: 

600 num_comparisons = math.factorial(treatments) / (2 * math.factorial(treatments - 2)) 

601 

602 return int(num_comparisons) 

603 

604 @staticmethod 

605 def _get_comparison_matrix(treatments, comparisons): 

606 """Transform categorical comparison to matrix. 

607 

608 Args: 

609 treatments (int): Number of treatment variants in the a/b test, 

610 including control. 

611 comparisons ({'control_vs_all', 'all_vs_all'}): Which treatments 

612 to compare. 

613 

614 Returns: 

615 numpy.ndarray: Lower triangular matrix of size 

616 `treatments x treatments` with 1 in position i, j 

617 if treatment i is to be compared with treatment j. 

618 

619 """ 

620 treatments = SampleSize._clean_treatments(treatments) 

621 comparisons = SampleSize._clean_comparisons(comparisons) 

622 

623 if comparisons == "control_vs_all": 

624 comparison_matrix = np.zeros((treatments, treatments)) 

625 comparison_matrix[1:, 0] = 1 

626 

627 else: 

628 comparison_matrix = np.ones((treatments, treatments)) 

629 comparison_matrix = np.tril(comparison_matrix, -1) 

630 

631 return comparison_matrix 

632 

633 @staticmethod 

634 def _clean_treatment_costs(treatments, treatment_costs): 

635 """Validate or generate treatment cost array. 

636 

637 Args: 

638 treatment_costs (numpy.ndarray, None): Array with same length as 

639 the number of treatments containing positive floats specifying 

640 the treatments' relative costs. None also accepted in which 

641 case equal relative costs are returned. 

642 treatments (int): Number of treatment variants in the a/b test, 

643 including control. 

644 

645 Returns: 

646 numpy.ndarray: Array with each treatment's cost. 

647 

648 Raises: 

649 TypeError: If `treatment_costs` is not None or a numpy.ndarray. 

650 TypeError: If the length of customs `treatment_costs` is not the 

651 same as the number of treatments. 

652 ValueError: If the values of custom `treatment_costs` are not all 

653 positive and sum to one. 

654 

655 """ 

656 treatments = SampleSize._clean_treatments(treatments) 

657 

658 if treatment_costs is None: 

659 # Default equal cost of all cells 

660 return np.ones(treatments) 

661 

662 elif ( 

663 not (isinstance(treatment_costs, np.ndarray) or isinstance(treatment_costs, list)) 

664 or len(treatment_costs) != treatments 

665 ): 

666 raise TypeError( 

667 "treatment_costs must be a list or numpy array of" "the same length as the number of treatments" 

668 ) 

669 

670 try: 

671 treatment_costs = np.array(treatment_costs) 

672 if not (treatment_costs > 0).all(): 

673 raise ValueError("treatment_costs values must all be positive") 

674 

675 except TypeError: 

676 raise TypeError("treatment_costs array must only contain numbers") 

677 

678 return treatment_costs 

679 

680 @staticmethod 

681 def _get_treatment_allocations(treatments, comparisons, treatment_costs, treatment_allocations): 

682 """Validate or generate treatment allocation array. 

683 

684 See the footnote on page 31 of "Duflo, E., Glennerster, R., & Kremer, 

685 M. (2007). Using randomization in development economics research: A 

686 toolkit. Handbook of Development Economics, 4, 3895–3962." for math. 

687 

688 Args: 

689 treatments (int, optional): Number of treatment variants in the a/b 

690 test, including control. Defaults to 2. 

691 comparisons ({'control_vs_all', 'all_vs_all'}, optional): Which 

692 treatments to compare. Defaults to 'control_vs_all'. 

693 treatment_costs (numpy.ndarray, optional): Array with same length 

694 as the number of treatments containing positive floats 

695 specifying the treatments' relative costs. Defaults to equal 

696 cost for all treatments. 

697 treatment_allocations (numpy.ndarray/list/tuple, optional): Array 

698 with same length as the number of treatments containing 

699 proportion of sample allocated to each treatment. If not 

700 specified defaults to automatic allocation. 

701 

702 Returns: 

703 numpy.ndarray: Array with same length as the number of treatments 

704 containing proportion of sample allocated to each treatment. 

705 

706 Raises: 

707 TypeError: If `treatment_allocations` is not None or a 

708 numpy.ndarray. 

709 TypeError: If the length of custom `treatment_allocations` is not 

710 the same as the number of treatments. 

711 ValueError: If the values of custom `treatment_allocations` are 

712 not all positive and sum to one. 

713 

714 """ 

715 treatments = SampleSize._clean_treatments(treatments) 

716 

717 if treatment_allocations is not None: 

718 if isinstance(treatment_allocations, list) or isinstance(treatment_allocations, tuple): 

719 treatment_allocations = np.array(treatment_allocations) 

720 

721 if not isinstance(treatment_allocations, np.ndarray) or len(treatment_allocations) != treatments: 

722 raise TypeError( 

723 "treatment_allocations must be a numpy array " 

724 "or list of the same length as the number of " 

725 "treatments" 

726 ) 

727 

728 elif not (treatment_allocations > 0).all(): 

729 raise ValueError("treatment_allocations values " "must all be positive") 

730 

731 elif not math.isclose(treatment_allocations.sum(), 1.0): 

732 raise ValueError("treatment_allocations values " "must sum to one") 

733 

734 else: 

735 return np.array(treatment_allocations) 

736 

737 comparisons = SampleSize._get_comparison_matrix(treatments, comparisons) 

738 weighted_comparisons = comparisons / np.sum(comparisons) 

739 treatment_costs = SampleSize._clean_treatment_costs(treatments, treatment_costs) 

740 

741 ratios = np.zeros((treatments, treatments)) 

742 for i in range(treatments): 

743 sum_importance_i = np.sum(weighted_comparisons[:, i]) + np.sum(weighted_comparisons[i, :]) 

744 for j in range(treatments): 

745 sum_importance_j = np.sum(weighted_comparisons[:, j]) + np.sum(weighted_comparisons[j, :]) 

746 ratios[i, j] = sum_importance_i / sum_importance_j * np.sqrt(treatment_costs[j] / treatment_costs[i]) 

747 

748 treatment_allocations = ratios[:, 0] / np.sum(ratios[:, 0]) 

749 

750 return treatment_allocations 

751 

752 @staticmethod 

753 def _get_alpha(alpha, power, bonferroni, treatments, comparisons): 

754 """Validate and potentially correct false positive rate. 

755 

756 Args: 

757 alpha (float): Probability of Type I error (false positive). 

758 bonferroni (bool): Whether Bonferroni correction should be applied 

759 to control the false positive rate across all comparisons. 

760 treatments (int): Number of treatment variants in the a/b test, 

761 including control. 

762 comparisons ({'control_vs_all', 'all_vs_all'}, optional): Which 

763 treatments to compare. 

764 

765 Returns: 

766 float: False positive rate, potentially Bonferroni corrected. 

767 

768 Raises: 

769 ValueError: If `power` is less than or equal to `alpha`. 

770 TypeError: If `bonferroni` is not a bool. 

771 

772 """ 

773 power = SampleSize._validate_percentage(power) 

774 alpha = SampleSize._validate_percentage(alpha) 

775 

776 if power <= alpha: 

777 raise ValueError("alpha must be less than power") 

778 elif not isinstance(bonferroni, bool): 

779 raise TypeError("bonferroni must be a bool") 

780 

781 num_comparisons = SampleSize._num_comparisons(treatments, comparisons) 

782 

783 if bonferroni: 

784 return alpha / num_comparisons 

785 else: 

786 return alpha 

787 

788 @staticmethod 

789 def _validate_percentage(num): 

790 """Validate that num is a percentage. 

791 

792 Args: 

793 num(float): Valid percentage. 

794 

795 Returns: 

796 float: Valid percentage. 

797 

798 Raises: 

799 TypeError: If `num` is not a float. 

800 ValueError: If `num` is not between zero and one. 

801 

802 """ 

803 if not isinstance(num, float): 

804 raise TypeError("num must be a float") 

805 elif not 0 < num < 1: 

806 raise ValueError("num must be between 0 and 1") 

807 else: 

808 return num 

809 

810 @staticmethod 

811 def _validate_positive(val): 

812 """Validate that val is positive. 

813 

814 Args: 

815 val (float): Value to validate. 

816 

817 Returns: 

818 float: Value. 

819 

820 Raises: 

821 ValueError: If value is non-positive. 

822 

823 """ 

824 if not val > 0: 

825 raise ValueError("value must be positive") 

826 else: 

827 return val 

828 

829 @staticmethod 

830 def _clean_continuous_mde(average_absolute_mde): 

831 """Validate that mde is not equal to zero. 

832 

833 Args: 

834 average_absolute_mde (float): Average absolute minimal detectable 

835 effect size (mean difference) across all tests. 

836 

837 Returns: 

838 float: Average absolute minimal detectable effect size. 

839 

840 Raises: 

841 ValueError: If `average_absolute_mde` is zero. 

842 

843 """ 

844 if math.isclose(average_absolute_mde, 0.0): 

845 raise ValueError("average_absolute_mde cannot be zero") 

846 else: 

847 return average_absolute_mde 

848 

849 @staticmethod 

850 def _clean_binomial_mde(absolute_percentage_mde, baseline_proportion): 

851 """Validate that mde is percentage and not too large. 

852 

853 Args: 

854 absolute_percentage_mde (float): Average absolute minimal 

855 detectable effect size across all tests. 

856 baseline_proportion (float): Baseline metric proportion in 

857 target population. 

858 

859 Returns: 

860 float: Average absolute minimal detectable effect size. 

861 

862 """ 

863 mde = SampleSize._validate_percentage(absolute_percentage_mde) 

864 baseline = SampleSize._validate_percentage(baseline_proportion) 

865 

866 if baseline - mde < 0 and baseline + mde > 1: 

867 raise ValueError("absolute_percentage_mde is too large " "given baseline_proportion") 

868 else: 

869 return mde