Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from distutils.version import LooseVersion 

2import types 

3from typing import Any, Callable, Dict, Optional, Tuple 

4 

5import numpy as np 

6 

7from pandas._typing import Scalar 

8from pandas.compat._optional import import_optional_dependency 

9 

10 

11def make_rolling_apply( 

12 func: Callable[..., Scalar], 

13 args: Tuple, 

14 nogil: bool, 

15 parallel: bool, 

16 nopython: bool, 

17): 

18 """ 

19 Creates a JITted rolling apply function with a JITted version of 

20 the user's function. 

21 

22 Parameters 

23 ---------- 

24 func : function 

25 function to be applied to each window and will be JITed 

26 args : tuple 

27 *args to be passed into the function 

28 nogil : bool 

29 nogil parameter from engine_kwargs for numba.jit 

30 parallel : bool 

31 parallel parameter from engine_kwargs for numba.jit 

32 nopython : bool 

33 nopython parameter from engine_kwargs for numba.jit 

34 

35 Returns 

36 ------- 

37 Numba function 

38 """ 

39 numba = import_optional_dependency("numba") 

40 

41 if parallel: 

42 loop_range = numba.prange 

43 else: 

44 loop_range = range 

45 

46 if LooseVersion(numba.__version__) >= LooseVersion("0.49.0"): 

47 is_jitted = numba.extending.is_jitted(func) 

48 else: 

49 is_jitted = isinstance(func, numba.targets.registry.CPUDispatcher) 

50 

51 if is_jitted: 

52 # Don't jit a user passed jitted function 

53 numba_func = func 

54 else: 

55 

56 @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel) 

57 def numba_func(window, *_args): 

58 if getattr(np, func.__name__, False) is func or isinstance( 

59 func, types.BuiltinFunctionType 

60 ): 

61 jf = func 

62 else: 

63 jf = numba.jit(func, nopython=nopython, nogil=nogil) 

64 

65 def impl(window, *_args): 

66 return jf(window, *_args) 

67 

68 return impl 

69 

70 @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) 

71 def roll_apply( 

72 values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int, 

73 ) -> np.ndarray: 

74 result = np.empty(len(begin)) 

75 for i in loop_range(len(result)): 

76 start = begin[i] 

77 stop = end[i] 

78 window = values[start:stop] 

79 count_nan = np.sum(np.isnan(window)) 

80 if len(window) - count_nan >= minimum_periods: 

81 result[i] = numba_func(window, *args) 

82 else: 

83 result[i] = np.nan 

84 return result 

85 

86 return roll_apply 

87 

88 

89def generate_numba_apply_func( 

90 args: Tuple, 

91 kwargs: Dict[str, Any], 

92 func: Callable[..., Scalar], 

93 engine_kwargs: Optional[Dict[str, bool]], 

94): 

95 """ 

96 Generate a numba jitted apply function specified by values from engine_kwargs. 

97 

98 1. jit the user's function 

99 2. Return a rolling apply function with the jitted function inline 

100 

101 Configurations specified in engine_kwargs apply to both the user's 

102 function _AND_ the rolling apply function. 

103 

104 Parameters 

105 ---------- 

106 args : tuple 

107 *args to be passed into the function 

108 kwargs : dict 

109 **kwargs to be passed into the function 

110 func : function 

111 function to be applied to each window and will be JITed 

112 engine_kwargs : dict 

113 dictionary of arguments to be passed into numba.jit 

114 

115 Returns 

116 ------- 

117 Numba function 

118 """ 

119 

120 if engine_kwargs is None: 

121 engine_kwargs = {} 

122 

123 nopython = engine_kwargs.get("nopython", True) 

124 nogil = engine_kwargs.get("nogil", False) 

125 parallel = engine_kwargs.get("parallel", False) 

126 

127 if kwargs and nopython: 

128 raise ValueError( 

129 "numba does not support kwargs with nopython=True: " 

130 "https://github.com/numba/numba/issues/2916" 

131 ) 

132 

133 return make_rolling_apply(func, args, nogil, parallel, nopython)