Coverage for src / kemi / operations / _query_cache.py: 92%

25 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""LRU cache for `Memory.recall()` results. 

2 

3Kept as a small private class because the cache is tightly coupled to 

4`MemoryObject` semantics (returns shallow copies to prevent cache corruption 

5when callers mutate the returned list). 

6""" 

7 

8from __future__ import annotations 

9 

10from collections import OrderedDict 

11from typing import Any 

12 

13from kemi.models import LifecycleState, MemoryObject 

14 

15 

16class _QueryCache: 

17 """Simple LRU cache for `recall()` query results. 

18 

19 Caches lists of `MemoryObject`s keyed by query parameters. 

20 Returns a *shallow copy* of the cached list so callers can 

21 safely mutate the returned result without corrupting the cache. 

22 """ 

23 

24 def __init__(self, max_size: int = 128) -> None: 

25 self._max_size = max_size 

26 self._cache: OrderedDict[str, list[MemoryObject]] = OrderedDict() 

27 

28 def _make_key( 

29 self, 

30 user_id: str, 

31 query: str, 

32 top_k: int, 

33 max_tokens: int | None, 

34 lifecycle_filter: list[LifecycleState] | None, 

35 hybrid_search: bool | None, 

36 namespace: str, 

37 session_id: str | None, 

38 metadata_filter: dict[str, Any] | None, 

39 ) -> str: 

40 """Build a stable string key from query parameters.""" 

41 lf = tuple(sorted(s.value for s in lifecycle_filter)) if lifecycle_filter else () 

42 mf = tuple(sorted((k, v) for k, v in (metadata_filter or {}).items())) 

43 return "|".join( 

44 [ 

45 user_id, 

46 query, 

47 str(top_k), 

48 str(max_tokens), 

49 str(lf), 

50 str(hybrid_search), 

51 namespace, 

52 str(session_id), 

53 str(mf), 

54 ] 

55 ) 

56 

57 def _copy_memories(self, memories: list[MemoryObject]) -> list[MemoryObject]: 

58 """Return a list of MemoryObject copies with mutable fields duplicated.""" 

59 return [ 

60 MemoryObject( 

61 memory_id=m.memory_id, 

62 user_id=m.user_id, 

63 content=m.content, 

64 embedding=m.embedding, 

65 score=m.score, 

66 created_at=m.created_at, 

67 last_accessed_at=m.last_accessed_at, 

68 source=m.source, 

69 importance=m.importance, 

70 lifecycle_state=m.lifecycle_state, 

71 metadata=m.metadata.copy(), 

72 embedding_dim=m.embedding_dim, 

73 tags=list(m.tags), 

74 confidence=m.confidence, 

75 memory_type=m.memory_type, 

76 session_id=m.session_id, 

77 namespace=m.namespace, 

78 version=m.version, 

79 agent_id=m.agent_id, 

80 run_id=m.run_id, 

81 app_id=m.app_id, 

82 ) 

83 for m in memories 

84 ] 

85 

86 def get(self, key: str) -> list[MemoryObject] | None: 

87 if key in self._cache: 

88 # Move to end (most recently used) 

89 self._cache.move_to_end(key) 

90 # Return copies so callers cannot mutate the cached objects. 

91 return self._copy_memories(self._cache[key]) 

92 return None 

93 

94 def put(self, key: str, value: list[MemoryObject]) -> None: 

95 if key in self._cache: 

96 self._cache.move_to_end(key) 

97 # Store copies so internal mutations (e.g., lifecycle updates on 

98 # the result list returned by recall) don't corrupt the cache. 

99 self._cache[key] = self._copy_memories(value) 

100 while len(self._cache) > self._max_size: 

101 self._cache.popitem(last=False)