Coverage for src / kemi / operations / _query_cache.py: 92%
25 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""LRU cache for `Memory.recall()` results.
3Kept as a small private class because the cache is tightly coupled to
4`MemoryObject` semantics (returns shallow copies to prevent cache corruption
5when callers mutate the returned list).
6"""
8from __future__ import annotations
10from collections import OrderedDict
11from typing import Any
13from kemi.models import LifecycleState, MemoryObject
16class _QueryCache:
17 """Simple LRU cache for `recall()` query results.
19 Caches lists of `MemoryObject`s keyed by query parameters.
20 Returns a *shallow copy* of the cached list so callers can
21 safely mutate the returned result without corrupting the cache.
22 """
24 def __init__(self, max_size: int = 128) -> None:
25 self._max_size = max_size
26 self._cache: OrderedDict[str, list[MemoryObject]] = OrderedDict()
28 def _make_key(
29 self,
30 user_id: str,
31 query: str,
32 top_k: int,
33 max_tokens: int | None,
34 lifecycle_filter: list[LifecycleState] | None,
35 hybrid_search: bool | None,
36 namespace: str,
37 session_id: str | None,
38 metadata_filter: dict[str, Any] | None,
39 ) -> str:
40 """Build a stable string key from query parameters."""
41 lf = tuple(sorted(s.value for s in lifecycle_filter)) if lifecycle_filter else ()
42 mf = tuple(sorted((k, v) for k, v in (metadata_filter or {}).items()))
43 return "|".join(
44 [
45 user_id,
46 query,
47 str(top_k),
48 str(max_tokens),
49 str(lf),
50 str(hybrid_search),
51 namespace,
52 str(session_id),
53 str(mf),
54 ]
55 )
57 def _copy_memories(self, memories: list[MemoryObject]) -> list[MemoryObject]:
58 """Return a list of MemoryObject copies with mutable fields duplicated."""
59 return [
60 MemoryObject(
61 memory_id=m.memory_id,
62 user_id=m.user_id,
63 content=m.content,
64 embedding=m.embedding,
65 score=m.score,
66 created_at=m.created_at,
67 last_accessed_at=m.last_accessed_at,
68 source=m.source,
69 importance=m.importance,
70 lifecycle_state=m.lifecycle_state,
71 metadata=m.metadata.copy(),
72 embedding_dim=m.embedding_dim,
73 tags=list(m.tags),
74 confidence=m.confidence,
75 memory_type=m.memory_type,
76 session_id=m.session_id,
77 namespace=m.namespace,
78 version=m.version,
79 agent_id=m.agent_id,
80 run_id=m.run_id,
81 app_id=m.app_id,
82 )
83 for m in memories
84 ]
86 def get(self, key: str) -> list[MemoryObject] | None:
87 if key in self._cache:
88 # Move to end (most recently used)
89 self._cache.move_to_end(key)
90 # Return copies so callers cannot mutate the cached objects.
91 return self._copy_memories(self._cache[key])
92 return None
94 def put(self, key: str, value: list[MemoryObject]) -> None:
95 if key in self._cache:
96 self._cache.move_to_end(key)
97 # Store copies so internal mutations (e.g., lifecycle updates on
98 # the result list returned by recall) don't corrupt the cache.
99 self._cache[key] = self._copy_memories(value)
100 while len(self._cache) > self._max_size:
101 self._cache.popitem(last=False)