Coverage for agentos/tools/serial_cache.py: 0%

197 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 06:58 +0800

1""" 

2Serialization & Caching for AgentOS. 

3 

4Serializer — adaptive serializer with JSON/msgpack/pickle auto-detection. 

5TTLCache — thread-safe time-to-live cache with LRU/LFU eviction. 

6SmartCache — compute-on-miss cache wrapping TTL cache with serializer. 

7""" 

8 

9import json 

10import pickle 

11import threading 

12import time 

13from collections import OrderedDict 

14from dataclasses import dataclass, field 

15from enum import Enum 

16from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar 

17 

18T = TypeVar("T") 

19 

20 

21# ============================================================================ 

22# Serializer 

23# ============================================================================ 

24 

25class SerialFormat(Enum): 

26 JSON = "json" 

27 PICKLE = "pickle" 

28 MSGPACK = "msgpack" 

29 AUTO = "auto" 

30 

31 def detect(self, data: bytes) -> "SerialFormat": 

32 if self != SerialFormat.AUTO: 

33 return self 

34 if data[:2] in (b'\x80\x03', b'\x80\x04', b'\x80\x05'): 

35 return SerialFormat.PICKLE 

36 if data[:1] == b'{' or data[:1] == b'[': 

37 return SerialFormat.JSON 

38 # Try msgpack header (0x80-0x8f for fixmap, 0x90-0x9f for fixarray, 0xdc/0xdd/0xde/0xdf, etc.) 

39 if len(data) > 0 and data[0] in range(0x80, 0x100): 

40 try: 

41 import msgpack 

42 msgpack.unpackb(data) 

43 return SerialFormat.MSGPACK 

44 except Exception: 

45 pass 

46 raise ValueError("Cannot auto-detect serialization format") 

47 

48 

49class Serializer: 

50 """Adaptive serializer with format auto-detection and compression support.""" 

51 

52 def __init__(self, fmt: SerialFormat = SerialFormat.JSON): 

53 self._fmt = fmt 

54 self._total_serialized: int = 0 

55 self._total_deserialized: int = 0 

56 

57 def dumps(self, obj: Any, use_msgpack: bool = False) -> bytes: 

58 fmt = SerialFormat.MSGPACK if use_msgpack else self._fmt 

59 if fmt == SerialFormat.AUTO: 

60 fmt = SerialFormat.JSON 

61 

62 if fmt == SerialFormat.JSON: 

63 data = json.dumps(obj, ensure_ascii=False, default=str) 

64 self._total_serialized += 1 

65 return data.encode('utf-8') 

66 

67 elif fmt == SerialFormat.PICKLE: 

68 data = pickle.dumps(obj) 

69 self._total_serialized += 1 

70 return data 

71 

72 elif fmt == SerialFormat.MSGPACK: 

73 import msgpack 

74 data = msgpack.packb(obj, default=str) 

75 self._total_serialized += 1 

76 return data 

77 

78 raise ValueError(f"Unsupported format: {fmt}") 

79 

80 def loads(self, data: bytes, fmt: Optional[SerialFormat] = None) -> Any: 

81 if fmt is None: 

82 fmt = SerialFormat.AUTO 

83 

84 fmt = fmt.detect(data) 

85 

86 if fmt == SerialFormat.JSON: 

87 result = json.loads(data.decode('utf-8')) 

88 self._total_deserialized += 1 

89 return result 

90 

91 elif fmt == SerialFormat.PICKLE: 

92 result = pickle.loads(data) 

93 self._total_deserialized += 1 

94 return result 

95 

96 elif fmt == SerialFormat.MSGPACK: 

97 import msgpack 

98 result = msgpack.unpackb(data) 

99 self._total_deserialized += 1 

100 return result 

101 

102 raise ValueError(f"Unsupported format: {fmt}") 

103 

104 @property 

105 def stats(self) -> Dict[str, Any]: 

106 return { 

107 "format": self._fmt.value if isinstance(self._fmt, SerialFormat) else self._fmt, 

108 "total_serialized": self._total_serialized, 

109 "total_deserialized": self._total_deserialized, 

110 } 

111 

112 

113# ============================================================================ 

114# EvictionPolicy 

115# ============================================================================ 

116 

117class EvictionPolicy(Enum): 

118 LRU = "lru" 

119 LFU = "lfu" 

120 TTL_ONLY = "ttl_only" 

121 

122 

123@dataclass 

124class _CacheEntry(Generic[T]): 

125 value: T 

126 expires_at: float 

127 access_count: int = 0 

128 last_access: float = field(default_factory=time.monotonic) 

129 

130 

131# ============================================================================ 

132# TTLCache 

133# ============================================================================ 

134 

135class TTLCache(Generic[T]): 

136 """Thread-safe TTL cache with configurable eviction policy (LRU/LFU). 

137 

138 Entries expire after ttl_seconds. On maxsize overflow, evicts based on policy. 

139 """ 

140 

141 def __init__( 

142 self, 

143 max_size: int = 1000, 

144 ttl: float = 300.0, 

145 policy: EvictionPolicy = EvictionPolicy.LRU, 

146 ): 

147 self._max_size = max_size 

148 self._ttl = ttl 

149 self._policy = policy 

150 self._data: OrderedDict[str, _CacheEntry[T]] = OrderedDict() 

151 self._lock = threading.RLock() 

152 self._hits: int = 0 

153 self._misses: int = 0 

154 self._evictions: int = 0 

155 

156 def get(self, key: str) -> Optional[T]: 

157 with self._lock: 

