Coverage for /Users/sebastiana/Documents/Sugarpills/confidence/spotify_confidence/analysis/confidence_utils.py: 42%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

132 statements  

1# Copyright 2017-2020 Spotify AB 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import OrderedDict 

16from concurrent.futures.thread import ThreadPoolExecutor 

17from typing import Union, Iterable, Tuple, List 

18 

19import numpy as np 

20from pandas import DataFrame, concat, Series 

21from scipy.stats import norm 

22 

23from spotify_confidence.analysis.constants import ( 

24 SFX1, 

25 SFX2, 

26) 

27 

28 

29def groupbyApplyParallel(dfGrouped, func_to_apply): 

30 with ThreadPoolExecutor(max_workers=32, thread_name_prefix="groupbyApplyParallel") as p: 

31 ret_list = p.map( 

32 func_to_apply, 

33 [group for name, group in dfGrouped], 

34 ) 

35 return concat(ret_list) 

36 

37 

38def applyParallel(df, func_to_apply, splits=32): 

39 with ThreadPoolExecutor(max_workers=splits, thread_name_prefix="applyParallel") as p: 

40 ret_list = p.map( 

41 func_to_apply, 

42 np.array_split(df, min(splits, len(df))), 

43 ) 

44 return concat(ret_list) 

45 

46 

47def get_all_group_columns(categorical_columns: Iterable, additional_column: str) -> Iterable: 

48 all_columns = listify(categorical_columns) + listify(additional_column) 

49 return list(OrderedDict.fromkeys(all_columns)) 

50 

51 

52def remove_group_columns(categorical_columns: Iterable, additional_column: str) -> Iterable: 

53 od = OrderedDict.fromkeys(categorical_columns) 

54 if additional_column is not None: 

55 del od[additional_column] 

56 return list(od) 

57 

58 

59def validate_categorical_columns(categorical_group_columns: Union[str, Iterable]) -> Iterable: 

60 if isinstance(categorical_group_columns, str): 

61 pass 

62 elif isinstance(categorical_group_columns, Iterable): 

63 pass 

64 else: 

65 raise TypeError( 

66 """categorical_group_columns must be string or 

67 iterable (list of columns) and you must 

68 provide at least one""" 

69 ) 

70 

71 

72def listify(column_s: Union[str, Iterable]) -> List: 

73 if isinstance(column_s, str): 

74 return [column_s] 

75 elif isinstance(column_s, Iterable): 

76 return list(column_s) 

77 elif column_s is None: 

78 return [] 

79 

80 

81def get_remaning_groups(all_groups: Iterable, some_groups: Iterable) -> Iterable: 

82 if some_groups is None: 

83 remaining_groups = all_groups 

84 else: 

85 remaining_groups = [group for group in all_groups if group not in some_groups and group is not None] 

86 return remaining_groups 

87 

88 

89def get_all_categorical_group_columns( 

90 categorical_columns: Union[str, Iterable, None], 

91 metric_column: Union[str, None], 

92 treatment_column: Union[str, None], 

93) -> Iterable: 

94 all_columns = listify(treatment_column) + listify(categorical_columns) + listify(metric_column) 

95 return list(OrderedDict.fromkeys(all_columns)) 

96 

97 

98def validate_levels(df: DataFrame, level_columns: Union[str, Iterable], levels: Iterable): 

99 for level in levels: 

100 try: 

101 df.groupby(level_columns).get_group(level) 

102 except (KeyError, ValueError): 

103 raise ValueError( 

104 """ 

105 Invalid level: '{}' 

106 Must supply a level within the ungrouped dimensions: {} 

107 Valid levels: 

108 {} 

109 """.format( 

110 level, level_columns, list(df.groupby(level_columns).groups.keys()) 

111 ) 

112 ) 

113 

114 

115def validate_and_rename_columns(df: DataFrame, columns: Iterable[str]) -> DataFrame: 

116 for column in columns: 

117 if column is None or column + SFX1 not in df.columns or column + SFX2 not in df.columns: 

118 continue 

119 

120 if (df[column + SFX1].isna() == df[column + SFX1].isna()).all() and ( 

121 df[column + SFX1][df[column + SFX1].notna()] == df[column + SFX1][df[column + SFX1].notna()] 

122 ).all(): 

123 df = df.rename(columns={column + SFX1: column}).drop(columns=[column + SFX2]) 

124 else: 

125 raise ValueError(f"Values of {column} do not agree across levels: {df[[column + SFX1, column + SFX2]]}") 

126 return df 

127 

128 

129def drop_and_rename_columns(df: DataFrame, columns: Iterable[str]) -> DataFrame: 

130 columns_dict = {col + SFX1: col for col in columns} 

