Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/pandas/core/window/numba_.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from distutils.version import LooseVersion
2import types
3from typing import Any, Callable, Dict, Optional, Tuple
5import numpy as np
7from pandas._typing import Scalar
8from pandas.compat._optional import import_optional_dependency
11def make_rolling_apply(
12 func: Callable[..., Scalar],
13 args: Tuple,
14 nogil: bool,
15 parallel: bool,
16 nopython: bool,
17):
18 """
19 Creates a JITted rolling apply function with a JITted version of
20 the user's function.
22 Parameters
23 ----------
24 func : function
25 function to be applied to each window and will be JITed
26 args : tuple
27 *args to be passed into the function
28 nogil : bool
29 nogil parameter from engine_kwargs for numba.jit
30 parallel : bool
31 parallel parameter from engine_kwargs for numba.jit
32 nopython : bool
33 nopython parameter from engine_kwargs for numba.jit
35 Returns
36 -------
37 Numba function
38 """
39 numba = import_optional_dependency("numba")
41 if parallel:
42 loop_range = numba.prange
43 else:
44 loop_range = range
46 if LooseVersion(numba.__version__) >= LooseVersion("0.49.0"):
47 is_jitted = numba.extending.is_jitted(func)
48 else:
49 is_jitted = isinstance(func, numba.targets.registry.CPUDispatcher)
51 if is_jitted:
52 # Don't jit a user passed jitted function
53 numba_func = func
54 else:
56 @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
57 def numba_func(window, *_args):
58 if getattr(np, func.__name__, False) is func or isinstance(
59 func, types.BuiltinFunctionType
60 ):
61 jf = func
62 else:
63 jf = numba.jit(func, nopython=nopython, nogil=nogil)
65 def impl(window, *_args):
66 return jf(window, *_args)
68 return impl
70 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
71 def roll_apply(
72 values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int,
73 ) -> np.ndarray:
74 result = np.empty(len(begin))
75 for i in loop_range(len(result)):
76 start = begin[i]
77 stop = end[i]
78 window = values[start:stop]
79 count_nan = np.sum(np.isnan(window))
80 if len(window) - count_nan >= minimum_periods:
81 result[i] = numba_func(window, *args)
82 else:
83 result[i] = np.nan
84 return result
86 return roll_apply
89def generate_numba_apply_func(
90 args: Tuple,
91 kwargs: Dict[str, Any],
92 func: Callable[..., Scalar],
93 engine_kwargs: Optional[Dict[str, bool]],
94):
95 """
96 Generate a numba jitted apply function specified by values from engine_kwargs.
98 1. jit the user's function
99 2. Return a rolling apply function with the jitted function inline
101 Configurations specified in engine_kwargs apply to both the user's
102 function _AND_ the rolling apply function.
104 Parameters
105 ----------
106 args : tuple
107 *args to be passed into the function
108 kwargs : dict
109 **kwargs to be passed into the function
110 func : function
111 function to be applied to each window and will be JITed
112 engine_kwargs : dict
113 dictionary of arguments to be passed into numba.jit
115 Returns
116 -------
117 Numba function
118 """
120 if engine_kwargs is None:
121 engine_kwargs = {}
123 nopython = engine_kwargs.get("nopython", True)
124 nogil = engine_kwargs.get("nogil", False)
125 parallel = engine_kwargs.get("parallel", False)
127 if kwargs and nopython:
128 raise ValueError(
129 "numba does not support kwargs with nopython=True: "
130 "https://github.com/numba/numba/issues/2916"
131 )
133 return make_rolling_apply(func, args, nogil, parallel, nopython)