Coverage for src / kemi / adapters / storage / json.py: 80%
134 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1from __future__ import annotations
3import json
4from datetime import datetime
5from pathlib import Path
6from typing import TYPE_CHECKING, Any
8from kemi import scoring
9from kemi.adapters.base import StorageAdapter
10from kemi.models import LifecycleState, MemoryObject, MemorySource, MemoryType
12if TYPE_CHECKING:
13 from kemi.encryption import EncryptionConfig
16class JSONStorageAdapter(StorageAdapter):
17 """JSON file storage adapter.
19 Thread safety: NOT guaranteed. Do not use from multiple threads.
20 Embedding stored as list of floats in JSON for readability.
22 Encryption: When encryption config is provided, uses Fernet field-level
23 encryption for content and metadata fields. Pass ``encryption=`` to
24 ``__init__`` or set ``KEMI_ENCRYPTION_KEY`` / ``KEMI_ENCRYPTION_ENABLED``
25 environment variables.
26 """
28 def __init__(self, path: str = "kemi.json", encryption: "EncryptionConfig | None" = None):
29 self._path = Path(path)
30 self._data = self._load()
31 # Lazy import to avoid circular dependency
32 from kemi.encryption import EncryptionConfig, FieldEncryptor
34 if encryption is None:
35 try:
36 env_config = EncryptionConfig.from_env()
37 self._field_encryptor = FieldEncryptor(env_config) if env_config.enabled else None
38 except Exception:
39 self._field_encryptor = None
40 else:
41 self._field_encryptor = FieldEncryptor(encryption) if encryption.enabled else None
43 def _load(self) -> dict[str, Any]:
44 if self._path.exists():
45 with open(self._path) as f:
46 return json.load(f) # type: ignore[no-any-return]
47 return {"memories": {}, "schema_version": 1}
49 def _save(self) -> None:
50 with open(self._path, "w") as f:
51 json.dump(self._data, f, indent=2)
53 def _row_to_memory(self, data: dict[str, Any]) -> MemoryObject:
54 # Decrypt fields if encryption is enabled
55 content_val: Any = data.get("content", "")
56 metadata_val: Any = data.get("metadata", {})
57 user_id_val: Any = data.get("user_id", "")
59 if self._field_encryptor is not None:
60 if self._field_encryptor._is_encrypted(content_val):
61 content_val = self._field_encryptor.decrypt_field("content", content_val)
62 if self._field_encryptor._is_encrypted(metadata_val):
63 metadata_val = self._field_encryptor.decrypt_field("metadata", metadata_val)
64 if self._field_encryptor._is_encrypted(user_id_val):
65 user_id_val = self._field_encryptor.decrypt_field("user_id", user_id_val)
67 expires_at_raw = data.get("expires_at")
68 expires_at = (
69 datetime.fromisoformat(expires_at_raw) if expires_at_raw else None
70 )
71 return MemoryObject(
72 memory_id=data["memory_id"],
73 user_id=user_id_val,
74 content=str(content_val),
75 embedding=data.get("embedding"),
76 score=0.0,
77 created_at=datetime.fromisoformat(data["created_at"]),
78 last_accessed_at=datetime.fromisoformat(data["last_accessed_at"]),
79 source=MemorySource(data["source"]),
80 importance=data["importance"],
81 lifecycle_state=LifecycleState(data["lifecycle_state"]),
82 metadata=metadata_val if isinstance(metadata_val, dict) else {},
83 embedding_dim=data.get("embedding_dim"),
84 tags=data.get("tags", []),
85 confidence=data.get("confidence", 1.0),
86 memory_type=MemoryType(data.get("memory_type", "episodic")),
87 session_id=data.get("session_id"),
88 namespace=data.get("namespace", "default"),
89 version=data.get("version", 1),
90 agent_id=data.get("agent_id"),
91 run_id=data.get("run_id"),
92 app_id=data.get("app_id"),
93 expires_at=expires_at,
94 )
96 def store(self, memory: MemoryObject) -> None:
97 content_val: Any = memory.content
98 metadata_val: Any = memory.metadata
99 user_id_val: Any = memory.user_id
101 if self._field_encryptor is not None:
102 content_val = self._field_encryptor.encrypt_field("content", memory.content)
103 metadata_val = self._field_encryptor.encrypt_field("metadata", memory.metadata)
104 if self._field_encryptor._encrypt_user_id:
105 user_id_val = self._field_encryptor.encrypt_field("user_id", memory.user_id)
107 self._data["memories"][memory.memory_id] = {
108 "memory_id": memory.memory_id,
109 "user_id": user_id_val,
110 "content": content_val,
111 "embedding": memory.embedding,
112 "created_at": memory.created_at.isoformat(),
113 "last_accessed_at": memory.last_accessed_at.isoformat(),
114 "source": memory.source.value,
115 "importance": memory.importance,
116 "lifecycle_state": memory.lifecycle_state.value,
117 "metadata": metadata_val,
118 "embedding_dim": memory.embedding_dim,
119 "tags": memory.tags,
120 "confidence": memory.confidence,
121 "memory_type": memory.memory_type.value,
122 "session_id": memory.session_id,
123 "namespace": memory.namespace,
124 "version": memory.version,
125 "agent_id": memory.agent_id,
126 "run_id": memory.run_id,
127 "app_id": memory.app_id,
128 "expires_at": memory.expires_at.isoformat() if memory.expires_at else None,
129 }
130 self._save()
132 def search(
133 self,
134 user_id: str,
135 query_embedding: list[float],
136 top_k: int = 10,
137 lifecycle_filter: list[LifecycleState] | None = None,
138 namespace: str = "default",
139 session_id: str | None = None,
140 ) -> list[MemoryObject]:
141 if lifecycle_filter is None:
142 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
144 states = {s.value for s in lifecycle_filter}
146 memories = []
147 for mem_data in self._data["memories"].values():
148 if mem_data["user_id"] != user_id:
149 continue
150 if mem_data["lifecycle_state"] not in states:
151 continue
152 if mem_data.get("namespace", "default") != namespace:
153 continue
154 if session_id is not None and mem_data.get("session_id") not in (session_id, None):
155 continue
157 memory = self._row_to_memory(mem_data)
158 if memory.embedding is not None:
159 similarity = scoring.cosine_similarity(memory.embedding, query_embedding)
160 memory.score = (similarity + 1.0) / 2.0
161 memories.append(memory)
163 memories.sort(key=lambda m: m.score, reverse=True)
164 return memories[:top_k]
166 def get(self, memory_id: str) -> MemoryObject | None:
167 mem_data = self._data["memories"].get(memory_id)
168 if mem_data:
169 return self._row_to_memory(mem_data)
170 return None
172 def update(self, memory: MemoryObject) -> None:
173 self.store(memory)
175 def delete_by_user(self, user_id: str) -> int:
176 to_delete = [mid for mid, m in self._data["memories"].items() if m["user_id"] == user_id]
177 for mid in to_delete:
178 del self._data["memories"][mid]
179 if to_delete:
180 self._save()
181 return len(to_delete)
183 def delete_by_id(self, memory_id: str) -> bool:
184 if memory_id in self._data["memories"]:
185 del self._data["memories"][memory_id]
186 self._save()
187 return True
188 return False
190 def get_all_by_user(
191 self,
192 user_id: str,
193 lifecycle_filter: list[LifecycleState] | None = None,
194 namespace: str = "default",
195 session_id: str | None = None,
196 limit: int | None = None,
197 offset: int | None = None,
198 ) -> list[MemoryObject]:
199 if lifecycle_filter is None:
200 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
202 states = {s.value for s in lifecycle_filter}
204 results = [
205 self._row_to_memory(m)
206 for m in self._data["memories"].values()
207 if m["user_id"] == user_id
208 and m["lifecycle_state"] in states
209 and m.get("namespace", "default") == namespace
210 and (session_id is None or m.get("session_id") in (session_id, None))
211 ]
212 # Apply pagination
213 if offset is not None:
214 results = results[offset:]
215 if limit is not None:
216 results = results[:limit]
217 return results
219 def count(self, user_id: str) -> int:
220 return sum(1 for m in self._data["memories"].values() if m["user_id"] == user_id)
222 def get_all(
223 self,
224 limit: int | None = None,
225 offset: int | None = None,
226 ) -> list[MemoryObject]:
227 results = [self._row_to_memory(m) for m in self._data["memories"].values()]
228 if offset is not None:
229 results = results[offset:]
230 if limit is not None:
231 results = results[:limit]
232 return results
234 def get_all_users(self) -> list[str]:
235 users = set(m["user_id"] for m in self._data["memories"].values())
236 return list(users)
238 def upgrade_schema(self, from_version: int, to_version: int) -> None:
239 self._data["schema_version"] = to_version
240 self._save()
242 def get_by_tag(
243 self,
244 user_id: str,
245 tag: str,
246 lifecycle_filter: list[LifecycleState] | None = None,
247 namespace: str = "default",
248 ) -> list[MemoryObject]:
249 if lifecycle_filter is None:
250 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
252 states = {s.value for s in lifecycle_filter}
254 return [
255 self._row_to_memory(m)
256 for m in self._data["memories"].values()
257 if m["user_id"] == user_id
258 and m["lifecycle_state"] in states
259 and m.get("namespace", "default") == namespace
260 and tag in m.get("tags", [])
261 ]
263 def search_by_content(
264 self,
265 user_id: str,
266 query: str,
267 top_k: int = 10,
268 lifecycle_filter: list[LifecycleState] | None = None,
269 namespace: str = "default",
270 session_id: str | None = None,
271 ) -> list[MemoryObject]:
272 """Search for memories using keyword matching (no embeddings required)."""
273 if lifecycle_filter is None:
274 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
276 states = {s.value for s in lifecycle_filter}
277 query_lower = query.lower()
279 candidates = [
280 self._row_to_memory(m)
281 for m in self._data["memories"].values()
282 if m["user_id"] == user_id
283 and m["lifecycle_state"] in states
284 and m.get("namespace", "default") == namespace
285 and (session_id is None or m.get("session_id") in (session_id, None))
286 and query_lower in m["content"].lower()
287 ]
289 # Simple scoring: longer matches rank higher
290 for mem in candidates:
291 mem.score = len(query) / max(len(mem.content), 1)
293 candidates.sort(key=lambda m: m.score, reverse=True)
294 return candidates[:top_k]