Coverage for agentos/tools/request_deduplicator.py: 0%

141 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 07:39 +0800

1""" 

2RequestDeduplicator — fingerprint-based concurrent request deduplication. 

3 

4Supports: 

5 - Fingerprint generation from request parameters 

6 - In-flight deduplication (same fingerprint → wait for existing result) 

7 - Result caching with TTL (return cached result on duplicate) 

8 - Thread-safe + async-safe 

9 - Auto-cleanup of expired entries 

10 - Configurable max cache size 

11""" 

12 

13from __future__ import annotations 

14 

15import hashlib 

16import json 

17import threading 

18import time 

19from enum import Enum 

20from typing import Any, Callable, Dict, Optional, Tuple 

21 

22 

23# ============================================================================ 

24# Result 

25# ============================================================================ 

26 

27class ResultStatus(Enum): 

28 COMPLETED = "completed" 

29 ERROR = "error" 

30 

31 

32class DedupResult: 

33 __slots__ = ("status", "value", "timestamp") 

34 

35 def __init__(self, status: ResultStatus, value: Any): 

36 self.status = status 

37 self.value = value 

38 self.timestamp = time.time() 

39 

40 

41# ============================================================================ 

42# RequestDeduplicator 

43# ============================================================================ 

44 

45class RequestDeduplicator: 

46 """Fingerprint-based request deduplication with result caching. 

47 

48 Usage: 

49 dedup = RequestDeduplicator(ttl=30.0) 

50 

51 # Option A: manual 

52 key = dedup.create_key(method="POST", path="/api/users", body={"name": "Alice"}) 

53 result = dedup.get(key) 

54 if result: 

55 return result.value 

56 

57 dedup.mark_in_flight(key) 

58 try: 

59 response = do_request(...) 

60 dedup.complete(key, response) 

61 except Exception as e: 

62 dedup.error(key, e) 

63 raise 

64 

65 # Option B: decorator 

66 @dedup.deduplicate(key_fn=lambda *a, **kw: f"{a[0]}_{a[1]}") 

67 def fetch(user_id, query): 

68 return api_call(user_id, query) 

69 """ 

70 

71 def __init__( 

72 self, 

73 ttl: float = 60.0, 

74 max_entries: int = 10000, 

75 key_prefix: str = "dedup:", 

76 ): 

77 self._ttl = ttl 

78 self._max_entries = max_entries 

79 self._key_prefix = key_prefix 

80 self._cache: Dict[str, DedupResult] = {} 

81 self._in_flight: Dict[str, threading.Event] = {} 

82 self._in_flight_results: Dict[str, DedupResult] = {} 

83 self._lock = threading.RLock() 

84 self._last_cleanup = time.time() 

85 

86 # ---------- key generation ---------- 

87 

88 def create_key(self, *args: Any, **kwargs: Any) -> str: 

89 """Generate a unique fingerprint key from args/kwargs. 

90 

91 Args are hashed positionally; kwargs are sorted by key. 

92 """ 

93 payload: Dict[str, Any] = {"args": args, "kwargs": dict(sorted(kwargs.items()))} 

94 raw = json.dumps(payload, sort_keys=True, default=str) 

95 digest = hashlib.sha256(raw.encode()).hexdigest()[:16] 

96 return f"{self._key_prefix}{digest}" 

97 

98 # ---------- lookup ---------- 

99 

100 def get(self, key: str) -> Optional[Any]: 

101 """Return cached result if available and not expired. None if not found.""" 

102 self._maybe_cleanup() 

103 with self._lock: 

104 entry = self._cache.get(key) 

105 if entry is None: 

106 return None 

107 age = time.time() - entry.timestamp 

108 if age > self._ttl: 

109 del self._cache[key] 

110 return None 

111 return entry 

112 

113 def get_or_none(self, key: str) -> Optional[Any]: 

114 """Same as get() but returns the raw value or None.""" 

115 entry = self.get(key) 

116 if entry: 

117 return entry.value 

118 return None 

119 

120 # ---------- in-flight management ---------- 

121 

122 def mark_in_flight(self, key: str) -> bool: 

123 """Mark key as in-flight. Returns True if we should proceed (first caller). 

124 Returns False if another caller is already processing — caller should wait. 

125 """ 

126 with self._lock: 

127 if key in self._in_flight: 

128 return False 

129 self._in_flight[key] = threading.Event() 

130 return True 

131 

