Coverage for src / tracekit / core / numba_backend.py: 0%
119 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Numba JIT compilation backend for performance-critical code paths.
3This module provides a unified interface for Numba JIT compilation with graceful
4fallback when Numba is not available. Provides 10-50x speedup for numerical loops
5that cannot be fully vectorized.
7Usage:
8 from tracekit.core.numba_backend import njit, prange, HAS_NUMBA
10 @njit(parallel=True, cache=True)
11 def fast_function(data):
12 result = np.zeros_like(data)
13 for i in prange(len(data)):
14 result[i] = expensive_computation(data[i])
15 return result
17Performance characteristics:
18 - First call: Compilation overhead (~100-500ms)
19 - Subsequent calls: 10-50x faster than Python loops
20 - Parallel execution: Additional speedup on multi-core systems
21 - Cache: Compilation results cached between runs
23Example:
24 >>> from tracekit.core.numba_backend import njit, HAS_NUMBA
25 >>> import numpy as np
26 >>>
27 >>> @njit(cache=True)
28 >>> def sum_of_squares(arr):
29 ... total = 0.0
30 ... for i in range(len(arr)):
31 ... total += arr[i] ** 2
32 ... return total
33 >>>
34 >>> data = np.random.randn(1_000_000)
35 >>> result = sum_of_squares(data) # Fast on second call
36"""
38from __future__ import annotations
40import functools
41from collections.abc import Callable
42from typing import Any, TypeVar
44import numpy as np
46# Try to import Numba
47try:
48 from numba import guvectorize as _numba_guvectorize # type: ignore[import-untyped]
49 from numba import jit as _numba_jit # type: ignore[import-untyped]
50 from numba import njit as _numba_njit # type: ignore[import-untyped]
51 from numba import prange as _numba_prange # type: ignore[import-untyped]
52 from numba import vectorize as _numba_vectorize # type: ignore[import-untyped]
54 HAS_NUMBA = True
55except ImportError:
56 HAS_NUMBA = False
59# Type variable for generic functions
60F = TypeVar("F", bound=Callable[..., Any])
63if HAS_NUMBA:
64 # Numba is available - use real implementations
65 njit = _numba_njit
66 prange = _numba_prange
67 vectorize = _numba_vectorize
68 guvectorize = _numba_guvectorize
69 jit = _numba_jit
71else:
72 # Numba not available - provide fallback decorators that do nothing
73 def njit(*args: Any, **kwargs: Any) -> Callable[[F], F]:
74 """No-op decorator when Numba is not available.
76 This decorator does nothing but allows code to remain syntactically valid
77 when Numba is not installed.
79 Args:
80 *args: Positional arguments (ignored).
81 **kwargs: Keyword arguments (ignored).
83 Returns:
84 Decorator function or decorated function.
85 """
87 def decorator(func: F) -> F:
88 @functools.wraps(func)
89 def wrapper(*call_args: Any, **call_kwargs: Any) -> Any:
90 return func(*call_args, **call_kwargs)
92 return wrapper # type: ignore[return-value]
94 # Handle both @njit and @njit() syntax
95 if len(args) == 1 and callable(args[0]) and not kwargs:
96 return decorator(args[0]) # type: ignore[no-any-return]
97 return decorator # type: ignore[no-any-return]
99 def prange(*args: Any, **kwargs: Any) -> range:
100 """Fallback to regular range when Numba is not available.
102 Args:
103 *args: Same as range().
104 **kwargs: Same as range().
106 Returns:
107 Standard Python range object.
108 """
109 return range(*args, **kwargs)
111 def vectorize(*args: Any, **kwargs: Any) -> Callable[[F], F]:
112 """No-op decorator when Numba is not available.
114 Args:
115 *args: Positional arguments (ignored).
116 **kwargs: Keyword arguments (ignored).
118 Returns:
119 Decorator that returns the original function.
120 """
122 def decorator(func: F) -> F:
123 return func
125 if len(args) == 1 and callable(args[0]):
126 return decorator(args[0]) # type: ignore[no-any-return]
127 return decorator # type: ignore[no-any-return]
129 def guvectorize(*args: Any, **kwargs: Any) -> Callable[[F], F]:
130 """No-op decorator when Numba is not available.
132 Args:
133 *args: Positional arguments (ignored).
134 **kwargs: Keyword arguments (ignored).
136 Returns:
137 Decorator that returns the original function.
138 """
140 def decorator(func: F) -> F:
141 return func
143 if len(args) == 1 and callable(args[0]):
144 return decorator(args[0]) # type: ignore[no-any-return]
145 return decorator # type: ignore[no-any-return]
147 def jit(*args: Any, **kwargs: Any) -> Callable[[F], F]:
148 """No-op decorator when Numba is not available.
150 Args:
151 *args: Positional arguments (ignored).
152 **kwargs: Keyword arguments (ignored).
154 Returns:
155 Decorator that returns the original function.
156 """
158 def decorator(func: F) -> F:
159 @functools.wraps(func)
160 def wrapper(*call_args: Any, **call_kwargs: Any) -> Any:
161 return func(*call_args, **call_kwargs)
163 return wrapper # type: ignore[return-value]
165 if len(args) == 1 and callable(args[0]) and not kwargs:
166 return decorator(args[0]) # type: ignore[no-any-return]
167 return decorator # type: ignore[no-any-return]
170def get_optimal_numba_config(
171 parallel: bool = False,
172 cache: bool = True,
173 fastmath: bool = False,
174 nogil: bool = False,
175) -> dict[str, Any]:
176 """Get optimal Numba configuration for given requirements.
178 Args:
179 parallel: Enable parallel execution using prange.
180 cache: Enable compilation caching for faster subsequent runs.
181 fastmath: Enable fast math optimizations (may reduce precision).
182 nogil: Release GIL during execution (useful for threading).
184 Returns:
185 Dictionary of Numba configuration options.
187 Example:
188 >>> config = get_optimal_numba_config(parallel=True, cache=True)
189 >>> @njit(**config)
190 >>> def my_function(data):
191 ... pass
192 """
193 if not HAS_NUMBA:
194 return {}
196 return {
197 "parallel": parallel,
198 "cache": cache,
199 "fastmath": fastmath,
200 "nogil": nogil,
201 }
204# Example Numba-optimized functions for common operations
207@njit(cache=True) # type: ignore[misc,untyped-decorator]
208def find_crossings_numba(
209 data: np.ndarray, # type: ignore[type-arg]
210 threshold: float,
211 direction: int = 0,
212) -> np.ndarray: # type: ignore[type-arg]
213 """Find threshold crossings with Numba acceleration.
215 Args:
216 data: Input signal data.
217 threshold: Threshold value to detect crossings.
218 direction: 0=both, 1=rising only, -1=falling only.
220 Returns:
221 Array of indices where crossings occur.
222 """
223 crossings = []
224 for i in range(1, len(data)):
225 prev_val = data[i - 1]
226 curr_val = data[i]
228 if direction >= 0: # Rising or both
229 if prev_val < threshold <= curr_val:
230 crossings.append(i)
231 if direction <= 0 and direction != 1: # Falling or both
232 if prev_val > threshold >= curr_val:
233 crossings.append(i)
235 return np.array(crossings, dtype=np.int64)
238@njit(parallel=True, cache=True) # type: ignore[misc,untyped-decorator]
239def moving_average_numba(
240 data: np.ndarray, # type: ignore[type-arg]
241 window_size: int,
242) -> np.ndarray: # type: ignore[type-arg]
243 """Compute moving average with Numba parallel acceleration.
245 Args:
246 data: Input signal data.
247 window_size: Size of the moving window.
249 Returns:
250 Array of moving averages.
251 """
252 n = len(data)
253 result = np.zeros(n - window_size + 1, dtype=np.float64)
255 for i in prange(len(result)):
256 total = 0.0
257 for j in range(window_size):
258 total += data[i + j]
259 result[i] = total / window_size
261 return result
264@njit(cache=True) # type: ignore[misc,untyped-decorator]
265def argrelextrema_numba(
266 data: np.ndarray, # type: ignore[type-arg]
267 comparator: int,
268 order: int = 1,
269) -> np.ndarray: # type: ignore[type-arg]
270 """Find relative extrema (peaks/valleys) with Numba acceleration.
272 Args:
273 data: Input signal data.
274 comparator: 1 for maxima (peaks), -1 for minima (valleys).
275 order: How many points on each side to use for comparison.
277 Returns:
278 Array of indices where extrema occur.
279 """
280 extrema = []
281 n = len(data)
283 for i in range(order, n - order):
284 is_extremum = True
286 for j in range(1, order + 1):
287 if comparator > 0: # Maximum
288 if data[i] <= data[i - j] or data[i] <= data[i + j]:
289 is_extremum = False
290 break
291 else: # Minimum
292 if data[i] >= data[i - j] or data[i] >= data[i + j]:
293 is_extremum = False
294 break
296 if is_extremum:
297 extrema.append(i)
299 return np.array(extrema, dtype=np.int64)
302@njit(cache=True) # type: ignore[misc,untyped-decorator]
303def interpolate_linear_numba(
304 x: np.ndarray, # type: ignore[type-arg]
305 y: np.ndarray, # type: ignore[type-arg]
306 x_new: np.ndarray, # type: ignore[type-arg]
307) -> np.ndarray: # type: ignore[type-arg]
308 """Linear interpolation with Numba acceleration.
310 Args:
311 x: Original x coordinates (must be sorted).
312 y: Original y values.
313 x_new: New x coordinates to interpolate.
315 Returns:
316 Interpolated y values at x_new.
317 """
318 n = len(x)
319 m = len(x_new)
320 y_new = np.zeros(m, dtype=np.float64)
322 for i in range(m):
323 xi = x_new[i]
325 # Binary search for bracketing indices
326 left = 0
327 right = n - 1
329 while left < right - 1:
330 mid = (left + right) // 2
331 if x[mid] <= xi:
332 left = mid
333 else:
334 right = mid
336 # Linear interpolation
337 if xi <= x[0]:
338 y_new[i] = y[0]
339 elif xi >= x[n - 1]:
340 y_new[i] = y[n - 1]
341 else:
342 x0, x1 = x[left], x[right]
343 y0, y1 = y[left], y[right]
344 t = (xi - x0) / (x1 - x0)
345 y_new[i] = y0 + t * (y1 - y0)
347 return y_new
350__all__ = [
351 "HAS_NUMBA",
352 "argrelextrema_numba",
353 "find_crossings_numba",
354 "get_optimal_numba_config",
355 "guvectorize",
356 "interpolate_linear_numba",
357 "jit",
358 "moving_average_numba",
359 "njit",
360 "prange",
361 "vectorize",
362]