Coverage for agentos/cache/llm_cache.py: 35%

154 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v0.40 LLM Cache — 语义缓存减少API调用成本。 

3支持:精确匹配缓存、语义相似度缓存、LRU淘汰、TTL过期。 

4""" 

5 

6from __future__ import annotations 

7 

8import hashlib 

9import json 

10import time 

11from collections import OrderedDict 

12from dataclasses import dataclass, field 

13from typing import Optional, Any 

14 

15 

16@dataclass 

17class CacheEntry: 

18 """A cached LLM response with metadata. 

19 

20 Attributes: 

21 key: Cache lookup key (typically hash of prompt + model). 

22 value: Cached response content. 

23 tokens_saved: Tokens saved by serving from cache. 

24 cost_saved: Estimated cost saved. 

25 created_at: Unix timestamp of cache insertion. 

26 ttl: Time-to-live in seconds. 

27 hit_count: Number of cache hits. 

28 tags: Optional tags for cache invalidation. 

29 """ 

30 key: str 

31 value: Any 

32 tokens_saved: int = 0 

33 cost_saved: float = 0.0 

34 created_at: float = field(default_factory=time.time) 

35 ttl: float = 3600 # 默认1小时 

36 hit_count: int = 0 

37 tags: list[str] = field(default_factory=list) 

38 

39 @property 

40 def expired(self) -> bool: 

41 return time.time() > self.created_at + self.ttl 

42 

43 

44class LRUCache: 

45 """LRU淘汰的内存缓存。""" 

46 

47 def __init__(self, max_size: int = 500): 

48 self._cache: OrderedDict[str, CacheEntry] = OrderedDict() 

49 self.max_size = max_size 

50 

51 def get(self, key: str) -> Optional[CacheEntry]: 

52 entry = self._cache.get(key) 

53 if entry: 

54 if entry.expired: 

55 del self._cache[key] 

56 return None 

57 self._cache.move_to_end(key) 

58 entry.hit_count += 1 

59 return entry 

60 return None 

61 

62 def put(self, key: str, entry: CacheEntry): 

63 if len(self._cache) >= self.max_size and key not in self._cache: 

64 self._cache.popitem(last=False) # 淘汰最久未用 

65 self._cache[key] = entry 

66 self._cache.move_to_end(key) 

67 

68 def invalidate(self, key: str | None = None, tag: str | None = None): 

69 if key: 

70 self._cache.pop(key, None) 

71 elif tag: 

72 to_delete = [k for k, v in self._cache.items() if tag in v.tags] 

73 for k in to_delete: 

74 del self._cache[k] 

75 

76 def size(self) -> int: 

77 return len(self._cache) 

78 

79 def clear(self): 

80 self._cache.clear() 

81 

82 

83class SemanticCache: 

84 """语义缓存 — 基于embedding相似度的缓存匹配。""" 

85 

86 def __init__(self, similarity_threshold: float = 0.92, embedder: Any = None): 

87 self.threshold = similarity_threshold 

88 self._entries: list[tuple[list[float], CacheEntry]] = [] 

89 self._embedder = embedder # 外部注入的embedding函数 

90 self.max_entries = 200 

91 

92 def _embed(self, text: str) -> list[float]: 

93 if self._embedder: 

94 return self._embedder(text) 

95 # 默认回退:简易TF-IDF风格hash 

96 tokens = text.lower().split() 

97 tf = {} 

98 for t in tokens: 

99 tf[t] = tf.get(t, 0) + 1 

100 vec = [hash(w) % 100 / 100.0 * tf.get(w, 0) for w in sorted(set(tokens))[:128]] 

101 return vec[:64] if len(vec) > 64 else vec + [0.0] * (64 - len(vec)) 

102 

103 @staticmethod 

104 def cosine_sim(a: list[float], b: list[float]) -> float: 

105 if not a or not b: 

106 return 0.0 

107 dot = sum(x * y for x, y in zip(a, b)) 

108 norm_a = sum(x**2 for x in a) ** 0.5 

109 norm_b = sum(x**2 for x in b) ** 0.5 

110 if norm_a == 0 or norm_b == 0: 

111 return 0.0 

112 return dot / (norm_a * norm_b) 

113 

114 def search(self, query: str) -> Optional[CacheEntry]: 

115 query_vec = self._embed(query) 

116 best_sim = 0.0 

117 best_entry = None 

118 for cached_vec, entry in self._entries: 

119 if entry.expired: 

120 continue 

121 sim = self.cosine_sim(query_vec, cached_vec) 

122 if sim > best_sim: 

123 best_sim = sim 

124 best_entry = entry 

125 if best_sim >= self.threshold and best_entry: 

126 best_entry.hit_count += 1 

127 return best_entry 

128 return None 

129 

130 def add(self, query: str, entry: CacheEntry): 

131 vec = self._embed(query) 

132 self._entries.append((vec, entry)) 

133 if len(self._entries) > self.max_entries: 

134 self._entries = self._entries[-self.max_entries:] 

135 

136 def clear(self): 

137 self._entries.clear() 

138 

139 

140@dataclass 

141class CacheStats: 

142 """缓存统计。""" 

143 total_requests: int = 0 

144 hits: int = 0 

145 misses: int = 0 

146 tokens_saved: int = 0 

147 cost_saved: float = 0.0 

148 exact_hits: int = 0 

149 semantic_hits: int = 0 

150 

151 @property 

152 def hit_rate(self) -> float: 

153 if self.total_requests == 0: 

154 return 0.0 

155 return self.hits / self.total_requests 

156 

157 

158class LLMCache: 

159 """ 

