Coverage for src / tracekit / core / numba_backend.py: 0%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Numba JIT compilation backend for performance-critical code paths. 

2 

3This module provides a unified interface for Numba JIT compilation with graceful 

4fallback when Numba is not available. Provides 10-50x speedup for numerical loops 

5that cannot be fully vectorized. 

6 

7Usage: 

8 from tracekit.core.numba_backend import njit, prange, HAS_NUMBA 

9 

10 @njit(parallel=True, cache=True) 

11 def fast_function(data): 

12 result = np.zeros_like(data) 

13 for i in prange(len(data)): 

14 result[i] = expensive_computation(data[i]) 

15 return result 

16 

17Performance characteristics: 

18 - First call: Compilation overhead (~100-500ms) 

19 - Subsequent calls: 10-50x faster than Python loops 

20 - Parallel execution: Additional speedup on multi-core systems 

21 - Cache: Compilation results cached between runs 

22 

23Example: 

24 >>> from tracekit.core.numba_backend import njit, HAS_NUMBA 

25 >>> import numpy as np 

26 >>> 

27 >>> @njit(cache=True) 

28 >>> def sum_of_squares(arr): 

29 ... total = 0.0 

30 ... for i in range(len(arr)): 

31 ... total += arr[i] ** 2 

32 ... return total 

33 >>> 

34 >>> data = np.random.randn(1_000_000) 

35 >>> result = sum_of_squares(data) # Fast on second call 

36""" 

37 

38from __future__ import annotations 

39 

40import functools 

41from collections.abc import Callable 

42from typing import Any, TypeVar 

43 

44import numpy as np 

45 

46# Try to import Numba 

47try: 

48 from numba import guvectorize as _numba_guvectorize # type: ignore[import-untyped] 

49 from numba import jit as _numba_jit # type: ignore[import-untyped] 

50 from numba import njit as _numba_njit # type: ignore[import-untyped] 

51 from numba import prange as _numba_prange # type: ignore[import-untyped] 

52 from numba import vectorize as _numba_vectorize # type: ignore[import-untyped] 

53 

54 HAS_NUMBA = True 

55except ImportError: 

56 HAS_NUMBA = False 

57 

58 

59# Type variable for generic functions 

60F = TypeVar("F", bound=Callable[..., Any]) 

61 

62 

63if HAS_NUMBA: 

64 # Numba is available - use real implementations 

65 njit = _numba_njit 

66 prange = _numba_prange 

67 vectorize = _numba_vectorize 

68 guvectorize = _numba_guvectorize 

69 jit = _numba_jit 

70 

71else: 

72 # Numba not available - provide fallback decorators that do nothing 

73 def njit(*args: Any, **kwargs: Any) -> Callable[[F], F]: 

74 """No-op decorator when Numba is not available. 

75 

76 This decorator does nothing but allows code to remain syntactically valid 

77 when Numba is not installed. 

78 

79 Args: 

80 *args: Positional arguments (ignored). 

81 **kwargs: Keyword arguments (ignored). 

82 

83 Returns: 

84 Decorator function or decorated function. 

85 """ 

86 

87 def decorator(func: F) -> F: 

88 @functools.wraps(func) 

89 def wrapper(*call_args: Any, **call_kwargs: Any) -> Any: 

90 return func(*call_args, **call_kwargs) 

91 

92 return wrapper # type: ignore[return-value] 

93 

94 # Handle both @njit and @njit() syntax 

95 if len(args) == 1 and callable(args[0]) and not kwargs: 

96 return decorator(args[0]) # type: ignore[no-any-return] 

97 return decorator # type: ignore[no-any-return] 

98 

99 def prange(*args: Any, **kwargs: Any) -> range: 

100 """Fallback to regular range when Numba is not available. 

101 

102 Args: 

103 *args: Same as range(). 

104 **kwargs: Same as range(). 

105 

106 Returns: 

107 Standard Python range object. 

108 """ 

109 return range(*args, **kwargs) 

110 

111 def vectorize(*args: Any, **kwargs: Any) -> Callable[[F], F]: 

112 """No-op decorator when Numba is not available. 