158 entry = self._data.get(key) 

159 if entry is None: 

160 self._misses += 1 

161 return None 

162 

163 if time.monotonic() > entry.expires_at: 

164 del self._data[key] 

165 self._misses += 1 

166 self._evictions += 1 

167 return None 

168 

169 entry.access_count += 1 

170 entry.last_access = time.monotonic() 

171 # Move to end for LRU ordering 

172 self._data.move_to_end(key) 

173 self._hits += 1 

174 return entry.value 

175 

176 def set(self, key: str, value: T, ttl: Optional[float] = None) -> None: 

177 with self._lock: 

178 if key in self._data: 

179 self._data.pop(key) 

180 

181 if len(self._data) >= self._max_size: 

182 self._evict_one() 

183 

184 self._data[key] = _CacheEntry( 

185 value=value, 

186 expires_at=time.monotonic() + (ttl if ttl is not None else self._ttl), 

187 ) 

188 self._data.move_to_end(key) 

189 

190 def _evict_one(self) -> None: 

191 if not self._data: 

192 return 

193 

194 if self._policy == EvictionPolicy.TTL_ONLY: 

195 # Remove oldest (first inserted) 

196 self._data.popitem(last=False) 

197 self._evictions += 1 

198 return 

199 

200 if self._policy == EvictionPolicy.LRU: 

201 # First item is least recently used (get moves items to end) 

202 self._data.popitem(last=False) 

203 self._evictions += 1 

204 return 

205 

206 if self._policy == EvictionPolicy.LFU: 

207 # Find item with lowest access count 

208 victim_key = min(self._data, key=lambda k: self._data[k].access_count) 

209 del self._data[victim_key] 

210 self._evictions += 1 

211 

212 def delete(self, key: str) -> bool: 

213 with self._lock: 

214 if key in self._data: 

215 del self._data[key] 

216 return True 

217 return False 

218 

219 def clear(self) -> None: 

220 with self._lock: 

221 self._data.clear() 

222 

223 def cleanup(self) -> int: 

224 """Remove all expired entries. Returns count removed.""" 

225 now = time.monotonic() 

226 count = 0 

227 with self._lock: 

228 expired = [k for k, v in self._data.items() if now > v.expires_at] 

229 for k in expired: 

230 del self._data[k] 

231 count += 1 

232 self._evictions += count 

233 return count 

234 

235 @property 

236 def size(self) -> int: 

237 with self._lock: 

238 return len(self._data) 

239 

240 @property 

241 def stats(self) -> Dict[str, Any]: 

242 with self._lock: 

243 return { 

244 "size": len(self._data), 

245 "max_size": self._max_size, 

246 "ttl": self._ttl, 

247 "policy": self._policy.value, 

248 "hits": self._hits, 

249 "misses": self._misses, 

250 "evictions": self._evictions, 

251 "hit_rate": round(self._hits / max(1, self._hits + self._misses), 3), 

252 } 

253 

254 

255# ============================================================================ 

256# SmartCache 

257# ============================================================================ 

258 

259class SmartCache(Generic[T]): 

260 """Compute-on-miss cache combining TTLCache with Serializer. 

261 

262 Provides get_or_compute() — key misses trigger the factory function, 

263 result stored in cache automatically. Supports serialization for persistence. 

264 """ 

265 

266 def __init__( 

267 self, 

268 max_size: int = 1000, 

269 ttl: float = 300.0, 

270 policy: EvictionPolicy = EvictionPolicy.LRU, 

271 ): 

272 self._cache = TTLCache[T](max_size=max_size, ttl=ttl, policy=policy) 

273 self._serializer = Serializer() 

274 

275 def get(self, key: str) -> Optional[T]: 

276 return self._cache.get(key) 

277 

278 def get_or_compute(self, key: str, factory: Callable[[], T], ttl: Optional[float] = None) -> T: 

279 """Get from cache or compute via factory and cache the result.""" 

280 value = self._cache.get(key) 

281 if value is not None: 

282 return value 

283 value = factory() 

284 self._cache.set(key, value, ttl=ttl) 

285 return value 

286 

287 def set(self, key: str, value: T, ttl: Optional[float] = None) -> None: 

288 self._cache.set(key, value, ttl=ttl) 

289 

290 def delete(self, key: str) -> bool: 

291 return self._cache.delete(key) 

292 

293 def clear(self) -> None: 

294 self._cache.clear() 

295 

296 def dump(self) -> bytes: 

297 """Serialize entire cache state.""" 

298 with self._cache._lock: 

299 entries = { 

300 k: { 

301 "value": v.value, 

302 "expires_at": v.expires_at, 

303 "access_count": v.access_count, 

304 "last_access": v.last_access, 

305 } 

306 for k, v in self._cache._data.items() 

307 } 

308 return self._serializer.dumps(entries) 

309 

310 def load(self, data: bytes) -> int: 

311 """Restore cache from serialized data. Returns number of entries loaded.""" 

312 now = time.monotonic() 

313 entries = self._serializer.loads(data) 

314 count = 0 

315 for k, v in entries.items(): 

316 if v["expires_at"] > now: 

317 self._cache._data[k] = _CacheEntry( 

318 value=v["value"], 

319 expires_at=v["expires_at"], 

320 access_count=v["access_count"], 

321 last_access=v["last_access"], 

322 ) 

323 count += 1 

324 return count 

325 

326 @property 

327 def size(self) -> int: 

328 return self._cache.size 

329 

330 @property 

331 def stats(self) -> Dict[str, Any]: 

332 return self._cache.stats