Coverage for src / kemi / adapters / storage / json.py: 80%

134 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1from __future__ import annotations 

2 

3import json 

4from datetime import datetime 

5from pathlib import Path 

6from typing import TYPE_CHECKING, Any 

7 

8from kemi import scoring 

9from kemi.adapters.base import StorageAdapter 

10from kemi.models import LifecycleState, MemoryObject, MemorySource, MemoryType 

11 

12if TYPE_CHECKING: 

13 from kemi.encryption import EncryptionConfig 

14 

15 

16class JSONStorageAdapter(StorageAdapter): 

17 """JSON file storage adapter. 

18 

19 Thread safety: NOT guaranteed. Do not use from multiple threads. 

20 Embedding stored as list of floats in JSON for readability. 

21 

22 Encryption: When encryption config is provided, uses Fernet field-level 

23 encryption for content and metadata fields. Pass ``encryption=`` to 

24 ``__init__`` or set ``KEMI_ENCRYPTION_KEY`` / ``KEMI_ENCRYPTION_ENABLED`` 

25 environment variables. 

26 """ 

27 

28 def __init__(self, path: str = "kemi.json", encryption: "EncryptionConfig | None" = None): 

29 self._path = Path(path) 

30 self._data = self._load() 

31 # Lazy import to avoid circular dependency 

32 from kemi.encryption import EncryptionConfig, FieldEncryptor 

33 

34 if encryption is None: 

35 try: 

36 env_config = EncryptionConfig.from_env() 

37 self._field_encryptor = FieldEncryptor(env_config) if env_config.enabled else None 

38 except Exception: 

39 self._field_encryptor = None 

40 else: 

41 self._field_encryptor = FieldEncryptor(encryption) if encryption.enabled else None 

42 

43 def _load(self) -> dict[str, Any]: 

44 if self._path.exists(): 

45 with open(self._path) as f: 

46 return json.load(f) # type: ignore[no-any-return] 

47 return {"memories": {}, "schema_version": 1} 

48 

49 def _save(self) -> None: 

50 with open(self._path, "w") as f: 

51 json.dump(self._data, f, indent=2) 

52 

53 def _row_to_memory(self, data: dict[str, Any]) -> MemoryObject: 

54 # Decrypt fields if encryption is enabled 

55 content_val: Any = data.get("content", "") 

56 metadata_val: Any = data.get("metadata", {}) 

57 user_id_val: Any = data.get("user_id", "") 

58 

59 if self._field_encryptor is not None: 

60 if self._field_encryptor._is_encrypted(content_val): 

61 content_val = self._field_encryptor.decrypt_field("content", content_val) 

62 if self._field_encryptor._is_encrypted(metadata_val): 

63 metadata_val = self._field_encryptor.decrypt_field("metadata", metadata_val) 

64 if self._field_encryptor._is_encrypted(user_id_val): 

65 user_id_val = self._field_encryptor.decrypt_field("user_id", user_id_val) 

66 

67 expires_at_raw = data.get("expires_at") 

68 expires_at = ( 

69 datetime.fromisoformat(expires_at_raw) if expires_at_raw else None 

70 ) 

71 return MemoryObject( 

72 memory_id=data["memory_id"], 

73 user_id=user_id_val, 

74 content=str(content_val), 

75 embedding=data.get("embedding"), 

76 score=0.0, 

77 created_at=datetime.fromisoformat(data["created_at"]), 

78 last_accessed_at=datetime.fromisoformat(data["last_accessed_at"]), 

79 source=MemorySource(data["source"]), 

80 importance=data["importance"], 

81 lifecycle_state=LifecycleState(data["lifecycle_state"]), 

82 metadata=metadata_val if isinstance(metadata_val, dict) else {}, 

83 embedding_dim=data.get("embedding_dim"), 

84 tags=data.get("tags", []), 

85 confidence=data.get("confidence", 1.0), 

86 memory_type=MemoryType(data.get("memory_type", "episodic")), 

87 session_id=data.get("session_id"), 

88 namespace=data.get("namespace", "default"), 

89 version=data.get("version", 1), 

90 agent_id=data.get("agent_id"), 

91 run_id=data.get("run_id"), 

92 app_id=data.get("app_id"), 

93 expires_at=expires_at, 

94 ) 

95 

96 def store(self, memory: MemoryObject) -> None: 

97 content_val: Any = memory.content 

98 metadata_val: Any = memory.metadata 

99 user_id_val: Any = memory.user_id 

100 

101 if self._field_encryptor is not None: 

102 content_val = self._field_encryptor.encrypt_field("content", memory.content) 

103 metadata_val = self._field_encryptor.encrypt_field("metadata", memory.metadata) 

104 if self._field_encryptor._encrypt_user_id: 

105 user_id_val = self._field_encryptor.encrypt_field("user_id", memory.user_id) 

106 

107 self._data["memories"][memory.memory_id] = { 

108 "memory_id": memory.memory_id, 

109 "user_id": user_id_val, 

110 "content": content_val, 

111 "embedding": memory.embedding, 

112 "created_at": memory.created_at.isoformat(), 

113 "last_accessed_at": memory.last_accessed_at.isoformat(), 

114 "source": memory.source.value, 

115 "importance": memory.importance, 

116 "lifecycle_state": memory.lifecycle_state.value, 

117 "metadata": metadata_val, 

118 "embedding_dim": memory.embedding_dim, 

119 "tags": memory.tags, 

120 "confidence": memory.confidence, 

121 "memory_type": memory.memory_type.value, 

122 "session_id": memory.session_id, 

123 "namespace": memory.namespace, 

124 "version": memory.version, 

125 "agent_id": memory.agent_id, 

126 "run_id": memory.run_id, 

127 "app_id": memory.app_id, 

128 "expires_at": memory.expires_at.isoformat() if memory.expires_at else None, 

129 } 