131 return df.rename(columns=columns_dict).drop(columns=[col + SFX2 for col in columns]) 

132 

133 

134def level2str(level: Union[str, Tuple]) -> str: 

135 if isinstance(level, str) or not isinstance(level, Iterable): 

136 return str(level) 

137 else: 

138 return ", ".join([str(sub_level) for sub_level in level]) 

139 

140 

141def validate_data(df: DataFrame, columns_that_must_exist, group_columns: Iterable, ordinal_group_column: str): 

142 """Integrity check input dataframe.""" 

143 for col in columns_that_must_exist: 

144 _validate_column(df, col) 

145 

146 if not group_columns: 

147 raise ValueError( 

148 """At least one of `categorical_group_columns` 

149 or `ordinal_group_column` must be specified.""" 

150 ) 

151 

152 for col in group_columns: 

153 _validate_column(df, col) 

154 

155 # Ensure there's at most 1 observation per grouping. 

156 max_one_row_per_grouping = all(df.groupby(group_columns, sort=False).size() <= 1) 

157 if not max_one_row_per_grouping: 

158 raise ValueError("""Each grouping should have at most 1 observation.""") 

159 

160 if ordinal_group_column: 

161 ordinal_column_type = df[ordinal_group_column].dtype.type 

162 if not np.issubdtype(ordinal_column_type, np.number) and not issubclass(ordinal_column_type, np.datetime64): 

163 raise TypeError( 

164 """`ordinal_group_column` is type `{}`. 

165 Must be number or datetime type.""".format( 

166 ordinal_column_type 

167 ) 

168 ) 

169 

170 

171def _validate_column(df: DataFrame, col: str): 

172 if col not in df.columns: 

173 raise ValueError(f"""Column {col} is not in dataframe""") 

174 

175 

176def is_non_inferiority(nim) -> bool: 

177 if isinstance(nim, float): 

178 return not np.isnan(nim) 

179 elif nim is None: 

180 return nim is not None 

181 

182 

183def reset_named_indices(df): 

184 named_indices = [name for name in df.index.names if name is not None] 

185 if len(named_indices) > 0: 

186 return df.reset_index(named_indices, drop=True).sort_index() 

187 else: 

188 return df 

189 

190 

191def _get_finite_bounds(numbers: Series) -> Tuple[float, float]: 

192 finite_numbers = numbers[numbers.abs() != float("inf")] 

193 return finite_numbers.min(), finite_numbers.max() 

194 

195 

196def axis_format_precision(numbers: Series, absolute: bool, extra_zeros: int = 0) -> Tuple[str, float, float]: 

197 min_value, max_value = _get_finite_bounds(numbers) 

198 

199 if max_value == min_value: 

200 return "0.00", min_value, max_value 

201 

202 extra_zeros += 2 if absolute else 0 

203 precision = -int(np.log10(abs(max_value - min_value))) + extra_zeros 

204 zeros = "".join(["0"] * precision) 

205 return "0.{}{}".format(zeros, "" if absolute else "%"), min_value, max_value 

206 

207 

208def to_finite(s: Series, limit: float) -> Series: 

209 return s.clip(-100 * abs(limit), 100 * abs(limit)) 

210 

211 

212def add_color_column(df: DataFrame, cols: Iterable) -> DataFrame: 

213 return df.assign(color=df[cols].agg(level2str, axis="columns")) 

214 

215 

216def power_calculation(mde: float, baseline_var: float, alpha: float, n1: int, n2: int) -> float: 

217 

218 z_alpha = norm.ppf(1 - alpha / 2) 

219 a = abs(mde) / np.sqrt(baseline_var) 

220 b = np.sqrt(n1 * n2 / (n1 + n2)) 

221 z_stat = a * b 

222 

223 return norm.cdf(z_stat - z_alpha) + norm.cdf(-z_stat - z_alpha) 

224 

225 

226def unlist(x): 

227 x0 = x[0] if isinstance(x, list) else x 

228 x1 = np.atleast_2d(x0) 

229 if x1.shape[0] < x1.shape[1]: 

230 x1 = x1.transpose() 

231 return x1 

232 

233 

234def dfmatmul(x, y, outer=True): 

235 

236 x = np.atleast_2d(x) 

237 y = np.atleast_2d(y) 

238 if x.shape[0] < x.shape[1]: 

239 x = x.transpose() 

240 if y.shape[0] < y.shape[1]: 

241 y = y.transpose() 

242 

243 if outer: 

244 out = np.matmul(x, np.transpose(y)) 

245 else: 

246 out = np.matmul(np.transpose(x), y) 

247 

248 if out.size == 1: 

249 out = out.item() 

250 return out