Coverage for agentos/tools/rate_limiter.py: 0%

108 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 07:44 +0800

1""" 

2RateLimiter — token bucket and sliding window rate limiters. 

3 

4Supports two algorithms: 

5 - TokenBucket: supports burst (tokens accumulate up to burst size), smooth refill 

6 - SlidingWindow: strict per-window limit, no burst 

7 

8Common interface: 

9 - try_acquire(key) → bool 

10 - acquire_or_wait(key, timeout) → bool (blocking with timeout) 

11 - reset(key) 

12 - stats() → dict 

13""" 

14 

15from __future__ import annotations 

16 

17import threading 

18import time 

19from dataclasses import dataclass, field 

20from enum import Enum 

21from typing import Any, Callable, Dict, List, Optional 

22 

23 

24# ============================================================================ 

25# Rate Limit Exceeded 

26# ============================================================================ 

27 

28class RateLimitExceeded(Exception): 

29 def __init__(self, key: str, limit: float, window: float): 

30 self.key = key 

31 self.limit = limit 

32 self.window = window 

33 super().__init__(f"Rate limit exceeded for '{key}': {limit}/{window}s") 

34 

35 

36# ============================================================================ 

37# TokenBucket 

38# ============================================================================ 

39 

40@dataclass 

41class _BucketState: 

42 tokens: float 

43 last_refill: float 

44 

45 

46class TokenBucket: 

47 """Token bucket rate limiter with burst support. 

48 

49 Usage: 

50 limiter = TokenBucket(rate=10.0, burst=20.0) # 10 tokens/sec, burst up to 20 

51 limiter.try_acquire("api:user:42") # → True/False 

52 limiter.try_acquire("api:user:42", tokens=5) # consume 5 tokens 

53 """ 

54 

55 def __init__(self, rate: float, burst: Optional[float] = None): 

56 if rate <= 0: 

57 raise ValueError("rate must be positive") 

58 self._rate = rate 

59 self._burst = burst if burst is not None else rate 

60 self._buckets: Dict[str, _BucketState] = {} 

61 self._lock = threading.RLock() 

62 self._total_acquired: int = 0 

63 self._total_rejected: int = 0 

64 

65 def try_acquire(self, key: str, tokens: float = 1.0) -> bool: 

66 """Try to acquire tokens. Returns True if allowed.""" 

67 now = time.monotonic() 

68 with self._lock: 

69 bucket = self._buckets.get(key) 

70 if bucket is None: 

71 bucket = _BucketState(tokens=self._burst, last_refill=now) 

72 self._buckets[key] = bucket 

73 else: 

74 # Refill 

75 elapsed = now - bucket.last_refill 

76 bucket.tokens = min(self._burst, bucket.tokens + elapsed * self._rate) 

77 bucket.last_refill = now 

78 

79 if bucket.tokens >= tokens: 

80 bucket.tokens -= tokens 

81 self._total_acquired += 1 

82 return True 

83 else: 

84 self._total_rejected += 1 

85 return False 

86 

87 def acquire_or_wait(self, key: str, timeout: Optional[float] = None, tokens: float = 1.0) -> bool: 

88 """Block until tokens available or timeout.""" 

89 deadline = time.monotonic() + timeout if timeout else None 

90 while True: 

91 if self.try_acquire(key, tokens): 

92 return True 

93 if deadline and time.monotonic() >= deadline: 

94 return False 

95 time.sleep(0.01) 

96 

97 def reset(self, key: str) -> None: 

98 with self._lock: 

99 self._buckets.pop(key, None) 

100 

101 def reset_all(self) -> None: 

102 with self._lock: 

103 self._buckets.clear() 

104 

105 def stats(self) -> Dict[str, Any]: 

106 with self._lock: 

107 return { 

108 "rate": self._rate, 

109 "burst": self._burst, 

110 "active_keys": len(self._buckets), 

111 "total_acquired": self._total_acquired, 

112 "total_rejected": self._total_rejected, 

113 } 

114 

115 @property 

116 def rate(self) -> float: 

117 return self._rate 

118 

119 

120# ============================================================================ 

121# SlidingWindow 

122# ============================================================================ 

123 

124class SlidingWindow: 

125 """Sliding window rate limiter — strict per-window limit, no burst. 

126 

127 Usage: 

128 limiter = SlidingWindow(limit=100, window=60.0) # 100 req per 60s 

129 limiter.try_acquire("api:endpoint") # → True/False 

130 """ 

131 

132 def __init__(self, limit: int, window: float = 60.0): 

133 if limit <= 0: 

134 raise ValueError("limit must be positive") 

135 self._limit = limit 

136 self._window = window 

137 self._windows: Dict[str, List[float]] = {} 

138 self._lock = threading.RLock() 

139 self._total_acquired: int = 0 

140 self._total_rejected: int = 0 

141 

142 def try_acquire(self, key: str) -> bool: 

143 """Try to acquire a slot. Returns True if within limit.""" 

144 now = time.monotonic() 

145 with self._lock: 

146 timestamps = self._windows.get(key) 

147 if timestamps is None: 

148 timestamps = [] 

149 self._windows[key] = timestamps 

150 

151 # Evict expired entries 

152 cutoff = now - self._window 

153 while timestamps and timestamps[0] < cutoff: 

154 timestamps.pop(0) 

155 

156 if len(timestamps) < self._limit: 

157 timestamps.append(now) 

158 self._total_acquired += 1 

159 return True 

160 else: 

161 self._total_rejected += 1 

162 return False 

163 

164 def acquire_or_wait(self, key: str, timeout: Optional[float] = None) -> bool: 

165 deadline = time.monotonic() + timeout if timeout else None 

166 while True: 

167 if self.try_acquire(key): 

168 return True 

169 if deadline and time.monotonic() >= deadline: 

170 return False 

171 time.sleep(0.02) 

172 

173 def reset(self, key: str) -> None: 

174 with self._lock: 

175 self._windows.pop(key, None) 

176 

177 def reset_all(self) -> None: 

178 with self._lock: 

179 self._windows.clear() 

180 

181 def stats(self) -> Dict[str, Any]: 

182 with self._lock: 

183 return { 

184 "limit": self._limit, 

185 "window": self._window, 

186 "active_keys": len(self._windows), 

187 "total_acquired": self._total_acquired, 

188 "total_rejected": self._total_rejected, 

189 } 

190 

191 @property 

192 def limit(self) -> int: 

193 return self._limit