130 self._save() 

131 

132 def search( 

133 self, 

134 user_id: str, 

135 query_embedding: list[float], 

136 top_k: int = 10, 

137 lifecycle_filter: list[LifecycleState] | None = None, 

138 namespace: str = "default", 

139 session_id: str | None = None, 

140 ) -> list[MemoryObject]: 

141 if lifecycle_filter is None: 

142 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

143 

144 states = {s.value for s in lifecycle_filter} 

145 

146 memories = [] 

147 for mem_data in self._data["memories"].values(): 

148 if mem_data["user_id"] != user_id: 

149 continue 

150 if mem_data["lifecycle_state"] not in states: 

151 continue 

152 if mem_data.get("namespace", "default") != namespace: 

153 continue 

154 if session_id is not None and mem_data.get("session_id") not in (session_id, None): 

155 continue 

156 

157 memory = self._row_to_memory(mem_data) 

158 if memory.embedding is not None: 

159 similarity = scoring.cosine_similarity(memory.embedding, query_embedding) 

160 memory.score = (similarity + 1.0) / 2.0 

161 memories.append(memory) 

162 

163 memories.sort(key=lambda m: m.score, reverse=True) 

164 return memories[:top_k] 

165 

166 def get(self, memory_id: str) -> MemoryObject | None: 

167 mem_data = self._data["memories"].get(memory_id) 

168 if mem_data: 

169 return self._row_to_memory(mem_data) 

170 return None 

171 

172 def update(self, memory: MemoryObject) -> None: 

173 self.store(memory) 

174 

175 def delete_by_user(self, user_id: str) -> int: 

176 to_delete = [mid for mid, m in self._data["memories"].items() if m["user_id"] == user_id] 

177 for mid in to_delete: 

178 del self._data["memories"][mid] 

179 if to_delete: 

180 self._save() 

181 return len(to_delete) 

182 

183 def delete_by_id(self, memory_id: str) -> bool: 

184 if memory_id in self._data["memories"]: 

185 del self._data["memories"][memory_id] 

186 self._save() 

187 return True 

188 return False 

189 

190 def get_all_by_user( 

191 self, 

192 user_id: str, 

193 lifecycle_filter: list[LifecycleState] | None = None, 

194 namespace: str = "default", 

195 session_id: str | None = None, 

196 limit: int | None = None, 

197 offset: int | None = None, 

198 ) -> list[MemoryObject]: 

199 if lifecycle_filter is None: 

200 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

201 

202 states = {s.value for s in lifecycle_filter} 

203 

204 results = [ 

205 self._row_to_memory(m) 

206 for m in self._data["memories"].values() 

207 if m["user_id"] == user_id 

208 and m["lifecycle_state"] in states 

209 and m.get("namespace", "default") == namespace 

210 and (session_id is None or m.get("session_id") in (session_id, None)) 

211 ] 

212 # Apply pagination 

213 if offset is not None: 

214 results = results[offset:] 

215 if limit is not None: 

216 results = results[:limit] 

217 return results 

218 

219 def count(self, user_id: str) -> int: 

220 return sum(1 for m in self._data["memories"].values() if m["user_id"] == user_id) 

221 

222 def get_all( 

223 self, 

224 limit: int | None = None, 

225 offset: int | None = None, 

226 ) -> list[MemoryObject]: 

227 results = [self._row_to_memory(m) for m in self._data["memories"].values()] 

228 if offset is not None: 

229 results = results[offset:] 

230 if limit is not None: 

231 results = results[:limit] 

232 return results 

233 

234 def get_all_users(self) -> list[str]: 

235 users = set(m["user_id"] for m in self._data["memories"].values()) 

236 return list(users) 

237 

238 def upgrade_schema(self, from_version: int, to_version: int) -> None: 

239 self._data["schema_version"] = to_version 

240 self._save() 

241 

242 def get_by_tag( 

243 self, 

244 user_id: str, 

245 tag: str, 

246 lifecycle_filter: list[LifecycleState] | None = None, 

247 namespace: str = "default", 

248 ) -> list[MemoryObject]: 

249 if lifecycle_filter is None: 

250 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

251 

252 states = {s.value for s in lifecycle_filter} 

253 

254 return [ 

255 self._row_to_memory(m) 

256 for m in self._data["memories"].values() 

257 if m["user_id"] == user_id 

258 and m["lifecycle_state"] in states 

259 and m.get("namespace", "default") == namespace 

260 and tag in m.get("tags", []) 

261 ] 

262 

263 def search_by_content( 

264 self, 

265 user_id: str, 

266 query: str, 

267 top_k: int = 10, 

268 lifecycle_filter: list[LifecycleState] | None = None, 

269 namespace: str = "default", 

270 session_id: str | None = None, 

271 ) -> list[MemoryObject]: 

272 """Search for memories using keyword matching (no embeddings required).""" 

273 if lifecycle_filter is None: 

274 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

275 

276 states = {s.value for s in lifecycle_filter} 

277 query_lower = query.lower() 

278 

279 candidates = [ 

280 self._row_to_memory(m) 

281 for m in self._data["memories"].values() 

282 if m["user_id"] == user_id 

283 and m["lifecycle_state"] in states 

284 and m.get("namespace", "default") == namespace 

285 and (session_id is None or m.get("session_id") in (session_id, None)) 

286 and query_lower in m["content"].lower() 

287 ] 

288 

289 # Simple scoring: longer matches rank higher 

290 for mem in candidates: 

291 mem.score = len(query) / max(len(mem.content), 1) 

292 

293 candidates.sort(key=lambda m: m.score, reverse=True) 

294 return candidates[:top_k]