113 

114 Args: 

115 *args: Positional arguments (ignored). 

116 **kwargs: Keyword arguments (ignored). 

117 

118 Returns: 

119 Decorator that returns the original function. 

120 """ 

121 

122 def decorator(func: F) -> F: 

123 return func 

124 

125 if len(args) == 1 and callable(args[0]): 

126 return decorator(args[0]) # type: ignore[no-any-return] 

127 return decorator # type: ignore[no-any-return] 

128 

129 def guvectorize(*args: Any, **kwargs: Any) -> Callable[[F], F]: 

130 """No-op decorator when Numba is not available. 

131 

132 Args: 

133 *args: Positional arguments (ignored). 

134 **kwargs: Keyword arguments (ignored). 

135 

136 Returns: 

137 Decorator that returns the original function. 

138 """ 

139 

140 def decorator(func: F) -> F: 

141 return func 

142 

143 if len(args) == 1 and callable(args[0]): 

144 return decorator(args[0]) # type: ignore[no-any-return] 

145 return decorator # type: ignore[no-any-return] 

146 

147 def jit(*args: Any, **kwargs: Any) -> Callable[[F], F]: 

148 """No-op decorator when Numba is not available. 

149 

150 Args: 

151 *args: Positional arguments (ignored). 

152 **kwargs: Keyword arguments (ignored). 

153 

154 Returns: 

155 Decorator that returns the original function. 

156 """ 

157 

158 def decorator(func: F) -> F: 

159 @functools.wraps(func) 

160 def wrapper(*call_args: Any, **call_kwargs: Any) -> Any: 

161 return func(*call_args, **call_kwargs) 

162 

163 return wrapper # type: ignore[return-value] 

164 

165 if len(args) == 1 and callable(args[0]) and not kwargs: 

166 return decorator(args[0]) # type: ignore[no-any-return] 

167 return decorator # type: ignore[no-any-return] 

168 

169 

170def get_optimal_numba_config( 

171 parallel: bool = False, 

172 cache: bool = True, 

173 fastmath: bool = False, 

174 nogil: bool = False, 

175) -> dict[str, Any]: 

176 """Get optimal Numba configuration for given requirements. 

177 

178 Args: 

179 parallel: Enable parallel execution using prange. 

180 cache: Enable compilation caching for faster subsequent runs. 

181 fastmath: Enable fast math optimizations (may reduce precision). 

182 nogil: Release GIL during execution (useful for threading). 

183 

184 Returns: 

185 Dictionary of Numba configuration options. 

186 

187 Example: 

188 >>> config = get_optimal_numba_config(parallel=True, cache=True) 

189 >>> @njit(**config) 

190 >>> def my_function(data): 

191 ... pass 

192 """ 

193 if not HAS_NUMBA: 

194 return {} 

195 

196 return { 

197 "parallel": parallel, 

198 "cache": cache, 

199 "fastmath": fastmath, 

200 "nogil": nogil, 

201 } 

202 

203 

204# Example Numba-optimized functions for common operations 

205 

206 

207@njit(cache=True) # type: ignore[misc,untyped-decorator] 

208def find_crossings_numba( 

209 data: np.ndarray, # type: ignore[type-arg] 

210 threshold: float, 

211 direction: int = 0, 

212) -> np.ndarray: # type: ignore[type-arg] 

213 """Find threshold crossings with Numba acceleration. 

214 

215 Args: 

216 data: Input signal data. 

217 threshold: Threshold value to detect crossings. 

218 direction: 0=both, 1=rising only, -1=falling only. 

219 

220 Returns: 

221 Array of indices where crossings occur. 

222 """ 

223 crossings = [] 

224 for i in range(1, len(data)): 

225 prev_val = data[i - 1] 

226 curr_val = data[i] 

227 

228 if direction >= 0: # Rising or both 

229 if prev_val < threshold <= curr_val: 

230 crossings.append(i) 

231 if direction <= 0 and direction != 1: # Falling or both 

232 if prev_val > threshold >= curr_val: 

233 crossings.append(i) 

234 

235 return np.array(crossings, dtype=np.int64) 

236 

237 

238@njit(parallel=True, cache=True) # type: ignore[misc,untyped-decorator] 

239def moving_average_numba( 

240 data: np.ndarray, # type: ignore[type-arg] 

241 window_size: int, 

242) -> np.ndarray: # type: ignore[type-arg] 

243 """Compute moving average with Numba parallel acceleration. 