160 LLM响应缓存 — 减少API调用成本。 

161 

162 三层策略: 

163 1. 精确匹配缓存 (LRU + TTL) 

164 2. 语义相似度缓存 

165 3. 透传 (无缓存命中) 

166 """ 

167 

168 def __init__(self, lru_size: int = 500, semantic_threshold: float = 0.92, enable_semantic: bool = True): 

169 self.lru = LRUCache(max_size=lru_size) 

170 self.semantic = SemanticCache(similarity_threshold=semantic_threshold) if enable_semantic else None 

171 self.stats = CacheStats() 

172 

173 @staticmethod 

174 def _hash_key(prompt: str, model: str = "", **kwargs) -> str: 

175 payload = prompt + model + json.dumps(kwargs, sort_keys=True) 

176 return hashlib.sha256(payload.encode()).hexdigest()[:32] 

177 

178 def get(self, prompt: str, model: str = "", **kwargs) -> Optional[Any]: 

179 self.stats.total_requests += 1 

180 

181 # 1. 精确匹配 

182 exact_key = self._hash_key(prompt, model, **kwargs) 

183 entry = self.lru.get(exact_key) 

184 if entry: 

185 self.stats.hits += 1 

186 self.stats.exact_hits += 1 

187 self.stats.tokens_saved += entry.tokens_saved 

188 self.stats.cost_saved += entry.cost_saved 

189 return entry.value 

190 

191 # 2. 语义匹配 

192 if self.semantic: 

193 entry = self.semantic.search(prompt) 

194 if entry: 

195 self.stats.hits += 1 

196 self.stats.semantic_hits += 1 

197 self.stats.tokens_saved += entry.tokens_saved 

198 self.stats.cost_saved += entry.cost_saved 

199 return entry.value 

200 

201 self.stats.misses += 1 

202 return None 

203 

204 def set(self, prompt: str, value: Any, model: str = "", tokens: int = 0, cost: float = 0.0, ttl: float = 3600, **kwargs): 

205 exact_key = self._hash_key(prompt, model, **kwargs) 

206 entry = CacheEntry(key=exact_key, value=value, tokens_saved=tokens, cost_saved=cost, ttl=ttl) 

207 self.lru.put(exact_key, entry) 

208 

209 if self.semantic: 

210 self.semantic.add(prompt, entry) 

211 

212 def invalidate(self, key: str = "", tag: str = ""): 

213 self.lru.invalidate(key=key or None, tag=tag or None) 

214 

215 def clear(self): 

216 self.lru.clear() 

217 if self.semantic: 

218 self.semantic.clear() 

219 

220 def snapshot(self) -> dict: 

221 return { 

222 "lru_entries": self.lru.size(), 

223 "semantic_entries": len(self.semantic._entries) if self.semantic else 0, 

224 "hit_rate": f"{self.stats.hit_rate:.1%}", 

225 "tokens_saved": self.stats.tokens_saved, 

226 "cost_saved": f"${self.stats.cost_saved:.4f}", 

227 "total_requests": self.stats.total_requests, 

228 "exact_hits": self.stats.exact_hits, 

229 "semantic_hits": self.stats.semantic_hits, 

230 }