1# Copyright 2017-2020 Spotify AB
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from collections import OrderedDict
16from concurrent.futures.thread import ThreadPoolExecutor
17from typing import Union, Iterable, Tuple, List
18
19import numpy as np
20from pandas import DataFrame, concat, Series
21from scipy.stats import norm
22
23from spotify_confidence.analysis.constants import (
24 SFX1,
25 SFX2,
26)
27
28
29def groupbyApplyParallel(dfGrouped, func_to_apply):
30 with ThreadPoolExecutor(max_workers=32, thread_name_prefix="groupbyApplyParallel") as p:
31 ret_list = p.map(
32 func_to_apply,
33 [group for name, group in dfGrouped],
34 )
35 return concat(ret_list)
36
37
38def applyParallel(df, func_to_apply, splits=32):
39 with ThreadPoolExecutor(max_workers=splits, thread_name_prefix="applyParallel") as p:
40 ret_list = p.map(
41 func_to_apply,
42 np.array_split(df, min(splits, len(df))),
43 )
44 return concat(ret_list)
45
46
47def get_all_group_columns(categorical_columns: Iterable, additional_column: str) -> Iterable:
48 all_columns = listify(categorical_columns) + listify(additional_column)
49 return list(OrderedDict.fromkeys(all_columns))
50
51
52def remove_group_columns(categorical_columns: Iterable, additional_column: str) -> Iterable:
53 od = OrderedDict.fromkeys(categorical_columns)
54 if additional_column is not None:
55 del od[additional_column]
56 return list(od)
57
58
59def validate_categorical_columns(categorical_group_columns: Union[str, Iterable]) -> Iterable:
60 if isinstance(categorical_group_columns, str):
61 pass
62 elif isinstance(categorical_group_columns, Iterable):
63 pass
64 else:
65 raise TypeError(
66 """categorical_group_columns must be string or
67 iterable (list of columns) and you must
68 provide at least one"""
69 )
70
71
72def listify(column_s: Union[str, Iterable]) -> List:
73 if isinstance(column_s, str):
74 return [column_s]
75 elif isinstance(column_s, Iterable):
76 return list(column_s)
77 elif column_s is None:
78 return []
79
80
81def get_remaning_groups(all_groups: Iterable, some_groups: Iterable) -> Iterable:
82 if some_groups is None:
83 remaining_groups = all_groups
84 else:
85 remaining_groups = [group for group in all_groups if group not in some_groups and group is not None]
86 return remaining_groups
87
88
89def get_all_categorical_group_columns(
90 categorical_columns: Union[str, Iterable, None],
91 metric_column: Union[str, None],
92 treatment_column: Union[str, None],
93) -> Iterable:
94 all_columns = listify(treatment_column) + listify(categorical_columns) + listify(metric_column)
95 return list(OrderedDict.fromkeys(all_columns))
96
97
98def validate_levels(df: DataFrame, level_columns: Union[str, Iterable], levels: Iterable):
99 for level in levels:
100 try:
101 df.groupby(level_columns).get_group(level)
102 except (KeyError, ValueError):
103 raise ValueError(
104 """
105 Invalid level: '{}'
106 Must supply a level within the ungrouped dimensions: {}
107 Valid levels:
108 {}
109 """.format(
110 level, level_columns, list(df.groupby(level_columns).groups.keys())
111 )
112 )
113
114
115def validate_and_rename_columns(df: DataFrame, columns: Iterable[str]) -> DataFrame:
116 for column in columns:
117 if column is None or column + SFX1 not in df.columns or column + SFX2 not in df.columns:
118 continue
119
120 if (df[column + SFX1].isna() == df[column + SFX1].isna()).all() and (
121 df[column + SFX1][df[column + SFX1].notna()] == df[column + SFX1][df[column + SFX1].notna()]
122 ).all():
123 df = df.rename(columns={column + SFX1: column}).drop(columns=[column + SFX2])
124 else:
125 raise ValueError(f"Values of {column} do not agree across levels: {df[[column + SFX1, column + SFX2]]}")
126 return df
127
128
129def drop_and_rename_columns(df: DataFrame, columns: Iterable[str]) -> DataFrame:
130 columns_dict = {col + SFX1: col for col in columns}
131 return df.rename(columns=columns_dict).drop(columns=[col + SFX2 for col in columns])
132
133
134def level2str(level: Union[str, Tuple]) -> str:
135 if isinstance(level, str) or not isinstance(level, Iterable):
136 return str(level)
137 else:
138 return ", ".join([str(sub_level) for sub_level in level])
139
140
141def validate_data(df: DataFrame, columns_that_must_exist, group_columns: Iterable, ordinal_group_column: str):
142 """Integrity check input dataframe."""
143 for col in columns_that_must_exist:
144 _validate_column(df, col)
145
146 if not group_columns:
147 raise ValueError(
148 """At least one of `categorical_group_columns`
149 or `ordinal_group_column` must be specified."""
150 )
151
152 for col in group_columns:
153 _validate_column(df, col)
154
155 # Ensure there's at most 1 observation per grouping.
156 max_one_row_per_grouping = all(df.groupby(group_columns, sort=False).size() <= 1)
157 if not max_one_row_per_grouping:
158 raise ValueError("""Each grouping should have at most 1 observation.""")
159
160 if ordinal_group_column:
161 ordinal_column_type = df[ordinal_group_column].dtype.type
162 if not np.issubdtype(ordinal_column_type, np.number) and not issubclass(ordinal_column_type, np.datetime64):
163 raise TypeError(
164 """`ordinal_group_column` is type `{}`.
165 Must be number or datetime type.""".format(
166 ordinal_column_type
167 )
168 )
169
170
171def _validate_column(df: DataFrame, col: str):
172 if col not in df.columns:
173 raise ValueError(f"""Column {col} is not in dataframe""")
174
175
176def is_non_inferiority(nim) -> bool:
177 if isinstance(nim, float):
178 return not np.isnan(nim)
179 elif nim is None:
180 return nim is not None
181
182
183def reset_named_indices(df):
184 named_indices = [name for name in df.index.names if name is not None]
185 if len(named_indices) > 0:
186 return df.reset_index(named_indices, drop=True).sort_index()
187 else:
188 return df
189
190
191def _get_finite_bounds(numbers: Series) -> Tuple[float, float]:
192 finite_numbers = numbers[numbers.abs() != float("inf")]
193 return finite_numbers.min(), finite_numbers.max()
194
195
196def axis_format_precision(numbers: Series, absolute: bool, extra_zeros: int = 0) -> Tuple[str, float, float]:
197 min_value, max_value = _get_finite_bounds(numbers)
198
199 if max_value == min_value:
200 return "0.00", min_value, max_value
201
202 extra_zeros += 2 if absolute else 0
203 precision = -int(np.log10(abs(max_value - min_value))) + extra_zeros
204 zeros = "".join(["0"] * precision)
205 return "0.{}{}".format(zeros, "" if absolute else "%"), min_value, max_value
206
207
208def to_finite(s: Series, limit: float) -> Series:
209 return s.clip(-100 * abs(limit), 100 * abs(limit))
210
211
212def add_color_column(df: DataFrame, cols: Iterable) -> DataFrame:
213 return df.assign(color=df[cols].agg(level2str, axis="columns"))
214
215
216def power_calculation(mde: float, baseline_var: float, alpha: float, n1: int, n2: int) -> float:
217
218 z_alpha = norm.ppf(1 - alpha / 2)
219 a = abs(mde) / np.sqrt(baseline_var)
220 b = np.sqrt(n1 * n2 / (n1 + n2))
221 z_stat = a * b
222
223 return norm.cdf(z_stat - z_alpha) + norm.cdf(-z_stat - z_alpha)
224
225
226def unlist(x):
227 x0 = x[0] if isinstance(x, list) else x
228 x1 = np.atleast_2d(x0)
229 if x1.shape[0] < x1.shape[1]:
230 x1 = x1.transpose()
231 return x1
232
233
234def dfmatmul(x, y, outer=True):
235
236 x = np.atleast_2d(x)
237 y = np.atleast_2d(y)
238 if x.shape[0] < x.shape[1]:
239 x = x.transpose()
240 if y.shape[0] < y.shape[1]:
241 y = y.transpose()
242
243 if outer:
244 out = np.matmul(x, np.transpose(y))
245 else:
246 out = np.matmul(np.transpose(x), y)
247
248 if out.size == 1:
249 out = out.item()
250 return out