Coverage for agentos/tools/memory_optimizer.py: 97%
213 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 20:14 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 20:14 +0800
1"""
2Memory Optimization Tools for AgentOS.
3Object pooling, LRU caching, memory monitoring, and smart caching with TTL.
4"""
6import threading
7import time
8import weakref
9from collections import OrderedDict
10from dataclasses import dataclass, field
11from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar
13T = TypeVar("T")
16# ============================================================================
17# ObjectPool
18# ============================================================================
20class ObjectPool(Generic[T]):
21 """Thread-safe object pool with auto-expiry and size limits.
23 Reuses pre-allocated objects instead of creating/destroying them repeatedly.
24 """
26 def __init__(
27 self,
28 factory: Callable[[], T],
29 max_size: int = 100,
30 max_idle: int = 30,
31 idle_timeout: float = 300.0,
32 ):
33 self._factory = factory
34 self._max_size = max_size
35 self._max_idle = max_idle
36 self._idle_timeout = idle_timeout
37 self._pool: List[_PooledItem[T]] = []
38 self._lock = threading.Lock()
39 self._created: int = 0
40 self._borrowed: int = 0
41 self._returned: int = 0
43 def acquire(self) -> T:
44 """Borrow an object from the pool or create a new one."""
45 with self._lock:
46 now = time.monotonic()
47 self._evict_expired(now)
49 if self._pool:
50 item = self._pool.pop()
51 item.idle = False
52 self._borrowed += 1
53 return item.obj
55 if self._created < self._max_size:
56 self._created += 1
57 self._borrowed += 1
58 return self._factory()
60 # Pool fully allocated; create temporary object outside pool
61 self._borrowed += 1
62 return self._factory()
64 def release(self, obj: T) -> None:
65 """Return an object to the pool for reuse."""
66 with self._lock:
67 self._returned += 1
68 self._evict_expired(time.monotonic())
70 if len(self._pool) < self._max_idle:
71 self._pool.append(_PooledItem(obj=obj, idle=True, acquired_at=time.monotonic()))
72 # else: discard excess objects
74 def _evict_expired(self, now: float) -> None:
75 self._pool[:] = [item for item in self._pool if now - item.acquired_at < self._idle_timeout]
77 @property
78 def stats(self) -> Dict[str, Any]:
79 with self._lock:
80 return {
81 "created": self._created,
82 "borrowed": self._borrowed,
83 "returned": self._returned,
84 "idle": len(self._pool),
85 "active": self._borrowed - self._returned,
86 }
88 def __len__(self) -> int:
89 return len(self._pool)
92@dataclass
93class _PooledItem(Generic[T]):
94 obj: T
95 idle: bool
96 acquired_at: float
99# ============================================================================
100# LRUCache
101# ============================================================================
103class LRUCache(Generic[T]):
104 """Thread-safe LRU cache with capacity limit and optional TTL."""
106 def __init__(self, capacity: int = 1024, ttl: Optional[float] = None):
107 self._capacity = capacity
108 self._ttl = ttl
109 self._cache: OrderedDict[str, _CacheEntry[T]] = OrderedDict()
110 self._lock = threading.Lock()
111 self._hits: int = 0
112 self._misses: int = 0
113 self._evictions: int = 0
115 def get(self, key: str) -> Optional[T]:
116 with self._lock:
117 entry = self._cache.get(key)
118 if entry is None:
119 self._misses += 1
120 return None
121 if self._ttl and time.monotonic() - entry.timestamp > self._ttl:
122 del self._cache[key]
123 self._misses += 1
124 self._evictions += 1
125 return None
126 self._cache.move_to_end(key)
127 self._hits += 1
128 return entry.value
130 def put(self, key: str, value: T) -> None:
131 with self._lock:
132 if key in self._cache:
133 self._cache.move_to_end(key)
134 self._cache[key] = _CacheEntry(value=value, timestamp=time.monotonic())
135 return
136 if len(self._cache) >= self._capacity:
137 self._cache.popitem(last=False)
138 self._evictions += 1
139 self._cache[key] = _CacheEntry(value=value, timestamp=time.monotonic())
141 def remove(self, key: str) -> bool:
142 with self._lock:
143 if key in self._cache:
144 del self._cache[key]
145 return True
146 return False
148 def clear(self) -> None:
149 with self._lock:
150 self._cache.clear()
152 @property
153 def hit_rate(self) -> float:
154 total = self._hits + self._misses
155 return self._hits / total if total > 0 else 0.0
157 @property
158 def stats(self) -> Dict[str, Any]:
159 with self._lock:
160 return {
161 "size": len(self._cache),
162 "capacity": self._capacity,
163 "hits": self._hits,
164 "misses": self._misses,
165 "hit_rate": round(self.hit_rate, 4),
166 "evictions": self._evictions,
167 }
169 def __len__(self) -> int:
170 return len(self._cache)
172 def __contains__(self, key: str) -> bool:
173 return key in self._cache
176@dataclass
177class _CacheEntry(Generic[T]):
178 value: T
179 timestamp: float
182# ============================================================================
183# SmartCache
184# ============================================================================
186class SmartCache(Generic[T]):
187 """Multi-tier cache with compute-on-miss and automatic invalidation."""
189 def __init__(self, compute: Callable[[str], T], capacity: int = 1024, ttl: float = 300.0):
190 self._lru = LRUCache[T](capacity=capacity, ttl=ttl)
191 self._compute = compute
192 self._lock = threading.Lock()
194 def get(self, key: str) -> T:
195 """Get from cache or compute and cache on miss."""
196 cached = self._lru.get(key)
197 if cached is not None:
198 return cached
199 with self._lock:
200 # Double-check after acquiring lock
201 cached = self._lru.get(key)
202 if cached is not None:
203 return cached
204 value = self._compute(key)
205 self._lru.put(key, value)
206 return value
208 def prefetch(self, keys: List[str]) -> int:
209 """Pre-compute and cache values for a list of keys. Returns count cached."""
210 count = 0
211 for key in keys:
212 if key not in self._lru:
213 try:
214 self.get(key)
215 count += 1
216 except Exception:
217 pass
218 return count
220 def invalidate(self, key: str) -> bool:
221 return self._lru.remove(key)
223 def invalidate_pattern(self, pattern: str) -> int:
224 """Invalidate all keys containing pattern substring. Returns count removed."""
225 count = 0
226 for key in list(self._lru._cache.keys()):
227 if pattern in key:
228 if self._lru.remove(key):
229 count += 1
230 return count
232 def clear(self) -> None:
233 self._lru.clear()
235 @property
236 def stats(self) -> Dict[str, Any]:
237 return self._lru.stats
239 def __len__(self) -> int:
240 return len(self._lru)
243# ============================================================================
244# MemoryMonitor
245# ============================================================================
247class MemoryMonitor:
248 """Monitor per-component memory usage with high-water mark tracking."""
250 _singleton = None
251 _lock = threading.Lock()
253 def __new__(cls):
254 if cls._singleton is None:
255 with cls._lock:
256 if cls._singleton is None:
257 cls._singleton = super().__new__(cls)
258 cls._singleton._initialized = False
259 return cls._singleton
261 def __init__(self):
262 if self._initialized:
263 return
264 self._initialized = True
265 self._components: Dict[str, _ComponentMetrics] = {}
266 self._lock = threading.Lock()
268 def register(self, name: str) -> None:
269 with self._lock:
270 if name not in self._components:
271 self._components[name] = _ComponentMetrics(name=name)
273 def record_alloc(self, name: str, size_bytes: int) -> None:
274 with self._lock:
275 comp = self._components.get(name)
276 if comp:
277 comp.current_bytes += size_bytes
278 comp.total_allocations += 1
279 comp.peak_bytes = max(comp.peak_bytes, comp.current_bytes)
281 def record_free(self, name: str, size_bytes: int) -> None:
282 with self._lock:
283 comp = self._components.get(name)
284 if comp:
285 comp.current_bytes = max(0, comp.current_bytes - size_bytes)
287 def snapshot(self) -> Dict[str, Dict[str, Any]]:
288 with self._lock:
289 return {name: comp.to_dict() for name, comp in self._components.items()}
291 def alert(self, name: str, threshold_bytes: int) -> bool:
292 """Check if a component exceeds memory threshold."""
293 with self._lock:
294 comp = self._components.get(name)
295 if comp:
296 return comp.current_bytes > threshold_bytes
297 return False
299 @property
300 def total_current(self) -> int:
301 with self._lock:
302 return sum(c.current_bytes for c in self._components.values())
305@dataclass
306class _ComponentMetrics:
307 name: str
308 current_bytes: int = 0
309 peak_bytes: int = 0
310 total_allocations: int = 0
312 def to_dict(self) -> Dict[str, Any]:
313 return {
314 "name": self.name,
315 "current_bytes": self.current_bytes,
316 "peak_bytes": self.peak_bytes,
317 "total_allocations": self.total_allocations,
318 }
321# ============================================================================
322# Convenience Functions
323# ============================================================================
325def create_object_pool(
326 factory: Callable[[], T],
327 max_size: int = 100,
328 max_idle: int = 30,
329 idle_timeout: float = 300.0,
330) -> ObjectPool[T]:
331 """Create a thread-safe object pool."""
332 return ObjectPool(factory, max_size=max_size, max_idle=max_idle, idle_timeout=idle_timeout)
335def create_lru_cache(capacity: int = 1024, ttl: Optional[float] = None) -> LRUCache[Any]:
336 """Create a thread-safe LRU cache."""
337 return LRUCache(capacity=capacity, ttl=ttl)
340def create_smart_cache(
341 compute: Callable[[str], T],
342 capacity: int = 1024,
343 ttl: float = 300.0,
344) -> SmartCache[T]:
345 """Create a smart cache with compute-on-miss."""
346 return SmartCache(compute, capacity=capacity, ttl=ttl)
349def get_memory_monitor() -> MemoryMonitor:
350 """Get the singleton memory monitor."""
351 return MemoryMonitor()