Coverage for agentos/queue/rate_limiter.py: 47%

117 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v0.60 Rate Limiter — 流量控制。 

3Token Bucket + Sliding Window + Concurrency Limiter + 多级配额。 

4""" 

5 

6from __future__ import annotations 

7 

8import asyncio 

9import time 

10import math 

11from dataclasses import dataclass, field 

12from enum import Enum 

13from typing import Optional 

14 

15 

16class RateLimitStrategy(str, Enum): 

17 

18 """限流策略枚举。""" 

19 

20 TOKEN_BUCKET = "token_bucket" 

21 SLIDING_WINDOW = "sliding_window" 

22 FIXED_WINDOW = "fixed_window" 

23 

24 

25@dataclass 

26class RateLimitConfig: 

27 """限流配置。""" 

28 strategy: RateLimitStrategy = RateLimitStrategy.TOKEN_BUCKET 

29 max_requests: int = 60 # 每单位时间的最大请求数 

30 per_seconds: float = 60.0 # 时间窗口(秒) 

31 burst_size: int = 10 # 突发容量(token bucket 专用) 

32 max_concurrent: int = 5 # 最大并发数 

33 queue_timeout: float = 30.0 # 排队超时 

34 retry_after_header: bool = True # 是否在拒绝时返回 Retry-After 

35 

36 

37@dataclass 

38class RateLimitResult: 

39 """限流检查结果。""" 

40 allowed: bool 

41 remaining: int = 0 

42 reset_at: float = 0.0 

43 retry_after: float = 0.0 

44 limit: int = 0 

45 reason: str = "" 

46 

47 

48class TokenBucket: 

49 """令牌桶算法实现。""" 

50 

51 def __init__(self, rate: float, capacity: int): 

52 self.rate = rate # 令牌填充速率(个/秒) 

53 self.capacity = capacity # 桶容量(最大突发) 

54 self.tokens = float(capacity) 

55 self.last_refill = time.monotonic() 

56 self._lock = asyncio.Lock() 

57 

58 async def consume(self, tokens: int = 1) -> bool: 

59 async with self._lock: 

60 self._refill() 

61 if self.tokens >= tokens: 

62 self.tokens -= tokens 

63 return True 

64 return False 

65 

66 def _refill(self): 

67 now = time.monotonic() 

68 elapsed = now - self.last_refill 

69 self.tokens = min(self.capacity, self.tokens + elapsed * self.rate) 

70 self.last_refill = now 

71 

72 @property 

73 def available(self) -> float: 

74 self._refill() 

75 return self.tokens 

76 

77 

78class SlidingWindow: 

79 """滑动窗口计数器。""" 

80 

81 def __init__(self, max_requests: int, window_seconds: float): 

82 self.max_requests = max_requests 

83 self.window = window_seconds 

84 self._timestamps: list[float] = [] 

85 self._lock = asyncio.Lock() 

86 

87 async def allow(self) -> bool: 

88 async with self._lock: 

89 now = time.monotonic() 

90 cutoff = now - self.window 

91 self._timestamps = [t for t in self._timestamps if t > cutoff] 

92 if len(self._timestamps) < self.max_requests: 

93 self._timestamps.append(now) 

94 return True 

95 return False 

96 

97 @property 

98 def current_count(self) -> int: 

99 cutoff = time.monotonic() - self.window 

100 return sum(1 for t in self._timestamps if t > cutoff) 

101 

102 

103class ConcurrencyLimiter: 

104 """并发请求限制器。""" 

105 

106 def __init__(self, max_concurrent: int): 

107 self._semaphore = asyncio.Semaphore(max_concurrent) 

108 self.max_concurrent = max_concurrent 

109 

110 async def acquire(self) -> bool: 

111 return await self._semaphore.acquire() 

112 

113 def release(self): 

114 self._semaphore.release() 

115 

116 @property 

117 def available(self) -> int: 

118 return self._semaphore._value 

119 

120 

121class RateLimiter: 

122 """组合限流器:Token Bucket + Concurrency Limiter + 多级配额。""" 

123 

124 def __init__(self, config: RateLimitConfig | None = None): 

125 cfg = config or RateLimitConfig() 

126 self.config = cfg 

127 self._bucket = TokenBucket( 

128 rate=cfg.max_requests / cfg.per_seconds, 

129 capacity=cfg.burst_size or cfg.max_requests 

130 ) 

131 self._window = SlidingWindow(cfg.max_requests, cfg.per_seconds) 

132 self._concurrency = ConcurrencyLimiter(cfg.max_concurrent) 

133 

134 async def acquire(self, weight: int = 1) -> RateLimitResult: 

135 """尝试获取请求配额。先检查并发,再检查速率。""" 

136 # 1. 并发检查 

137 if not self._concurrency._semaphore.locked(): 

138 pass # 还有并发槽位 

139 

140 # 2. 速率检查 

141 if self.config.strategy == RateLimitStrategy.TOKEN_BUCKET: 

142 if await self._bucket.consume(weight): 

143 return RateLimitResult( 

144 allowed=True, 

145 remaining=max(0, int(self._bucket.available)), 

146 limit=self.config.max_requests, 

147 ) 

148 wait = (weight - self._bucket.available) / self._bucket.rate 

149 return RateLimitResult( 

150 allowed=False, 

151 remaining=0, 

152 retry_after=wait, 

153 limit=self.config.max_requests, 

154 reason="rate_limit_exceeded", 

155 ) 

156 

157 elif self.config.strategy == RateLimitStrategy.SLIDING_WINDOW: 

158 if await self._window.allow(): 

159 return RateLimitResult( 

160 allowed=True, 

161 remaining=self.config.max_requests - self._window.current_count, 

162 limit=self.config.max_requests, 

163 ) 

164 return RateLimitResult( 

165 allowed=False, remaining=0, 

166 retry_after=self.config.per_seconds, 

167 limit=self.config.max_requests, 

168 reason="window_exceeded", 

169 ) 

170 

171 # fixed window fallback 

172 return RateLimitResult(allowed=True, limit=self.config.max_requests) 

173 

174 async def release(self): 

175 self._concurrency.release() 

176 

177 def model_quota(self, model: str) -> RateLimitConfig: 

178 """返回特定模型的配额配置。""" 

179 quotas = { 

180 "gpt-4o": RateLimitConfig(max_requests=50, per_seconds=60, burst_size=5), 

181 "gpt-4o-mini": RateLimitConfig(max_requests=200, per_seconds=60, burst_size=20), 

182 "claude-sonnet-4": RateLimitConfig(max_requests=40, per_seconds=60, burst_size=5), 

183 "deepseek-v3.1": RateLimitConfig(max_requests=100, per_seconds=60, burst_size=15), 

184 } 

185 return quotas.get(model, self.config) 

186 

187 

188class QuotaManager: 

189 """多租户配额管理。""" 

190 

191 def __init__(self): 

192 self._limiters: dict[str, RateLimiter] = {} 

193 

194 def get(self, key: str, config: RateLimitConfig | None = None) -> RateLimiter: 

195 if key not in self._limiters: 

196 self._limiters[key] = RateLimiter(config) 

197 return self._limiters[key] 

198 

199 def add_quota(self, key: str, config: RateLimitConfig): 

200 self._limiters[key] = RateLimiter(config) 

201 

202 def clear_expired(self, ttl: float = 3600): 

203 """清除超过TTL未使用的限流器(预留接口)。""" 

204 pass