Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/pandas/core/window/common.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Common utility functions for rolling operations"""
2from collections import defaultdict
3from typing import Callable, Optional
4import warnings
6import numpy as np
8from pandas.core.dtypes.common import is_integer
9from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
11import pandas.core.common as com
12from pandas.core.generic import _shared_docs
13from pandas.core.groupby.base import GroupByMixin
14from pandas.core.indexes.api import MultiIndex
16_shared_docs = dict(**_shared_docs)
17_doc_template = """
18 Returns
19 -------
20 Series or DataFrame
21 Return type is determined by the caller.
23 See Also
24 --------
25 Series.%(name)s : Series %(name)s.
26 DataFrame.%(name)s : DataFrame %(name)s.
27"""
30def _dispatch(name: str, *args, **kwargs):
31 """
32 Dispatch to apply.
33 """
35 def outer(self, *args, **kwargs):
36 def f(x):
37 x = self._shallow_copy(x, groupby=self._groupby)
38 return getattr(x, name)(*args, **kwargs)
40 return self._groupby.apply(f)
42 outer.__name__ = name
43 return outer
46class WindowGroupByMixin(GroupByMixin):
47 """
48 Provide the groupby facilities.
49 """
51 def __init__(self, obj, *args, **kwargs):
52 kwargs.pop("parent", None)
53 groupby = kwargs.pop("groupby", None)
54 if groupby is None:
55 groupby, obj = obj, obj.obj
56 self._groupby = groupby
57 self._groupby.mutated = True
58 self._groupby.grouper.mutated = True
59 super().__init__(obj, *args, **kwargs)
61 count = _dispatch("count")
62 corr = _dispatch("corr", other=None, pairwise=None)
63 cov = _dispatch("cov", other=None, pairwise=None)
65 def _apply(
66 self,
67 func: Callable,
68 center: bool,
69 require_min_periods: int = 0,
70 floor: int = 1,
71 is_weighted: bool = False,
72 name: Optional[str] = None,
73 use_numba_cache: bool = False,
74 **kwargs,
75 ):
76 """
77 Dispatch to apply; we are stripping all of the _apply kwargs and
78 performing the original function call on the grouped object.
79 """
80 kwargs.pop("floor", None)
82 # TODO: can we de-duplicate with _dispatch?
83 def f(x, name=name, *args):
84 x = self._shallow_copy(x)
86 if isinstance(name, str):
87 return getattr(x, name)(*args, **kwargs)
89 return x.apply(name, *args, **kwargs)
91 return self._groupby.apply(f)
94def _flex_binary_moment(arg1, arg2, f, pairwise=False):
96 if not (
97 isinstance(arg1, (np.ndarray, ABCSeries, ABCDataFrame))
98 and isinstance(arg2, (np.ndarray, ABCSeries, ABCDataFrame))
99 ):
100 raise TypeError(
101 "arguments to moment function must be of type "
102 "np.ndarray/Series/DataFrame"
103 )
105 if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance(
106 arg2, (np.ndarray, ABCSeries)
107 ):
108 X, Y = prep_binary(arg1, arg2)
109 return f(X, Y)
111 elif isinstance(arg1, ABCDataFrame):
112 from pandas import DataFrame
114 def dataframe_from_int_dict(data, frame_template):
115 result = DataFrame(data, index=frame_template.index)
116 if len(result.columns) > 0:
117 result.columns = frame_template.columns[result.columns]
118 return result
120 results = {}
121 if isinstance(arg2, ABCDataFrame):
122 if pairwise is False:
123 if arg1 is arg2:
124 # special case in order to handle duplicate column names
125 for i, col in enumerate(arg1.columns):
126 results[i] = f(arg1.iloc[:, i], arg2.iloc[:, i])
127 return dataframe_from_int_dict(results, arg1)
128 else:
129 if not arg1.columns.is_unique:
130 raise ValueError("'arg1' columns are not unique")
131 if not arg2.columns.is_unique:
132 raise ValueError("'arg2' columns are not unique")
133 with warnings.catch_warnings(record=True):
134 warnings.simplefilter("ignore", RuntimeWarning)
135 X, Y = arg1.align(arg2, join="outer")
136 X = X + 0 * Y
137 Y = Y + 0 * X
139 with warnings.catch_warnings(record=True):
140 warnings.simplefilter("ignore", RuntimeWarning)
141 res_columns = arg1.columns.union(arg2.columns)
142 for col in res_columns:
143 if col in X and col in Y:
144 results[col] = f(X[col], Y[col])
145 return DataFrame(results, index=X.index, columns=res_columns)
146 elif pairwise is True:
147 results = defaultdict(dict)
148 for i, k1 in enumerate(arg1.columns):
149 for j, k2 in enumerate(arg2.columns):
150 if j < i and arg2 is arg1:
151 # Symmetric case
152 results[i][j] = results[j][i]
153 else:
154 results[i][j] = f(
155 *prep_binary(arg1.iloc[:, i], arg2.iloc[:, j])
156 )
158 from pandas import concat
160 result_index = arg1.index.union(arg2.index)
161 if len(result_index):
163 # construct result frame
164 result = concat(
165 [
166 concat(
167 [results[i][j] for j, c in enumerate(arg2.columns)],
168 ignore_index=True,
169 )
170 for i, c in enumerate(arg1.columns)
171 ],
172 ignore_index=True,
173 axis=1,
174 )
175 result.columns = arg1.columns
177 # set the index and reorder
178 if arg2.columns.nlevels > 1:
179 result.index = MultiIndex.from_product(
180 arg2.columns.levels + [result_index]
181 )
182 result = result.reorder_levels([2, 0, 1]).sort_index()
183 else:
184 result.index = MultiIndex.from_product(
185 [range(len(arg2.columns)), range(len(result_index))]
186 )
187 result = result.swaplevel(1, 0).sort_index()
188 result.index = MultiIndex.from_product(
189 [result_index] + [arg2.columns]
190 )
191 else:
193 # empty result
194 result = DataFrame(
195 index=MultiIndex(
196 levels=[arg1.index, arg2.columns], codes=[[], []]
197 ),
198 columns=arg2.columns,
199 dtype="float64",
200 )
202 # reset our index names to arg1 names
203 # reset our column names to arg2 names
204 # careful not to mutate the original names
205 result.columns = result.columns.set_names(arg1.columns.names)
206 result.index = result.index.set_names(
207 result_index.names + arg2.columns.names
208 )
210 return result
212 else:
213 raise ValueError("'pairwise' is not True/False")
214 else:
215 results = {
216 i: f(*prep_binary(arg1.iloc[:, i], arg2))
217 for i, col in enumerate(arg1.columns)
218 }
219 return dataframe_from_int_dict(results, arg1)
221 else:
222 return _flex_binary_moment(arg2, arg1, f)
225def _get_center_of_mass(comass, span, halflife, alpha):
226 valid_count = com.count_not_none(comass, span, halflife, alpha)
227 if valid_count > 1:
228 raise ValueError("comass, span, halflife, and alpha are mutually exclusive")
230 # Convert to center of mass; domain checks ensure 0 < alpha <= 1
231 if comass is not None:
232 if comass < 0:
233 raise ValueError("comass must satisfy: comass >= 0")
234 elif span is not None:
235 if span < 1:
236 raise ValueError("span must satisfy: span >= 1")
237 comass = (span - 1) / 2.0
238 elif halflife is not None:
239 if halflife <= 0:
240 raise ValueError("halflife must satisfy: halflife > 0")
241 decay = 1 - np.exp(np.log(0.5) / halflife)
242 comass = 1 / decay - 1
243 elif alpha is not None:
244 if alpha <= 0 or alpha > 1:
245 raise ValueError("alpha must satisfy: 0 < alpha <= 1")
246 comass = (1.0 - alpha) / alpha
247 else:
248 raise ValueError("Must pass one of comass, span, halflife, or alpha")
250 return float(comass)
253def calculate_center_offset(window):
254 if not is_integer(window):
255 window = len(window)
256 return int((window - 1) / 2.0)
259def calculate_min_periods(
260 window: int,
261 min_periods: Optional[int],
262 num_values: int,
263 required_min_periods: int,
264 floor: int,
265) -> int:
266 """
267 Calculates final minimum periods value for rolling aggregations.
269 Parameters
270 ----------
271 window : passed window value
272 min_periods : passed min periods value
273 num_values : total number of values
274 required_min_periods : required min periods per aggregation function
275 floor : required min periods per aggregation function
277 Returns
278 -------
279 min_periods : int
280 """
281 if min_periods is None:
282 min_periods = window
283 else:
284 min_periods = max(required_min_periods, min_periods)
285 if min_periods > window:
286 raise ValueError(f"min_periods {min_periods} must be <= window {window}")
287 elif min_periods > num_values:
288 min_periods = num_values + 1
289 elif min_periods < 0:
290 raise ValueError("min_periods must be >= 0")
291 return max(min_periods, floor)
294def zsqrt(x):
295 with np.errstate(all="ignore"):
296 result = np.sqrt(x)
297 mask = x < 0
299 if isinstance(x, ABCDataFrame):
300 if mask.values.any():
301 result[mask] = 0
302 else:
303 if mask.any():
304 result[mask] = 0
306 return result
309def prep_binary(arg1, arg2):
310 if not isinstance(arg2, type(arg1)):
311 raise Exception("Input arrays must be of the same type!")
313 # mask out values, this also makes a common index...
314 X = arg1 + 0 * arg2
315 Y = arg2 + 0 * arg1
317 return X, Y
320def get_weighted_roll_func(cfunc: Callable) -> Callable:
321 def func(arg, window, min_periods=None):
322 if min_periods is None:
323 min_periods = len(window)
324 return cfunc(arg, window, min_periods)
326 return func