244 

245 Args: 

246 data: Input signal data. 

247 window_size: Size of the moving window. 

248 

249 Returns: 

250 Array of moving averages. 

251 """ 

252 n = len(data) 

253 result = np.zeros(n - window_size + 1, dtype=np.float64) 

254 

255 for i in prange(len(result)): 

256 total = 0.0 

257 for j in range(window_size): 

258 total += data[i + j] 

259 result[i] = total / window_size 

260 

261 return result 

262 

263 

264@njit(cache=True) # type: ignore[misc,untyped-decorator] 

265def argrelextrema_numba( 

266 data: np.ndarray, # type: ignore[type-arg] 

267 comparator: int, 

268 order: int = 1, 

269) -> np.ndarray: # type: ignore[type-arg] 

270 """Find relative extrema (peaks/valleys) with Numba acceleration. 

271 

272 Args: 

273 data: Input signal data. 

274 comparator: 1 for maxima (peaks), -1 for minima (valleys). 

275 order: How many points on each side to use for comparison. 

276 

277 Returns: 

278 Array of indices where extrema occur. 

279 """ 

280 extrema = [] 

281 n = len(data) 

282 

283 for i in range(order, n - order): 

284 is_extremum = True 

285 

286 for j in range(1, order + 1): 

287 if comparator > 0: # Maximum 

288 if data[i] <= data[i - j] or data[i] <= data[i + j]: 

289 is_extremum = False 

290 break 

291 else: # Minimum 

292 if data[i] >= data[i - j] or data[i] >= data[i + j]: 

293 is_extremum = False 

294 break 

295 

296 if is_extremum: 

297 extrema.append(i) 

298 

299 return np.array(extrema, dtype=np.int64) 

300 

301 

302@njit(cache=True) # type: ignore[misc,untyped-decorator] 

303def interpolate_linear_numba( 

304 x: np.ndarray, # type: ignore[type-arg] 

305 y: np.ndarray, # type: ignore[type-arg] 

306 x_new: np.ndarray, # type: ignore[type-arg] 

307) -> np.ndarray: # type: ignore[type-arg] 

308 """Linear interpolation with Numba acceleration. 

309 

310 Args: 

311 x: Original x coordinates (must be sorted). 

312 y: Original y values. 

313 x_new: New x coordinates to interpolate. 

314 

315 Returns: 

316 Interpolated y values at x_new. 

317 """ 

318 n = len(x) 

319 m = len(x_new) 

320 y_new = np.zeros(m, dtype=np.float64) 

321 

322 for i in range(m): 

323 xi = x_new[i] 

324 

325 # Binary search for bracketing indices 

326 left = 0 

327 right = n - 1 

328 

329 while left < right - 1: 

330 mid = (left + right) // 2 

331 if x[mid] <= xi: 

332 left = mid 

333 else: 

334 right = mid 

335 

336 # Linear interpolation 

337 if xi <= x[0]: 

338 y_new[i] = y[0] 

339 elif xi >= x[n - 1]: 

340 y_new[i] = y[n - 1] 

341 else: 

342 x0, x1 = x[left], x[right] 

343 y0, y1 = y[left], y[right] 

344 t = (xi - x0) / (x1 - x0) 

345 y_new[i] = y0 + t * (y1 - y0) 

346 

347 return y_new 

348 

349 

350__all__ = [ 

351 "HAS_NUMBA", 

352 "argrelextrema_numba", 

353 "find_crossings_numba", 

354 "get_optimal_numba_config", 

355 "guvectorize", 

356 "interpolate_linear_numba", 

357 "jit", 

358 "moving_average_numba", 

359 "njit", 

360 "prange", 

361 "vectorize", 

362]