Coverage for agentos/cache/llm_cache.py: 35%
154 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS v0.40 LLM Cache — 语义缓存减少API调用成本。
3支持:精确匹配缓存、语义相似度缓存、LRU淘汰、TTL过期。
4"""
6from __future__ import annotations
8import hashlib
9import json
10import time
11from collections import OrderedDict
12from dataclasses import dataclass, field
13from typing import Optional, Any
16@dataclass
17class CacheEntry:
18 """A cached LLM response with metadata.
20 Attributes:
21 key: Cache lookup key (typically hash of prompt + model).
22 value: Cached response content.
23 tokens_saved: Tokens saved by serving from cache.
24 cost_saved: Estimated cost saved.
25 created_at: Unix timestamp of cache insertion.
26 ttl: Time-to-live in seconds.
27 hit_count: Number of cache hits.
28 tags: Optional tags for cache invalidation.
29 """
30 key: str
31 value: Any
32 tokens_saved: int = 0
33 cost_saved: float = 0.0
34 created_at: float = field(default_factory=time.time)
35 ttl: float = 3600 # 默认1小时
36 hit_count: int = 0
37 tags: list[str] = field(default_factory=list)
39 @property
40 def expired(self) -> bool:
41 return time.time() > self.created_at + self.ttl
44class LRUCache:
45 """LRU淘汰的内存缓存。"""
47 def __init__(self, max_size: int = 500):
48 self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
49 self.max_size = max_size
51 def get(self, key: str) -> Optional[CacheEntry]:
52 entry = self._cache.get(key)
53 if entry:
54 if entry.expired:
55 del self._cache[key]
56 return None
57 self._cache.move_to_end(key)
58 entry.hit_count += 1
59 return entry
60 return None
62 def put(self, key: str, entry: CacheEntry):
63 if len(self._cache) >= self.max_size and key not in self._cache:
64 self._cache.popitem(last=False) # 淘汰最久未用
65 self._cache[key] = entry
66 self._cache.move_to_end(key)
68 def invalidate(self, key: str | None = None, tag: str | None = None):
69 if key:
70 self._cache.pop(key, None)
71 elif tag:
72 to_delete = [k for k, v in self._cache.items() if tag in v.tags]
73 for k in to_delete:
74 del self._cache[k]
76 def size(self) -> int:
77 return len(self._cache)
79 def clear(self):
80 self._cache.clear()
83class SemanticCache:
84 """语义缓存 — 基于embedding相似度的缓存匹配。"""
86 def __init__(self, similarity_threshold: float = 0.92, embedder: Any = None):
87 self.threshold = similarity_threshold
88 self._entries: list[tuple[list[float], CacheEntry]] = []
89 self._embedder = embedder # 外部注入的embedding函数
90 self.max_entries = 200
92 def _embed(self, text: str) -> list[float]:
93 if self._embedder:
94 return self._embedder(text)
95 # 默认回退:简易TF-IDF风格hash
96 tokens = text.lower().split()
97 tf = {}
98 for t in tokens:
99 tf[t] = tf.get(t, 0) + 1
100 vec = [hash(w) % 100 / 100.0 * tf.get(w, 0) for w in sorted(set(tokens))[:128]]
101 return vec[:64] if len(vec) > 64 else vec + [0.0] * (64 - len(vec))
103 @staticmethod
104 def cosine_sim(a: list[float], b: list[float]) -> float:
105 if not a or not b:
106 return 0.0
107 dot = sum(x * y for x, y in zip(a, b))
108 norm_a = sum(x**2 for x in a) ** 0.5
109 norm_b = sum(x**2 for x in b) ** 0.5
110 if norm_a == 0 or norm_b == 0:
111 return 0.0
112 return dot / (norm_a * norm_b)
114 def search(self, query: str) -> Optional[CacheEntry]:
115 query_vec = self._embed(query)
116 best_sim = 0.0
117 best_entry = None
118 for cached_vec, entry in self._entries:
119 if entry.expired:
120 continue
121 sim = self.cosine_sim(query_vec, cached_vec)
122 if sim > best_sim:
123 best_sim = sim
124 best_entry = entry
125 if best_sim >= self.threshold and best_entry:
126 best_entry.hit_count += 1
127 return best_entry
128 return None
130 def add(self, query: str, entry: CacheEntry):
131 vec = self._embed(query)
132 self._entries.append((vec, entry))
133 if len(self._entries) > self.max_entries:
134 self._entries = self._entries[-self.max_entries:]
136 def clear(self):
137 self._entries.clear()
140@dataclass
141class CacheStats:
142 """缓存统计。"""
143 total_requests: int = 0
144 hits: int = 0
145 misses: int = 0
146 tokens_saved: int = 0
147 cost_saved: float = 0.0
148 exact_hits: int = 0
149 semantic_hits: int = 0
151 @property
152 def hit_rate(self) -> float:
153 if self.total_requests == 0:
154 return 0.0
155 return self.hits / self.total_requests
158class LLMCache:
159 """
160 LLM响应缓存 — 减少API调用成本。
162 三层策略:
163 1. 精确匹配缓存 (LRU + TTL)
164 2. 语义相似度缓存
165 3. 透传 (无缓存命中)
166 """
168 def __init__(self, lru_size: int = 500, semantic_threshold: float = 0.92, enable_semantic: bool = True):
169 self.lru = LRUCache(max_size=lru_size)
170 self.semantic = SemanticCache(similarity_threshold=semantic_threshold) if enable_semantic else None
171 self.stats = CacheStats()
173 @staticmethod
174 def _hash_key(prompt: str, model: str = "", **kwargs) -> str:
175 payload = prompt + model + json.dumps(kwargs, sort_keys=True)
176 return hashlib.sha256(payload.encode()).hexdigest()[:32]
178 def get(self, prompt: str, model: str = "", **kwargs) -> Optional[Any]:
179 self.stats.total_requests += 1
181 # 1. 精确匹配
182 exact_key = self._hash_key(prompt, model, **kwargs)
183 entry = self.lru.get(exact_key)
184 if entry:
185 self.stats.hits += 1
186 self.stats.exact_hits += 1
187 self.stats.tokens_saved += entry.tokens_saved
188 self.stats.cost_saved += entry.cost_saved
189 return entry.value
191 # 2. 语义匹配
192 if self.semantic:
193 entry = self.semantic.search(prompt)
194 if entry:
195 self.stats.hits += 1
196 self.stats.semantic_hits += 1
197 self.stats.tokens_saved += entry.tokens_saved
198 self.stats.cost_saved += entry.cost_saved
199 return entry.value
201 self.stats.misses += 1
202 return None
204 def set(self, prompt: str, value: Any, model: str = "", tokens: int = 0, cost: float = 0.0, ttl: float = 3600, **kwargs):
205 exact_key = self._hash_key(prompt, model, **kwargs)
206 entry = CacheEntry(key=exact_key, value=value, tokens_saved=tokens, cost_saved=cost, ttl=ttl)
207 self.lru.put(exact_key, entry)
209 if self.semantic:
210 self.semantic.add(prompt, entry)
212 def invalidate(self, key: str = "", tag: str = ""):
213 self.lru.invalidate(key=key or None, tag=tag or None)
215 def clear(self):
216 self.lru.clear()
217 if self.semantic:
218 self.semantic.clear()
220 def snapshot(self) -> dict:
221 return {
222 "lru_entries": self.lru.size(),
223 "semantic_entries": len(self.semantic._entries) if self.semantic else 0,
224 "hit_rate": f"{self.stats.hit_rate:.1%}",
225 "tokens_saved": self.stats.tokens_saved,
226 "cost_saved": f"${self.stats.cost_saved:.4f}",
227 "total_requests": self.stats.total_requests,
228 "exact_hits": self.stats.exact_hits,
229 "semantic_hits": self.stats.semantic_hits,
230 }