132 def wait_in_flight(self, key: str, timeout: Optional[float] = None) -> Optional[Any]: 

133 """Wait for an in-flight request to complete, then return its result.""" 

134 event = None 

135 with self._lock: 

136 event = self._in_flight.get(key) 

137 if event is None: 

138 return None 

139 signaled = event.wait(timeout=timeout) 

140 if not signaled: 

141 return None 

142 with self._lock: 

143 result = self._in_flight_results.pop(key, None) 

144 self._in_flight.pop(key, None) 

145 if result: 

146 return result.value 

147 return None 

148 

149 def complete(self, key: str, result: Any) -> None: 

150 """Signal completion and cache the result.""" 

151 with self._lock: 

152 entry = DedupResult(ResultStatus.COMPLETED, result) 

153 self._cache[key] = entry 

154 self._in_flight_results[key] = entry 

155 event = self._in_flight.get(key) 

156 # Signal outside lock to avoid deadlock 

157 if event: 

158 event.set() 

159 self._evict_if_needed() 

160 

161 def error(self, key: str, error: Exception) -> None: 

162 """Signal error for in-flight request.""" 

163 with self._lock: 

164 entry = DedupResult(ResultStatus.ERROR, error) 

165 self._in_flight_results[key] = entry 

166 event = self._in_flight.get(key) 

167 if event: 

168 event.set() 

169 

170 # ---------- decorator ---------- 

171 

172 def deduplicate( 

173 self, 

174 key_fn: Callable[..., str], 

175 wait_timeout: Optional[float] = 30.0, 

176 cache_errors: bool = False, 

177 ): 

178 """Decorator: deduplicate concurrent calls with same fingerprint. 

179 

180 Args: 

181 key_fn: function(*args, **kwargs) → key string 

182 wait_timeout: max wait for in-flight request 

183 cache_errors: if True, cache error results too 

184 """ 

185 

186 def decorator(func: Callable) -> Callable: 

187 def wrapper(*args: Any, **kwargs: Any) -> Any: 

188 key = key_fn(*args, **kwargs) 

189 

190 # Check cache first 

191 cached = self.get(key) 

192 if cached is not None: 

193 if cached.status == ResultStatus.ERROR: 

194 if not cache_errors: 

195 pass # fall through to re-execute 

196 else: 

197 raise cached.value if isinstance(cached.value, Exception) else Exception(str(cached.value)) 

198 else: 

199 return cached.value 

200 

201 # Try to claim in-flight 

202 if self.mark_in_flight(key): 

203 try: 

204 result = func(*args, **kwargs) 

205 self.complete(key, result) 

206 return result 

207 except Exception as e: 

208 if cache_errors: 

209 self.complete(key, e) 

210 else: 

211 self.error(key, e) 

212 raise 

213 else: 

214 # Another caller is processing — wait 

215 result = self.wait_in_flight(key, timeout=wait_timeout) 

216 if result is not None: 

217 return result 

218 # Timeout: fall through to execute ourselves 

219 raise TimeoutError(f"Timeout waiting for deduplicated request: {key}") 

220 

221 return wrapper 

222 

223 return decorator 

224 

225 # ---------- cache maintenance ---------- 

226 

227 def _maybe_cleanup(self) -> None: 

228 """Trigger cleanup if enough time has passed.""" 

229 now = time.time() 

230 if now - self._last_cleanup < self._ttl: 

231 return 

232 self._last_cleanup = now 

233 with self._lock: 

234 expired = [ 

235 k for k, v in self._cache.items() 

236 if now - v.timestamp > self._ttl 

237 ] 

238 for k in expired: 

239 del self._cache[k] 

240 

241 def _evict_if_needed(self) -> None: 

242 with self._lock: 

243 excess = len(self._cache) - self._max_entries 

244 if excess <= 0: 

245 return 

246 # Evict oldest entries 

247 sorted_by_age = sorted(self._cache.items(), key=lambda x: x[1].timestamp) 

248 for k, _ in sorted_by_age[:excess]: 

249 del self._cache[k] 

250 

251 def clear(self) -> None: 

252 with self._lock: 

253 self._cache.clear() 

254 self._in_flight.clear() 

255 self._in_flight_results.clear() 

256 

257 @property 

258 def cache_size(self) -> int: 

259 with self._lock: 

260 return len(self._cache) 

261 

262 @property 

263 def in_flight_count(self) -> int: 

264 with self._lock: 

265 return len(self._in_flight)