Coverage for src / kemi / _memory_impl.py: 84%

783 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1from __future__ import annotations 

2 

3import logging 

4import os 

5import uuid 

6from collections.abc import Callable 

7from datetime import datetime, timedelta, timezone 

8from typing import TYPE_CHECKING, Any 

9 

10if TYPE_CHECKING: 

11 from kemi.encryption import EncryptionConfig 

12 

13from kemi.versions import ( 

14 DiffResult, 

15 MemoryVersionStore, 

16 RollbackResult, 

17 VersionSnapshot, 

18) 

19from kemi.webhooks import WebhookDispatcher, WebhookEventType 

20 

21 

22def _memory_to_dict(memory: "MemoryObject") -> dict[str, Any]: 

23 """Convert a MemoryObject to a JSON-serialisable dict for webhook payloads.""" 

24 if not isinstance(memory, MemoryObject): 

25 return {} 

26 return { 

27 "memory_id": memory.memory_id, 

28 "content": memory.content, 

29 "importance": memory.importance, 

30 "confidence": memory.confidence, 

31 "lifecycle_state": memory.lifecycle_state.value if memory.lifecycle_state else None, 

32 "memory_type": memory.memory_type.value if memory.memory_type else None, 

33 "source": memory.source.value if memory.source else None, 

34 "tags": memory.tags, 

35 "namespace": memory.namespace, 

36 "session_id": memory.session_id, 

37 "version": memory.version, 

38 "created_at": memory.created_at.isoformat() if memory.created_at else None, 

39 "last_accessed_at": memory.last_accessed_at.isoformat() if memory.last_accessed_at else None, 

40 "metadata": memory.metadata, 

41 "agent_id": memory.agent_id, 

42 "run_id": memory.run_id, 

43 "app_id": memory.app_id, 

44 } 

45 

46from kemi import lifecycle, sanitize, scoring 

47from kemi.adapters.base import EmbeddingAdapter, StorageAdapter 

48from kemi.entities import EntityLinker, NoopEntityLinker, RegexEntityLinker 

49from kemi.models import ( 

50 LifecycleState, 

51 MemoryConfig, 

52 MemoryObject, 

53 MemorySource, 

54 MemoryType, 

55) 

56 

57logger = logging.getLogger(__name__) 

58 

59 

60class Memory: 

61 def __init__( 

62 self, 

63 embed: EmbeddingAdapter | None = None, 

64 store: StorageAdapter | None = None, 

65 config: MemoryConfig | None = None, 

66 encryption: "EncryptionConfig | None" = None, 

67 entity_linker: "EntityLinker | None" = None, 

68 ) -> None: 

69 # Lazy import to avoid circular dependencies 

70 from kemi.encryption import EncryptionConfig 

71 

72 if encryption is None: 

73 try: 

74 encryption = EncryptionConfig.from_env() 

75 except Exception: 

76 # Broad catch intentional: EncryptionConfig.from_env() reads 

77 # env vars and can fail for many reasons (missing key, bad 

78 # Fernet format, base64 decode errors). Encryption is opt-in; 

79 # fall back to disabled rather than blocking Memory init. 

80 encryption = None 

81 if embed is None: 

82 try: # pragma: no cover 

83 from kemi.adapters.embedding.fastembed import FastEmbedAdapter 

84 

85 self._embed: EmbeddingAdapter = FastEmbedAdapter() 

86 except ImportError as e: 

87 raise ImportError( 

88 "No embedding adapter provided and fastembed is not installed. " 

89 "Install with: pip install kemi[local] or provide your own: " 

90 "Memory(embed=YourAdapter())" 

91 ) from e 

92 else: 

93 self._embed = embed 

94 

95 if store is None: 

96 # Honor an explicit KEMI_DB_PATH so background-task workers and 

97 # ad-hoc Memory() instantiations point at the same database as 

98 # the main app, rather than silently defaulting to ~/.kemi. 

99 env_path = os.environ.get("KEMI_DB_PATH") 

100 if env_path: 

101 default_db_path = os.path.expanduser(env_path) 

102 else: 

103 default_db_path = os.path.join(os.path.expanduser("~"), ".kemi", "memories.db") 

104 os.makedirs(os.path.dirname(default_db_path), exist_ok=True) 

105 

106 try: 

107 from kemi.adapters.storage.sqlite_vec import SQLiteVecStorageAdapter 

108 

109 if SQLiteVecStorageAdapter.is_vec_available(): 

110 embedding_dim = self._embed.dimension() 

111 self._store: StorageAdapter = SQLiteVecStorageAdapter( 

112 db_path=default_db_path, 

113 embedding_dim=embedding_dim, 

114 encryption=encryption if isinstance(encryption, EncryptionConfig) and encryption.enabled else None, 

115 ) 

116 logger.info( 

117 "Using SQLiteVecStorageAdapter with ANN vector search " 

118 "(sqlite-vec installed)" 

119 ) 

120 else: 

121 from kemi.adapters.storage.sqlite import SQLiteStorageAdapter 

122 

123 self._store = SQLiteStorageAdapter( 

124 db_path=default_db_path, 

125 encryption=encryption if isinstance(encryption, EncryptionConfig) and encryption.enabled else None, 

126 ) 

127 except ImportError: # pragma: no cover 

128 from kemi.adapters.storage.sqlite import SQLiteStorageAdapter 

129 

130 self._store = SQLiteStorageAdapter( 

131 db_path=default_db_path, 

132 encryption=encryption if isinstance(encryption, EncryptionConfig) and encryption.enabled else None, 

133 ) # pragma: no cover 

134 else: 

135 self._store = store 

136 

137 if config is None: 

138 self._config: MemoryConfig = MemoryConfig() 

139 else: 

140 self._config = config 

141 

142 # Optional observability 

143 self._metrics: Any | None = None 

144 try: 

145 from kemi.observability import get_metrics_collector 

146 

147 self._metrics = get_metrics_collector() 

148 except ImportError: 

149 pass 

150 

151 self._audit_trail: Any | None = None 

152 self._adaptive_retriever: Any | None = None 

153 self._event_hooks: dict[str, list[Callable[..., Any]]] = {"pre": [], "post": []} 

154 self._query_cache: _QueryCache | None = None 

155 self._version_store: MemoryVersionStore | None = None 

156 self._max_versions_per_memory: int = 50 

157 self._auto_prune_versions: bool = True 

158 self._webhook_dispatcher: WebhookDispatcher | None = None 

159 

160 if entity_linker is not None: 

161 self._entity_linker: EntityLinker = entity_linker 

162 elif self._config.enable_entity_boost: 

163 self._entity_linker = RegexEntityLinker() 

164 else: 

165 self._entity_linker = NoopEntityLinker() 

166 

167 def _latency_tracker(self, operation: str) -> Any: 

168 """Return a context manager that tracks operation latency if metrics are enabled.""" 

169 from kemi.operations import _ops_metrics 

170 return _ops_metrics.latency_tracker(self, operation) 

171 

172 def remember( 

173 self, 

174 user_id: str, 

175 content: str, 

176 importance: float = 0.5, 

177 source: MemorySource = MemorySource.USER_STATED, 

178 metadata: dict[str, Any] | None = None, 

179 sanitize_input: bool = False, 

180 tags: list[str] | None = None, 

181 namespace: str = "default", 

182 session_id: str | None = None, 

183 memory_type: MemoryType = MemoryType.EPISODIC, 

184 confidence: float = 1.0, 

185 agent_id: str | None = None, 

186 run_id: str | None = None, 

187 app_id: str | None = None, 

188 ttl_seconds: int | None = None, 

189 ) -> str: 

190 """Store a new memory. See :class:`kemi.pipeline.ingestion.IngestionPipeline`.""" 

191 from kemi.pipeline.ingestion import IngestionPipeline 

192 

193 self._validate_remember_inputs(user_id, content, importance, ttl_seconds) 

194 

195 # Pre-hooks fire BEFORE the embedding call (matches historical behaviour). 

196 self._run_hooks( 

197 "pre", 

198 "remember", 

199 user_id=user_id, 

200 content=content, 

201 namespace=namespace, 

202 ) 

203 

204 with self._latency_tracker("remember"): 

205 if sanitize_input: 

206 content = sanitize.sanitize(content, strict=self._config.sanitize) 

207 

208 try: 

209 embedding = self._embed.embed_single(content) 

210 except Exception: 

211 # Broad catch intentional: embedding adapters wrap many 

212 # different backends (fastembed, OpenAI, custom) each with 

213 # their own exception types. We just need to count the 

214 # error in metrics, then re-raise the original. 

215 self._record_embed_error() 

216 raise 

217 

218 memory = self._build_memory_object( 

219 user_id=user_id, 

220 content=content, 

221 embedding=embedding, 

222 importance=importance, 

223 source=source, 

224 metadata=metadata, 

225 tags=tags, 

226 namespace=namespace, 

227 session_id=session_id, 

228 memory_type=memory_type, 

229 confidence=confidence, 

230 agent_id=agent_id, 

231 run_id=run_id, 

232 app_id=app_id, 

233 ttl_seconds=ttl_seconds, 

234 ) 

235 

236 stored = IngestionPipeline(self._build_ingestion_context()).ingest(memory) 

237 

238 self._run_hooks( 

239 "post", 

240 "remember", 

241 user_id=user_id, 

242 memory_id=stored.memory_id, 

243 namespace=namespace, 

244 ) 

245 return stored.memory_id 

246 

247 def _build_ingestion_context(self) -> "IngestionContext": 

248 """Assemble a slim :class:`IngestionContext` from the current Memory state.""" 

249 from kemi.pipeline.ingestion import IngestionContext 

250 

251 return IngestionContext( 

252 store=self._store, 

253 config=self._config, 

254 entity_linker=self._entity_linker, 

255 metrics=self._metrics, 

256 record_store_error=self._record_store_error, 

257 dispatch_webhook=self._dispatch_webhook_event, 

258 track_operation=self._track_operation, 

259 get_version_store=self._get_version_store, 

260 auto_prune_versions=self._auto_prune_versions_for_memory, 

261 ) 

262 

263 @staticmethod 

264 def _validate_remember_inputs( 

265 user_id: str, 

266 content: str, 

267 importance: float, 

268 ttl_seconds: int | None, 

269 ) -> None: 

270 if not user_id or not user_id.strip(): 

271 raise ValueError("user_id cannot be empty") 

272 if not content or not content.strip(): 

273 raise ValueError("content cannot be empty — there is nothing to remember") 

274 if not isinstance(importance, (int, float)): 

275 raise TypeError( 

276 f"importance must be a number between 0.0 and 1.0, got {type(importance).__name__}" 

277 ) 

278 if ttl_seconds is not None and ( 

279 not isinstance(ttl_seconds, int) or ttl_seconds <= 0 

280 ): 

281 raise ValueError(f"ttl_seconds must be a positive integer, got {ttl_seconds}") 

282 

283 @staticmethod 

284 def _build_memory_object( 

285 user_id: str, 

286 content: str, 

287 embedding: list[float], 

288 importance: float, 

289 source: MemorySource, 

290 metadata: dict[str, Any] | None, 

291 tags: list[str] | None, 

292 namespace: str, 

293 session_id: str | None, 

294 memory_type: MemoryType, 

295 confidence: float, 

296 agent_id: str | None, 

297 run_id: str | None, 

298 app_id: str | None, 

299 ttl_seconds: int | None, 

300 ) -> MemoryObject: 

301 """Construct a fresh ACTIVE ``MemoryObject`` from raw inputs.""" 

302 clamped_importance = max(0.0, min(1.0, importance)) 

303 return MemoryObject( 

304 memory_id=str(uuid.uuid4()), 

305 user_id=user_id, 

306 content=content, 

307 embedding=embedding, 

308 score=0.0, 

309 created_at=datetime.now(timezone.utc), 

310 last_accessed_at=datetime.now(timezone.utc), 

311 source=source, 

312 importance=clamped_importance, 

313 lifecycle_state=LifecycleState.ACTIVE, 

314 metadata=metadata or {}, 

315 embedding_dim=len(embedding), 

316 tags=tags or [], 

317 confidence=max(0.0, min(1.0, confidence)), 

318 memory_type=memory_type, 

319 session_id=session_id, 

320 namespace=namespace, 

321 version=1, 

322 agent_id=agent_id, 

323 run_id=run_id, 

324 app_id=app_id, 

325 expires_at=( 

326 datetime.now(timezone.utc).replace(microsecond=0) 

327 + timedelta(seconds=ttl_seconds) 

328 if ttl_seconds is not None 

329 else None 

330 ), 

331 ) 

332 

333 def recall( 

334 self, 

335 user_id: str, 

336 query: str, 

337 top_k: int = 5, 

338 max_tokens: int | None = None, 

339 lifecycle_filter: list[LifecycleState] | None = None, 

340 hybrid_search: bool | None = None, 

341 namespace: str = "default", 

342 session_id: str | None = None, 

343 metadata_filter: dict[str, Any] | None = None, 

344 ) -> list[MemoryObject]: 

345 """Recall memories matching ``query``. See :class:`kemi.pipeline.retrieval.RetrievalPipeline`.""" 

346 from kemi.pipeline.retrieval import RetrievalContext, RetrievalPipeline 

347 

348 if not user_id or not user_id.strip(): 

349 raise ValueError("user_id cannot be empty") 

350 if not query or not query.strip(): 

351 raise ValueError("query cannot be empty — what should kemi search for?") 

352 if top_k < 1: 

353 raise ValueError(f"top_k must be at least 1, got {top_k}") 

354 

355 with self._latency_tracker("recall"): 

356 return RetrievalPipeline(self._build_retrieval_context()).retrieve( 

357 user_id=user_id, 

358 query=query, 

359 top_k=top_k, 

360 max_tokens=max_tokens, 

361 lifecycle_filter=lifecycle_filter, 

362 hybrid_search=hybrid_search, 

363 namespace=namespace, 

364 session_id=session_id, 

365 metadata_filter=metadata_filter, 

366 ) 

367 

368 def _build_retrieval_context(self) -> "RetrievalContext": 

369 """Assemble a :class:`RetrievalContext` from the current Memory state.""" 

370 from kemi.pipeline.retrieval import RetrievalContext 

371 

372 return RetrievalContext( 

373 store=self._store, 

374 embed=self._embed, 

375 config=self._config, 

376 entity_linker=self._entity_linker, 

377 query_cache=self._query_cache, 

378 metrics=self._metrics, 

379 adaptive_retriever=self._adaptive_retriever, 

380 run_hooks=self._run_hooks, 

381 track_operation=self._track_operation, 

382 ) 

383 

384 def recall_many( 

385 self, 

386 user_ids: list[str], 

387 queries: list[str], 

388 top_k: int = 5, 

389 max_tokens: int | None = None, 

390 lifecycle_filter: list[LifecycleState] | None = None, 

391 hybrid_search: bool | None = None, 

392 namespace: str = "default", 

393 session_id: str | None = None, 

394 metadata_filter: dict[str, Any] | None = None, 

395 ) -> dict[str, list[MemoryObject]]: 

396 """Recall memories for multiple users and queries at once. 

397 

398 Args: 

399 user_ids: List of user IDs to recall for. 

400 queries: List of query strings (same length as user_ids). 

401 top_k: Max memories per result. 

402 max_tokens: Token budget. 

403 lifecycle_filter: Optional lifecycle state filter. 

404 hybrid_search: Override hybrid search. 

405 namespace: Memory namespace. 

406 session_id: Optional session ID filter. 

407 metadata_filter: Optional metadata key-value filter dict. 

408 

409 Returns: 

410 Dict mapping user_id -> list of MemoryObjects. 

411 """ 

412 if len(user_ids) != len(queries): 

413 raise ValueError("user_ids and queries must have the same length") 

414 results: dict[str, list[MemoryObject]] = {} 

415 for uid, q in zip(user_ids, queries, strict=True): 

416 results[uid] = self.recall( 

417 user_id=uid, 

418 query=q, 

419 top_k=top_k, 

420 max_tokens=max_tokens, 

421 lifecycle_filter=lifecycle_filter, 

422 hybrid_search=hybrid_search, 

423 namespace=namespace, 

424 session_id=session_id, 

425 metadata_filter=metadata_filter, 

426 ) 

427 return results 

428 

429 def update_many( 

430 self, 

431 memory_ids: list[str], 

432 content: str | None = None, 

433 importance: float | None = None, 

434 confidence: float | None = None, 

435 memory_type: MemoryType | None = None, 

436 metadata: dict[str, Any] | None = None, 

437 ) -> list[str]: 

438 """Update multiple memories at once. 

439 

440 Args: 

441 memory_ids: List of memory IDs to update. 

442 content: New content for all (if provided, will re-embed). 

443 importance: New importance for all. 

444 confidence: New confidence for all. 

445 memory_type: New memory type for all. 

446 

447 Returns: 

448 List of updated memory IDs. 

449 """ 

450 updated: list[str] = [] 

451 for mid in memory_ids: 

452 self.update( 

453 mid, 

454 content=content, 

455 importance=importance, 

456 confidence=confidence, 

457 memory_type=memory_type, 

458 metadata=metadata, 

459 ) 

460 updated.append(mid) 

461 return updated 

462 

463 def forget_many( 

464 self, 

465 memory_ids: list[str], 

466 ) -> int: 

467 """Delete multiple memories by ID at once. 

468 

469 Args: 

470 memory_ids: List of memory IDs to delete. 

471 

472 Returns: 

473 Number of memories deleted. 

474 

475 Note: 

476 Unlike :meth:`forget`, this method does not fire pre/post 

477 event hooks for each individual delete. Hooks are intentionally 

478 skipped for batch performance. 

479 """ 

480 count = 0 

481 for mid in memory_ids: 

482 if self._store.delete_by_id(mid): 

483 count += 1 

484 return count 

485 

486 def forget( 

487 self, 

488 user_id: str, 

489 memory_id: str | None = None, 

490 ) -> int: 

491 if not user_id or not user_id.strip(): 

492 raise ValueError("user_id cannot be empty") 

493 

494 with self._latency_tracker("forget"): 

495 self._run_hooks( 

496 "pre", "forget", user_id=user_id, memory_id=memory_id, namespace="default" 

497 ) 

498 if memory_id is not None: 

499 deleted = self._store.delete_by_id(memory_id) 

500 if deleted: 

501 self._dispatch_webhook_event( 

502 WebhookEventType.DELETED, 

503 memory_id=memory_id, 

504 user_id=user_id, 

505 ) 

506 self._run_hooks( 

507 "post", 

508 "forget", 

509 user_id=user_id, 

510 memory_id=memory_id, 

511 deleted=deleted, 

512 namespace="default", 

513 ) 

514 self._track_operation( 

515 "forget", 

516 user_id, 

517 {"memory_id": memory_id, "deleted": deleted}, 

518 memory_id, 

519 namespace="default", 

520 ) 

521 return 1 if deleted else 0 

522 else: 

523 count = self._store.delete_by_user(user_id) 

524 if count: 

525 self._dispatch_webhook_event( 

526 WebhookEventType.DELETED, 

527 memory_id="batch", 

528 user_id=user_id, 

529 deleted_count=count, 

530 ) 

531 self._run_hooks( 

532 "post", "forget", user_id=user_id, deleted_count=count, namespace="default" 

533 ) 

534 self._track_operation( 

535 "forget", user_id, {"deleted_count": count}, namespace="default" 

536 ) 

537 return count 

538 

539 def context_block( 

540 self, 

541 user_id: str, 

542 query: str, 

543 top_k: int = 5, 

544 max_tokens: int = 1500, 

545 prefix: str = "Relevant context from memory:", 

546 namespace: str = "default", 

547 session_id: str | None = None, 

548 ) -> str: 

549 memories = self.recall( 

550 user_id=user_id, 

551 query=query, 

552 top_k=top_k, 

553 max_tokens=max_tokens, 

554 namespace=namespace, 

555 session_id=session_id, 

556 ) 

557 

558 if not memories: 

559 return "" 

560 

561 lines = [prefix] 

562 for mem in memories: 

563 lines.append(f"- {mem.content}") 

564 

565 return "\n".join(lines) 

566 

567 async def aremember( 

568 self, 

569 user_id: str, 

570 content: str, 

571 importance: float = 0.5, 

572 source: MemorySource = MemorySource.USER_STATED, 

573 metadata: dict[str, Any] | None = None, 

574 sanitize_input: bool = False, 

575 tags: list[str] | None = None, 

576 namespace: str = "default", 

577 session_id: str | None = None, 

578 memory_type: MemoryType = MemoryType.EPISODIC, 

579 confidence: float = 1.0, 

580 agent_id: str | None = None, 

581 run_id: str | None = None, 

582 app_id: str | None = None, 

583 ) -> str: 

584 import asyncio 

585 

586 return await asyncio.to_thread( 

587 self.remember, 

588 user_id, 

589 content, 

590 importance, 

591 source, 

592 metadata, 

593 sanitize_input, 

594 tags, 

595 namespace, 

596 session_id, 

597 memory_type, 

598 confidence, 

599 agent_id, 

600 run_id, 

601 app_id, 

602 ) 

603 

604 async def recall_stream( 

605 self, 

606 user_id: str, 

607 query: str, 

608 top_k: int = 5, 

609 max_tokens: int | None = None, 

610 lifecycle_filter: list[LifecycleState] | None = None, 

611 hybrid_search: bool | None = None, 

612 namespace: str = "default", 

613 session_id: str | None = None, 

614 metadata_filter: dict[str, Any] | None = None, 

615 ): 

616 """Stream recall results as an async generator. 

617 

618 Scores all candidate memories, then yields each one as MMR reranking 

619 selects it, providing progressive delivery instead of waiting for 

620 full ranking. 

621 

622 Args: 

623 Same as :meth:`recall`. 

624 

625 Yields: 

626 MemoryObject instances in ranked order. 

627 """ 

628 if not user_id or not user_id.strip(): 

629 raise ValueError("user_id cannot be empty") 

630 if not query or not query.strip(): 

631 raise ValueError("query cannot be empty — what should kemi search for?") 

632 if top_k < 1: 

633 raise ValueError(f"top_k must be at least 1, got {top_k}") 

634 

635 import asyncio 

636 

637 if hybrid_search is None: 

638 hybrid_search = self._config.hybrid_search 

639 

640 query_embedding = await asyncio.to_thread(self._embed.embed_single, query) 

641 

642 if lifecycle_filter is None: 

643 lifecycle_filter = lifecycle.get_recall_filter() 

644 

645 # Run the synchronous store search in a thread 

646 search_results = await asyncio.to_thread( 

647 self._store.search, 

648 user_id=user_id, 

649 query_embedding=query_embedding, 

650 top_k=top_k * 3, 

651 lifecycle_filter=lifecycle_filter, 

652 namespace=namespace, 

653 session_id=session_id, 

654 ) 

655 

656 if metadata_filter is not None: 

657 search_results = [ 

658 m 

659 for m in search_results 

660 if all(m.metadata.get(k) == v for k, v in metadata_filter.items()) 

661 ] 

662 

663 current_dim = self._embed.dimension() 

664 if search_results: 

665 stored_dim = search_results[0].embedding_dim 

666 if stored_dim is not None and stored_dim != current_dim: 

667 raise ValueError( 

668 "Embedding dimension mismatch: stored memories use " 

669 f"{stored_dim} dimensions but current adapter produces " 

670 f"{current_dim} dimensions. Run memory.migrate(user_id, " 

671 "new_adapter) to re-embed your memories." 

672 ) 

673 

674 # Entity extraction for boost 

675 query_entities_stream: set[str] | None = None 

676 memory_entities_map_stream: dict[str, set[str]] | None = None 

677 if self._config.enable_entity_boost: 

678 query_entities_stream = self._entity_linker.extract(query) 

679 memory_entities_map_stream = {} 

680 for m in search_results: 

681 cached = m.metadata.get("extracted_entities") 

682 if cached is not None: 

683 memory_entities_map_stream[m.memory_id] = set(cached) 

684 else: 

685 memory_entities_map_stream[m.memory_id] = self._entity_linker.extract(m.content) 

686 

687 # Score all candidates (same as rank_memories) 

688 corpus = [m.content for m in search_results] if len(search_results) > 1 else None 

689 for memory in search_results: 

690 mem_entities = None 

691 if memory_entities_map_stream is not None: 

692 mem_entities = memory_entities_map_stream.get(memory.memory_id) 

693 memory.score = scoring.score_memory( 

694 memory, 

695 query_embedding, 

696 query, 

697 hybrid_search, 

698 corpus, 

699 self._config.weight_semantic, 

700 self._config.weight_recency, 

701 self._config.weight_bm25, 

702 self._config.weight_semantic_no_embed, 

703 self._config.weight_recency_no_embed, 

704 self._config.weight_importance, 

705 query_entities_stream, 

706 mem_entities, 

707 self._config.entity_boost_weight, 

708 ) 

709 

710 # Sort by score descending first so mmr_rerank_stream gets pre-sorted input 

711 search_results.sort(key=lambda m: m.score, reverse=True) 

712 

713 # Truncate to token budget before MMR 

714 effective_max_tokens = ( 

715 max_tokens if max_tokens is not None else self._config.max_tokens_default 

716 ) 

717 

718 if effective_max_tokens is not None: 

719 search_results = scoring.truncate_by_tokens(search_results, effective_max_tokens) 

720 

721 # Apply MMR and yield progressively 

722 yielded_memories: list[MemoryObject] = [] 

723 for memory in scoring.mmr_rerank_stream( 

724 search_results, query_embedding, top_k, lambda_param=0.7 

725 ): 

726 # Update lifecycle and access time 

727 memory.last_accessed_at = datetime.now(timezone.utc) 

728 new_state = lifecycle.evaluate_lifecycle(memory, self._config.decay_threshold_hours) 

729 if new_state != memory.lifecycle_state: 

730 updated = lifecycle.transition(memory, new_state) 

731 self._store.update(updated) 

732 if self._metrics is not None: 

733 self._metrics.lifecycle_transitions.inc(1) 

734 yielded_memories.append(memory) 

735 yield memory 

736 

737 # Update metrics after all yielded 

738 if self._metrics is not None: 

739 self._metrics.total_memories.set(self._store.count(user_id)) 

740 self._run_hooks( 

741 "post", 

742 "recall", 

743 user_id=user_id, 

744 query=query, 

745 results=yielded_memories, 

746 namespace=namespace, 

747 ) 

748 self._track_operation( 

749 "recall", 

750 user_id, 

751 {"query": query, "results_count": len(yielded_memories), "cache_hit": False, "stream": True}, 

752 namespace=namespace, 

753 ) 

754 if self._adaptive_retriever is not None: 

755 try: 

756 profile = self._adaptive_retriever.analyze_query(query) 

757 self._adaptive_retriever.record_feedback(user_id, query, profile) 

758 except Exception: 

759 logger.debug("Adaptive retrieval analysis failed", exc_info=True) 

760 

761 async def arecall( 

762 self, 

763 user_id: str, 

764 query: str, 

765 top_k: int = 5, 

766 max_tokens: int | None = None, 

767 lifecycle_filter: list[LifecycleState] | None = None, 

768 hybrid_search: bool | None = None, 

769 namespace: str = "default", 

770 session_id: str | None = None, 

771 metadata_filter: dict[str, Any] | None = None, 

772 stream: bool = False, 

773 ): 

774 import asyncio 

775 

776 if stream: 

777 return self.recall_stream( 

778 user_id, 

779 query, 

780 top_k, 

781 max_tokens, 

782 lifecycle_filter, 

783 hybrid_search, 

784 namespace, 

785 session_id, 

786 metadata_filter, 

787 ) 

788 

789 return await asyncio.to_thread( 

790 self.recall, 

791 user_id, 

792 query, 

793 top_k, 

794 max_tokens, 

795 lifecycle_filter, 

796 hybrid_search, 

797 namespace, 

798 session_id, 

799 metadata_filter, 

800 ) 

801 

802 async def aforget( 

803 self, 

804 user_id: str, 

805 memory_id: str | None = None, 

806 ) -> int: 

807 import asyncio 

808 

809 return await asyncio.to_thread(self.forget, user_id, memory_id) 

810 

811 async def acontext_block( 

812 self, 

813 user_id: str, 

814 query: str, 

815 top_k: int = 5, 

816 max_tokens: int = 1500, 

817 prefix: str = "Relevant context from memory:", 

818 namespace: str = "default", 

819 session_id: str | None = None, 

820 ) -> str: 

821 import asyncio 

822 

823 return await asyncio.to_thread( 

824 self.context_block, user_id, query, top_k, max_tokens, prefix, namespace, session_id 

825 ) 

826 

827 def migrate( 

828 self, 

829 user_id: str, 

830 new_embed_fn: EmbeddingAdapter, 

831 batch_size: int = 100, 

832 ) -> int: 

833 if not user_id or not user_id.strip(): 

834 raise ValueError("user_id cannot be empty") 

835 if batch_size < 1: 

836 raise ValueError(f"batch_size must be at least 1, got {batch_size}") 

837 

838 with self._latency_tracker("migrate"): 

839 memories = self._store.get_all_by_user( 

840 user_id, 

841 lifecycle_filter=[LifecycleState.ACTIVE, LifecycleState.DECAYING], 

842 ) 

843 

844 if not memories: 

845 return 0 

846 

847 count = 0 

848 

849 for i in range(0, len(memories), batch_size): 

850 batch = memories[i : i + batch_size] 

851 contents = [m.content for m in batch] 

852 new_embeddings = new_embed_fn.embed(contents) 

853 new_dim = new_embed_fn.dimension() 

854 

855 for j, mem in enumerate(batch): 

856 mem.embedding = new_embeddings[j] 

857 mem.embedding_dim = new_dim 

858 self._store.update(mem) 

859 count += 1 

860 

861 logger.info(f"Migrated {count} memories for user {user_id}") 

862 self._track_operation("migrate", user_id, {"count": count}) 

863 return count 

864 

865 def export(self, file_path: str) -> int: 

866 """Export all memories to a JSON file.""" 

867 import json 

868 

869 all_memories = self._store.get_all() 

870 memories_data = [] 

871 for mem in all_memories: 

872 memories_data.append( 

873 { 

874 "memory_id": mem.memory_id, 

875 "user_id": mem.user_id, 

876 "content": mem.content, 

877 "embedding": mem.embedding, 

878 "score": mem.score, 

879 "created_at": mem.created_at.isoformat() if mem.created_at else None, 

880 "last_accessed_at": mem.last_accessed_at.isoformat() 

881 if mem.last_accessed_at 

882 else None, 

883 "source": mem.source.value if mem.source else None, 

884 "importance": mem.importance, 

885 "lifecycle_state": mem.lifecycle_state.value if mem.lifecycle_state else None, 

886 "metadata": mem.metadata, 

887 "embedding_dim": mem.embedding_dim, 

888 "tags": mem.tags, 

889 "confidence": mem.confidence, 

890 "memory_type": mem.memory_type.value, 

891 "session_id": mem.session_id, 

892 "namespace": mem.namespace, 

893 "version": mem.version, 

894 "agent_id": mem.agent_id, 

895 "run_id": mem.run_id, 

896 "app_id": mem.app_id, 

897 } 

898 ) 

899 

900 with open(file_path, "w") as f: 

901 json.dump(memories_data, f, indent=2) 

902 

903 logger.info(f"Exported {len(memories_data)} memories to {file_path}") 

904 return len(memories_data) 

905 

906 def import_from(self, file_path: str) -> int: 

907 """Import memories from a JSON file.""" 

908 import json 

909 

910 with open(file_path) as f: 

911 memories_data = json.load(f) 

912 

913 imported_count = 0 

914 for mem_data in memories_data: 

915 existing = self._store.get(mem_data["memory_id"]) 

916 if existing is not None: 

917 continue 

918 

919 from datetime import datetime 

920 

921 from kemi.models import LifecycleState, MemorySource 

922 

923 created_at = ( 

924 datetime.fromisoformat(mem_data["created_at"]) 

925 if mem_data.get("created_at") 

926 else datetime.now(timezone.utc) 

927 ) 

928 last_accessed_at = ( 

929 datetime.fromisoformat(mem_data["last_accessed_at"]) 

930 if mem_data.get("last_accessed_at") 

931 else datetime.now(timezone.utc) 

932 ) 

933 

934 memory = MemoryObject( 

935 memory_id=mem_data["memory_id"], 

936 user_id=mem_data["user_id"], 

937 content=mem_data["content"], 

938 embedding=mem_data.get("embedding"), 

939 score=mem_data.get("score", 0.0), 

940 created_at=created_at, 

941 last_accessed_at=last_accessed_at, 

942 source=MemorySource(mem_data["source"]) 

943 if mem_data.get("source") 

944 else MemorySource.USER_STATED, 

945 importance=mem_data.get("importance", 0.5), 

946 lifecycle_state=LifecycleState(mem_data["lifecycle_state"]) 

947 if mem_data.get("lifecycle_state") 

948 else LifecycleState.ACTIVE, 

949 metadata=mem_data.get("metadata", {}), 

950 embedding_dim=mem_data.get("embedding_dim"), 

951 tags=mem_data.get("tags", []), 

952 confidence=mem_data.get("confidence", 1.0), 

953 memory_type=MemoryType(mem_data["memory_type"]) 

954 if mem_data.get("memory_type") 

955 else MemoryType.EPISODIC, 

956 session_id=mem_data.get("session_id"), 

957 namespace=mem_data.get("namespace", "default"), 

958 version=mem_data.get("version", 1), 

959 agent_id=mem_data.get("agent_id"), 

960 run_id=mem_data.get("run_id"), 

961 app_id=mem_data.get("app_id"), 

962 ) 

963 

964 self._store.store(memory) 

965 imported_count += 1 

966 

967 logger.info(f"Imported {imported_count} memories from {file_path}") 

968 return imported_count 

969 

970 async def aexport(self, file_path: str) -> int: 

971 import asyncio 

972 

973 return await asyncio.to_thread(self.export, file_path) 

974 

975 async def aimport_from(self, file_path: str) -> int: 

976 import asyncio 

977 

978 return await asyncio.to_thread(self.import_from, file_path) 

979 

980 def upgrade(self) -> None: 

981 self._store.upgrade_schema(from_version=1, to_version=1) 

982 logger.info("Schema upgraded to version 1") 

983 

984 def remember_many( 

985 self, 

986 user_id: str, 

987 contents: list[str], 

988 importance: float = 0.5, 

989 source: MemorySource = MemorySource.USER_STATED, 

990 tags: list[str] | None = None, 

991 namespace: str = "default", 

992 session_id: str | None = None, 

993 memory_type: MemoryType = MemoryType.EPISODIC, 

994 confidence: float = 1.0, 

995 agent_id: str | None = None, 

996 run_id: str | None = None, 

997 app_id: str | None = None, 

998 ttl_seconds: int | None = None, 

999 ) -> list[str]: 

1000 """Store multiple memories at once. 

1001 

1002 Args: 

1003 user_id: User ID. 

1004 contents: List of content strings to remember. 

1005 importance: Importance value (0.0-1.0) for all. 

1006 source: Memory source. 

1007 tags: Optional list of tags to apply to all memories. 

1008 

1009 Returns: 

1010 List of memory IDs. 

1011 

1012 Note: 

1013 Fires pre/post ``remember`` event hooks for each item so 

1014 batch behavior is consistent with :meth:`recall_many` and 

1015 :meth:`update_many`. 

1016 """ 

1017 if not contents: 

1018 return [] 

1019 

1020 with self._latency_tracker("remember_many"): 

1021 # Batch embed all contents at once for performance 

1022 embeddings = self._embed.embed(contents) 

1023 

1024 memory_ids: list[str] = [] 

1025 audit_batch: list[dict[str, Any]] | None = [] if self._audit_trail is not None else None 

1026 for i, content in enumerate(contents): 

1027 if not content or not content.strip(): 

1028 raise ValueError("content cannot be empty — there is nothing to remember") 

1029 self._run_hooks( 

1030 "pre", "remember", user_id=user_id, content=content, namespace=namespace 

1031 ) 

1032 memory_id = self._remember_with_embedding( 

1033 user_id=user_id, 

1034 content=content, 

1035 embedding=embeddings[i], 

1036 importance=importance, 

1037 source=source, 

1038 tags=tags, 

1039 namespace=namespace, 

1040 session_id=session_id, 

1041 memory_type=memory_type, 

1042 confidence=confidence, 

1043 audit_batch=audit_batch, 

1044 agent_id=agent_id, 

1045 run_id=run_id, 

1046 app_id=app_id, 

1047 ttl_seconds=ttl_seconds, 

1048 ) 

1049 self._run_hooks( 

1050 "post", "remember", user_id=user_id, memory_id=memory_id, namespace=namespace 

1051 ) 

1052 memory_ids.append(memory_id) 

1053 

1054 if self._metrics is not None: 

1055 self._metrics.remember_many_total.inc(1) 

1056 self._metrics.total_memories.set(self._store.count(user_id)) 

1057 

1058 # Append batch-level audit entry and flush all entries in one transaction 

1059 if audit_batch is not None: 

1060 self._track_operation( 

1061 "remember_many", 

1062 user_id, 

1063 {"count": len(memory_ids)}, 

1064 namespace=namespace, 

1065 audit_batch=audit_batch, 

1066 ) 

1067 if self._audit_trail is not None: 

1068 try: 

1069 self._audit_trail.log_operation_batch(audit_batch) 

1070 except Exception: 

1071 logger.warning("Audit log batch failed for remember_many", exc_info=True) 

1072 return memory_ids 

1073 

1074 def _remember_with_embedding( 

1075 self, 

1076 user_id: str, 

1077 content: str, 

1078 embedding: list[float], 

1079 importance: float, 

1080 source: MemorySource, 

1081 metadata: dict[str, Any] | None = None, 

1082 tags: list[str] | None = None, 

1083 namespace: str = "default", 

1084 session_id: str | None = None, 

1085 memory_type: MemoryType = MemoryType.EPISODIC, 

1086 confidence: float = 1.0, 

1087 audit_batch: list[dict[str, Any]] | None = None, 

1088 agent_id: str | None = None, 

1089 run_id: str | None = None, 

1090 app_id: str | None = None, 

1091 ttl_seconds: int | None = None, 

1092 ) -> str: 

1093 """Internal: store a memory with a pre-computed embedding. 

1094 

1095 Builds the ``MemoryObject`` and delegates to 

1096 :class:`kemi.pipeline.ingestion.IngestionPipeline`. Used by 

1097 :meth:`remember_many` in a tight loop. Does NOT fire hooks; 

1098 that is the responsibility of the public :meth:`remember`. 

1099 """ 

1100 from kemi.pipeline.ingestion import IngestionPipeline 

1101 

1102 memory = self._build_memory_object( 

1103 user_id=user_id, 

1104 content=content, 

1105 embedding=embedding, 

1106 importance=importance, 

1107 source=source, 

1108 metadata=metadata, 

1109 tags=tags, 

1110 namespace=namespace, 

1111 session_id=session_id, 

1112 memory_type=memory_type, 

1113 confidence=confidence, 

1114 agent_id=agent_id, 

1115 run_id=run_id, 

1116 app_id=app_id, 

1117 ttl_seconds=ttl_seconds, 

1118 ) 

1119 stored = IngestionPipeline(self._build_ingestion_context()).ingest( 

1120 memory, audit_batch=audit_batch 

1121 ) 

1122 return stored.memory_id 

1123 

1124 def list_users(self) -> list[str]: 

1125 """Get all unique user IDs that have memories. 

1126 

1127 Returns: 

1128 List of user IDs. 

1129 """ 

1130 return self._store.get_all_users() 

1131 

1132 def prune( 

1133 self, 

1134 user_id: str, 

1135 max_age_days: float | None = None, 

1136 min_importance: float | None = None, 

1137 lifecycle_states: list[LifecycleState] | None = None, 

1138 namespace: str = "default", 

1139 ) -> int: 

1140 """Auto-prune old or low-importance memories. 

1141 

1142 Args: 

1143 user_id: User ID to prune. 

1144 max_age_days: Delete memories older than this many days. 

1145 min_importance: Delete memories with importance below this threshold. 

1146 lifecycle_states: Only prune memories in these states. 

1147 namespace: Memory namespace to prune. 

1148 

1149 Returns: 

1150 Number of memories deleted. 

1151 """ 

1152 if not user_id or not user_id.strip(): 

1153 raise ValueError("user_id cannot be empty") 

1154 

1155 with self._latency_tracker("prune"): 

1156 all_memories = self._store.get_all_by_user( 

1157 user_id, 

1158 lifecycle_filter=lifecycle_states or [LifecycleState.DECAYING], 

1159 namespace=namespace, 

1160 ) 

1161 

1162 to_delete: list[str] = [] 

1163 now = datetime.now(timezone.utc) 

1164 

1165 for mem in all_memories: 

1166 if max_age_days is not None: 

1167 age_days = (now - mem.created_at).total_seconds() / 86400.0 

1168 if age_days > max_age_days: 

1169 to_delete.append(mem.memory_id) 

1170 continue 

1171 

1172 if min_importance is not None: 

1173 if mem.importance < min_importance: 

1174 to_delete.append(mem.memory_id) 

1175 continue 

1176 

1177 for mid in to_delete: 

1178 self._store.delete_by_id(mid) 

1179 

1180 logger.info(f"Pruned {len(to_delete)} memories for user {user_id}") 

1181 self._track_operation( 

1182 "prune", user_id, {"deleted_count": len(to_delete)}, namespace=namespace 

1183 ) 

1184 return len(to_delete) 

1185 

1186 def prune_expired( 

1187 self, 

1188 user_id: str | None = None, 

1189 namespace: str | None = None, 

1190 ) -> int: 

1191 """Delete memories whose ``expires_at`` has passed. 

1192 

1193 Sweeps ACTIVE and DECAYING memories with a non-null ``expires_at`` 

1194 in the past, transitions them to DELETED, and removes them from 

1195 the store. Called automatically by :meth:`run_maintenance`. 

1196 

1197 Args: 

1198 user_id: If provided, only sweep this user. If None, sweep all 

1199 users. 

1200 namespace: If provided, only sweep this namespace. If None, 

1201 sweep memories across all namespaces. 

1202 

1203 Returns: 

1204 Number of memories deleted. 

1205 """ 

1206 with self._latency_tracker("prune_expired"): 

1207 now = datetime.now(timezone.utc) 

1208 deleted = 0 

1209 

1210 if user_id is not None: 

1211 users = [user_id] 

1212 else: 

1213 users = self._store.get_all_users() 

1214 

1215 for uid in users: 

1216 if namespace is not None: 

1217 memories = self._store.get_all_by_user( 

1218 uid, 

1219 lifecycle_filter=[LifecycleState.ACTIVE, LifecycleState.DECAYING], 

1220 namespace=namespace, 

1221 ) 

1222 else: 

1223 # Sweep across all namespaces. The base adapter's 

1224 # ``get_all_by_user`` defaults ``namespace="default"``, 

1225 # so we fetch from each known namespace explicitly. 

1226 namespaces = self._known_namespaces(uid) 

1227 memories = [] 

1228 for ns in namespaces: 

1229 memories.extend( 

1230 self._store.get_all_by_user( 

1231 uid, 

1232 lifecycle_filter=[ 

1233 LifecycleState.ACTIVE, 

1234 LifecycleState.DECAYING, 

1235 ], 

1236 namespace=ns, 

1237 ) 

1238 ) 

1239 for mem in memories: 

1240 if mem.expires_at is not None and mem.expires_at <= now: 

1241 if self._store.delete_by_id(mem.memory_id): 

1242 deleted += 1 

1243 

1244 if deleted > 0: 

1245 logger.info( 

1246 f"Pruned {deleted} expired memories" 

1247 + (f" for user {user_id}" if user_id else "") 

1248 ) 

1249 self._track_operation( 

1250 "prune_expired", 

1251 user_id or "all", 

1252 {"deleted_count": deleted}, 

1253 namespace=namespace or "all", 

1254 ) 

1255 return deleted 

1256 

1257 def _known_namespaces(self, user_id: str) -> set[str]: 

1258 """Return the set of distinct namespaces holding memories for a user.""" 

1259 namespaces: set[str] = set() 

1260 # Default namespace is always present even if empty. 

1261 namespaces.add("default") 

1262 try: 

1263 # Use get_all_by_user with no namespace filter is not possible 

1264 # (the base API defaults to "default"), so we sample a small 

1265 # batch from get_all() and collect namespaces. 

1266 for mem in self._store.get_all(limit=1000): 

1267 if mem.user_id == user_id: 

1268 namespaces.add(mem.namespace) 

1269 except Exception: 

1270 pass 

1271 return namespaces 

1272 

1273 def recall_between( 

1274 self, 

1275 user_id: str, 

1276 query: str, 

1277 start: datetime, 

1278 end: datetime, 

1279 top_k: int = 5, 

1280 max_tokens: int | None = None, 

1281 lifecycle_filter: list[LifecycleState] | None = None, 

1282 namespace: str = "default", 

1283 session_id: str | None = None, 

1284 ) -> list[MemoryObject]: 

1285 """Recall memories created within a specific date range. 

1286 

1287 Args: 

1288 user_id: User ID to search for. 

1289 query: Search query. 

1290 start: Start datetime (inclusive). 

1291 end: End datetime (inclusive). 

1292 top_k: Maximum memories to return. 

1293 max_tokens: Token budget. 

1294 lifecycle_filter: Filter by lifecycle state. 

1295 namespace: Memory namespace. 

1296 session_id: Filter by session ID. 

1297 

1298 Returns: 

1299 List of MemoryObjects created within the date range. 

1300 """ 

1301 if not user_id or not user_id.strip(): 

1302 raise ValueError("user_id cannot be empty") 

1303 if not query or not query.strip(): 

1304 raise ValueError("query cannot be empty") 

1305 if top_k < 1: 

1306 raise ValueError(f"top_k must be at least 1, got {top_k}") 

1307 

1308 all_results = self.recall( 

1309 user_id=user_id, 

1310 query=query, 

1311 top_k=top_k * 3, 

1312 max_tokens=max_tokens, 

1313 lifecycle_filter=lifecycle_filter, 

1314 namespace=namespace, 

1315 session_id=session_id, 

1316 ) 

1317 

1318 filtered = [m for m in all_results if m.created_at and start <= m.created_at <= end] 

1319 return filtered[:top_k] 

1320 

1321 def recall_user_profile( 

1322 self, 

1323 user_id: str, 

1324 *, 

1325 top_k: int = 20, 

1326 namespace: str = "default", 

1327 ) -> list[MemoryObject]: 

1328 """Recall a user's long-lived profile — semantic facts and preferences. 

1329 

1330 Filters for SEMANTIC memories that are ACTIVE, DECAYING, or ARCHIVED, 

1331 then ranks by importance (highest first). This is the ergonomic 

1332 equivalent of: 

1333 

1334 .. code-block:: python 

1335 

1336 memories = memory.recall( 

1337 user_id="alice", 

1338 query="profile preferences facts", 

1339 top_k=20, 

1340 lifecycle_filter=[ACTIVE, DECAYING, ARCHIVED], 

1341 ) 

1342 

1343 Args: 

1344 user_id: User whose profile to retrieve. 

1345 top_k: Maximum number of profile facts to return. 

1346 namespace: Memory namespace. 

1347 

1348 Returns: 

1349 List of MemoryObjects sorted by importance descending. 

1350 """ 

1351 if not user_id or not user_id.strip(): 

1352 raise ValueError("user_id cannot be empty") 

1353 if top_k < 1: 

1354 raise ValueError(f"top_k must be at least 1, got {top_k}") 

1355 

1356 with self._latency_tracker("recall_user_profile"): 

1357 self._run_hooks("pre", "recall_user_profile", user_id=user_id, namespace=namespace) 

1358 

1359 all_memories = self._store.get_all_by_user( 

1360 user_id, 

1361 lifecycle_filter=[ 

1362 LifecycleState.ACTIVE, 

1363 LifecycleState.DECAYING, 

1364 LifecycleState.ARCHIVED, 

1365 ], 

1366 namespace=namespace, 

1367 ) 

1368 

1369 profile_memories = [ 

1370 m for m in all_memories if m.memory_type == MemoryType.SEMANTIC 

1371 ] 

1372 profile_memories.sort(key=lambda m: m.importance, reverse=True) 

1373 

1374 for mem in profile_memories[:top_k]: 

1375 mem.last_accessed_at = datetime.now(timezone.utc) 

1376 new_state = lifecycle.evaluate_lifecycle(mem, self._config.decay_threshold_hours) 

1377 if new_state != mem.lifecycle_state: 

1378 updated = lifecycle.transition(mem, new_state) 

1379 self._store.update(updated) 

1380 

1381 self._run_hooks( 

1382 "post", 

1383 "recall_user_profile", 

1384 user_id=user_id, 

1385 results=profile_memories[:top_k], 

1386 namespace=namespace, 

1387 ) 

1388 self._track_operation( 

1389 "recall_user_profile", 

1390 user_id, 

1391 {"results_count": len(profile_memories[:top_k])}, 

1392 namespace=namespace, 

1393 ) 

1394 return profile_memories[:top_k] 

1395 

1396 def recall_session_context( 

1397 self, 

1398 user_id: str, 

1399 session_id: str, 

1400 *, 

1401 top_k: int = 20, 

1402 namespace: str = "default", 

1403 ) -> list[MemoryObject]: 

1404 """Recall recent episodic memories scoped to a specific session. 

1405 

1406 Filters for EPISODIC memories that belong to the given *session_id*, 

1407 then ranks by recency (most recent first). This is the ergonomic 

1408 equivalent of: 

1409 

1410 .. code-block:: python 

1411 

1412 memories = memory.recall( 

1413 user_id="alice", 

1414 query="session context", 

1415 top_k=20, 

1416 lifecycle_filter=[ACTIVE, DECAYING], 

1417 session_id="sess_123", 

1418 ) 

1419 

1420 Args: 

1421 user_id: User whose session to retrieve. 

1422 session_id: Session identifier. 

1423 top_k: Maximum number of session memories to return. 

1424 namespace: Memory namespace. 

1425 

1426 Returns: 

1427 List of MemoryObjects sorted by created_at descending. 

1428 """ 

1429 if not user_id or not user_id.strip(): 

1430 raise ValueError("user_id cannot be empty") 

1431 if not session_id or not session_id.strip(): 

1432 raise ValueError("session_id cannot be empty") 

1433 if top_k < 1: 

1434 raise ValueError(f"top_k must be at least 1, got {top_k}") 

1435 

1436 with self._latency_tracker("recall_session_context"): 

1437 self._run_hooks( 

1438 "pre", "recall_session_context", user_id=user_id, session_id=session_id, namespace=namespace 

1439 ) 

1440 

1441 all_memories = self._store.get_all_by_user( 

1442 user_id, 

1443 lifecycle_filter=[ 

1444 LifecycleState.ACTIVE, 

1445 LifecycleState.DECAYING, 

1446 LifecycleState.ARCHIVED, 

1447 ], 

1448 namespace=namespace, 

1449 session_id=session_id, 

1450 ) 

1451 

1452 session_memories = [ 

1453 m for m in all_memories if m.memory_type == MemoryType.EPISODIC 

1454 ] 

1455 session_memories.sort(key=lambda m: m.created_at or datetime.min.replace(tzinfo=timezone.utc), reverse=True) 

1456 

1457 for mem in session_memories[:top_k]: 

1458 mem.last_accessed_at = datetime.now(timezone.utc) 

1459 new_state = lifecycle.evaluate_lifecycle(mem, self._config.decay_threshold_hours) 

1460 if new_state != mem.lifecycle_state: 

1461 updated = lifecycle.transition(mem, new_state) 

1462 self._store.update(updated) 

1463 

1464 self._run_hooks( 

1465 "post", 

1466 "recall_session_context", 

1467 user_id=user_id, 

1468 session_id=session_id, 

1469 results=session_memories[:top_k], 

1470 namespace=namespace, 

1471 ) 

1472 self._track_operation( 

1473 "recall_session_context", 

1474 user_id, 

1475 {"results_count": len(session_memories[:top_k]), "session_id": session_id}, 

1476 namespace=namespace, 

1477 ) 

1478 return session_memories[:top_k] 

1479 

1480 def recall_agent_knowledge( 

1481 self, 

1482 agent_id: str, 

1483 *, 

1484 namespace: str = "default", 

1485 top_k: int = 50, 

1486 ) -> list[MemoryObject]: 

1487 """Recall memories that belong to a specific agent. 

1488 

1489 Scans across all users (via :meth:`list_users`) and returns the 

1490 agent's most important memories in the given namespace. 

1491 

1492 Args: 

1493 agent_id: Agent identifier to filter by. 

1494 namespace: Memory namespace. 

1495 top_k: Maximum number of memories to return. 

1496 

1497 Returns: 

1498 List of MemoryObjects sorted by importance descending. 

1499 """ 

1500 if not agent_id or not agent_id.strip(): 

1501 raise ValueError("agent_id cannot be empty") 

1502 if top_k < 1: 

1503 raise ValueError(f"top_k must be at least 1, got {top_k}") 

1504 

1505 with self._latency_tracker("recall_agent_knowledge"): 

1506 self._run_hooks( 

1507 "pre", "recall_agent_knowledge", agent_id=agent_id, namespace=namespace 

1508 ) 

1509 

1510 all_users = self._store.get_all_users() 

1511 agent_memories: list[MemoryObject] = [] 

1512 for uid in all_users: 

1513 user_memories = self._store.get_all_by_user( 

1514 uid, 

1515 lifecycle_filter=[ 

1516 LifecycleState.ACTIVE, 

1517 LifecycleState.DECAYING, 

1518 LifecycleState.ARCHIVED, 

1519 ], 

1520 namespace=namespace, 

1521 ) 

1522 for mem in user_memories: 

1523 if mem.agent_id == agent_id: 

1524 agent_memories.append(mem) 

1525 

1526 agent_memories.sort(key=lambda m: m.importance, reverse=True) 

1527 

1528 for mem in agent_memories[:top_k]: 

1529 mem.last_accessed_at = datetime.now(timezone.utc) 

1530 new_state = lifecycle.evaluate_lifecycle(mem, self._config.decay_threshold_hours) 

1531 if new_state != mem.lifecycle_state: 

1532 updated = lifecycle.transition(mem, new_state) 

1533 self._store.update(updated) 

1534 

1535 self._run_hooks( 

1536 "post", 

1537 "recall_agent_knowledge", 

1538 agent_id=agent_id, 

1539 results=agent_memories[:top_k], 

1540 namespace=namespace, 

1541 ) 

1542 self._track_operation( 

1543 "recall_agent_knowledge", 

1544 "all", 

1545 {"results_count": len(agent_memories[:top_k]), "agent_id": agent_id}, 

1546 namespace=namespace, 

1547 ) 

1548 return agent_memories[:top_k] 

1549 

1550 def recall_explain( 

1551 self, 

1552 user_id: str, 

1553 query: str, 

1554 top_k: int = 5, 

1555 namespace: str = "default", 

1556 session_id: str | None = None, 

1557 ) -> list[dict[str, Any]]: 

1558 """Recall memories with detailed score breakdowns. 

1559 

1560 Returns each memory with an 'explanation' dict showing: 

1561 - semantic_score: cosine similarity contribution 

1562 - recency_score: temporal recency contribution 

1563 - bm25_score: keyword match contribution (if hybrid) 

1564 - importance_score: importance contribution 

1565 - final_score: the combined score 

1566 

1567 Args: 

1568 user_id: User ID. 

1569 query: Search query. 

1570 top_k: Max results. 

1571 namespace: Memory namespace. 

1572 session_id: Filter by session. 

1573 

1574 Returns: 

1575 List of dicts with 'memory' (MemoryObject) and 'explanation'. 

1576 """ 

1577 if not user_id or not user_id.strip(): 

1578 raise ValueError("user_id cannot be empty") 

1579 if not query or not query.strip(): 

1580 raise ValueError("query cannot be empty") 

1581 if top_k < 1: 

1582 raise ValueError(f"top_k must be at least 1, got {top_k}") 

1583 

1584 query_embedding = self._embed.embed_single(query) 

1585 

1586 search_results = self._store.search( 

1587 user_id=user_id, 

1588 query_embedding=query_embedding, 

1589 top_k=top_k * 3, 

1590 lifecycle_filter=lifecycle.get_recall_filter(), 

1591 namespace=namespace, 

1592 session_id=session_id, 

1593 ) 

1594 

1595 corpus = [m.content for m in search_results] if len(search_results) > 1 else None 

1596 

1597 query_entities_explain: set[str] | None = None 

1598 memory_entities_map_explain: dict[str, set[str]] | None = None 

1599 if self._config.enable_entity_boost: 

1600 query_entities_explain = self._entity_linker.extract(query) 

1601 memory_entities_map_explain = {} 

1602 for m in search_results: 

1603 cached = m.metadata.get("extracted_entities") 

1604 if cached is not None: 

1605 memory_entities_map_explain[m.memory_id] = set(cached) 

1606 else: 

1607 memory_entities_map_explain[m.memory_id] = self._entity_linker.extract(m.content) 

1608 

1609 explained: list[dict[str, Any]] = [] 

1610 for memory in search_results: 

1611 semantic = scoring.cosine_similarity(memory.embedding, query_embedding) 

1612 semantic_norm = (semantic + 1.0) / 2.0 

1613 recency = scoring.temporal_recency(memory.last_accessed_at) 

1614 

1615 entity_score = 0.0 

1616 if ( 

1617 self._config.enable_entity_boost 

1618 and query_entities_explain is not None 

1619 and memory_entities_map_explain is not None 

1620 ): 

1621 mem_entities = memory_entities_map_explain.get(memory.memory_id, set()) 

1622 entity_score = scoring.jaccard_similarity(query_entities_explain, mem_entities) 

1623 

1624 if self._config.hybrid_search: 

1625 if corpus and len(corpus) > 1: 

1626 bm25 = scoring.bm25_score_corpus(query, memory.content, corpus) 

1627 else: 

1628 bm25 = scoring.bm25_score(query, memory.content) 

1629 final = ( 

1630 semantic_norm * self._config.weight_semantic 

1631 + recency * self._config.weight_recency 

1632 + bm25 * self._config.weight_bm25 

1633 + entity_score * self._config.entity_boost_weight 

1634 ) 

1635 explanation = { 

1636 "semantic_score": round(semantic_norm, 4), 

1637 "recency_score": round(recency, 4), 

1638 "bm25_score": round(bm25, 4), 

1639 "importance_score": None, 

1640 "entity_score": round(entity_score, 4), 

1641 "final_score": round(final, 4), 

1642 "weights": { 

1643 "semantic": self._config.weight_semantic, 

1644 "recency": self._config.weight_recency, 

1645 "bm25": self._config.weight_bm25, 

1646 "entity": self._config.entity_boost_weight, 

1647 }, 

1648 } 

1649 else: 

1650 importance = max(0.0, min(1.0, memory.importance)) 

1651 final = ( 

1652 semantic_norm * self._config.weight_semantic_no_embed 

1653 + recency * self._config.weight_recency_no_embed 

1654 + importance * self._config.weight_importance 

1655 + entity_score * self._config.entity_boost_weight 

1656 ) 

1657 explanation = { 

1658 "semantic_score": round(semantic_norm, 4), 

1659 "recency_score": round(recency, 4), 

1660 "bm25_score": None, 

1661 "importance_score": round(importance, 4), 

1662 "entity_score": round(entity_score, 4), 

1663 "final_score": round(final, 4), 

1664 "weights": { 

1665 "semantic": self._config.weight_semantic_no_embed, 

1666 "recency": self._config.weight_recency_no_embed, 

1667 "importance": self._config.weight_importance, 

1668 "entity": self._config.entity_boost_weight, 

1669 }, 

1670 } 

1671 

1672 memory.score = final 

1673 explained.append({"memory": memory, "explanation": explanation}) 

1674 

1675 explained.sort(key=lambda x: x["explanation"]["final_score"], reverse=True) 

1676 return explained[:top_k] 

1677 

1678 def consolidate( 

1679 self, 

1680 user_id: str, 

1681 namespace: str = "default", 

1682 min_memories: int = 5, 

1683 max_age_days: float = 30.0, 

1684 with_llm_summary: bool = False, 

1685 ) -> str | None: 

1686 """Consolidate old episodic memories into a semantic summary. 

1687 

1688 Uses local extractive summarization (no LLM required) by default. 

1689 When ``with_llm_summary=True``, uses LLM-powered abstractive 

1690 summarization via the configured provider (see ``MemoryConfig``). 

1691 

1692 Finds clusters of related old memories, generates a summary for 

1693 each, stores it as a SEMANTIC memory, and archives the old ones. 

1694 

1695 Args: 

1696 user_id: User to consolidate. 

1697 namespace: Memory namespace. 

1698 min_memories: Minimum memories needed to form a cluster. 

1699 max_age_days: Only consider memories older than this. 

1700 with_llm_summary: If True, use LLM-powered abstractive summary. 

1701 

1702 Returns: 

1703 Memory ID of the consolidated summary, or None if no consolidation occurred. 

1704 """ 

1705 try: 

1706 from kemi import consolidation 

1707 except ImportError: # pragma: no cover 

1708 logger.warning("consolidation module not available") 

1709 return None 

1710 

1711 mid = consolidation.consolidate( 

1712 store=self._store, 

1713 embed=self._embed, 

1714 user_id=user_id, 

1715 namespace=namespace, 

1716 min_memories=min_memories, 

1717 max_age_days=max_age_days, 

1718 with_llm_summary=with_llm_summary, 

1719 summarizer_llm_provider=self._config.summarizer_llm_provider, 

1720 summarizer_llm_model=self._config.summarizer_llm_model, 

1721 summarizer_prompt_template=self._config.summarizer_prompt_template, 

1722 ) 

1723 if mid is not None: 

1724 self._dispatch_webhook_event( 

1725 WebhookEventType.CONSOLIDATED, 

1726 memory_id=mid, 

1727 user_id=user_id, 

1728 ) 

1729 return mid 

1730 

1731 def cluster_topics( 

1732 self, 

1733 user_id: str, 

1734 n_clusters: int = 3, 

1735 namespace: str = "default", 

1736 ) -> dict[str, list[MemoryObject]]: 

1737 """Cluster memories into topic groups using embeddings. 

1738 

1739 Requires scikit-learn to be installed. 

1740 

1741 Args: 

1742 user_id: User ID. 

1743 n_clusters: Number of topic clusters. 

1744 namespace: Memory namespace. 

1745 

1746 Returns: 

1747 Dict mapping topic labels to lists of memories. 

1748 """ 

1749 try: 

1750 from kemi import topics 

1751 except ImportError: # pragma: no cover 

1752 logger.warning("topics module not available") 

1753 return {} 

1754 

1755 return topics.cluster_memories( 

1756 store=self._store, 

1757 user_id=user_id, 

1758 n_clusters=n_clusters, 

1759 namespace=namespace, 

1760 ) 

1761 

1762 def extract_entities(self, memory_id: str) -> list[dict[str, Any]]: 

1763 """Extract named entities from a memory's content. 

1764 

1765 Uses regex/heuristic-based extraction (no external NER model required). 

1766 

1767 Args: 

1768 memory_id: Memory ID. 

1769 

1770 Returns: 

1771 List of entity dicts with 'text', 'label', 'start', 'end'. 

1772 """ 

1773 try: 

1774 from kemi import graph 

1775 except ImportError: # pragma: no cover 

1776 logger.warning("graph module not available") 

1777 return [] 

1778 

1779 memory = self._store.get(memory_id) 

1780 if memory is None: 

1781 raise ValueError(f"Memory not found: {memory_id}") 

1782 

1783 return graph.extract_entities(memory.content) 

1784 

1785 def get_memory_graph( 

1786 self, 

1787 user_id: str, 

1788 namespace: str = "default", 

1789 ) -> dict[str, Any]: 

1790 """Build a memory graph of entities and relations. 

1791 

1792 Args: 

1793 user_id: User ID. 

1794 namespace: Memory namespace. 

1795 

1796 Returns: 

1797 Dict with 'entities' (list) and 'relations' (list of dicts). 

1798 """ 

1799 try: 

1800 from kemi import graph 

1801 except ImportError: # pragma: no cover 

1802 logger.warning("graph module not available") 

1803 return {"entities": [], "relations": []} 

1804 

1805 return graph.build_memory_graph( 

1806 store=self._store, 

1807 user_id=user_id, 

1808 namespace=namespace, 

1809 ) 

1810 

1811 def stats( 

1812 self, 

1813 user_id: str, 

1814 lifecycle_filter: list[LifecycleState] | None = None, 

1815 session_id: str | None = None, 

1816 ) -> dict[str, Any]: 

1817 """Return health statistics for a user's memory store. 

1818 

1819 Args: 

1820 user_id: The user whose memories to analyze. 

1821 lifecycle_filter: Optional list of lifecycle states to filter by. 

1822 session_id: Optional session ID to filter by. 

1823 

1824 Returns a dict with these keys: 

1825 total: int - total number of memories 

1826 by_lifecycle: dict - count per lifecycle state 

1827 e.g. {"active": 10, "decaying": 3, "archived": 1, "deleted": 0} 

1828 by_source: dict - count per memory source 

1829 e.g. {"user_stated": 8, "agent_inferred": 5} 

1830 avg_importance: float - average importance score (0.0 if no memories) 

1831 tag_counts: dict - how many memories each tag appears in 

1832 e.g. {"food": 3, "work": 7} 

1833 total_with_tags: int - number of memories that have at least one tag 

1834 total_without_tags: int - number of memories with no tags 

1835 """ 

1836 if not user_id or not user_id.strip(): 

1837 raise ValueError("user_id cannot be empty") 

1838 

1839 all_memories = self._store.get_all_by_user( 

1840 user_id, lifecycle_filter=lifecycle_filter, session_id=session_id 

1841 ) 

1842 

1843 by_lifecycle = {state.value: 0 for state in LifecycleState} 

1844 by_source = {source.value: 0 for source in MemorySource} 

1845 tag_counts: dict[str, int] = {} 

1846 total_with_tags = 0 

1847 total_importance = 0.0 

1848 

1849 for mem in all_memories: 

1850 by_lifecycle[mem.lifecycle_state.value] += 1 

1851 by_source[mem.source.value] += 1 

1852 total_importance += mem.importance 

1853 

1854 if mem.tags: 

1855 total_with_tags += 1 

1856 for tag in mem.tags: 

1857 tag_counts[tag] = tag_counts.get(tag, 0) + 1 

1858 

1859 total = len(all_memories) 

1860 avg_importance = total_importance / total if total > 0 else 0.0 

1861 total_without_tags = total - total_with_tags 

1862 

1863 return { 

1864 "total": total, 

1865 "by_lifecycle": by_lifecycle, 

1866 "by_source": by_source, 

1867 "avg_importance": avg_importance, 

1868 "tag_counts": tag_counts, 

1869 "total_with_tags": total_with_tags, 

1870 "total_without_tags": total_without_tags, 

1871 } 

1872 

1873 async def astats(self, user_id: str) -> dict[str, Any]: 

1874 """Async version of stats().""" 

1875 import asyncio 

1876 

1877 return await asyncio.to_thread(self.stats, user_id) 

1878 

1879 def recall_by_tag( 

1880 self, 

1881 user_id: str, 

1882 tag: str, 

1883 lifecycle_filter: list[LifecycleState] | None = None, 

1884 ) -> list[MemoryObject]: 

1885 """Recall memories by tag. 

1886 

1887 Args: 

1888 user_id: User ID to search for. 

1889 tag: Tag to filter by. 

1890 lifecycle_filter: Filter by lifecycle state. 

1891 

1892 Returns: 

1893 List of MemoryObjects with the specified tag. 

1894 """ 

1895 if not user_id or not user_id.strip(): 

1896 raise ValueError("user_id cannot be empty") 

1897 if not tag or not tag.strip(): 

1898 raise ValueError("tag cannot be empty") 

1899 

1900 return self._store.get_by_tag(user_id, tag, lifecycle_filter) 

1901 

1902 async def arecall_by_tag( 

1903 self, 

1904 user_id: str, 

1905 tag: str, 

1906 lifecycle_filter: list[LifecycleState] | None = None, 

1907 ) -> list[MemoryObject]: 

1908 """Async version of recall_by_tag().""" 

1909 import asyncio 

1910 

1911 return await asyncio.to_thread(self.recall_by_tag, user_id, tag, lifecycle_filter) 

1912 

1913 def update( 

1914 self, 

1915 memory_id: str, 

1916 content: str | None = None, 

1917 importance: float | None = None, 

1918 confidence: float | None = None, 

1919 memory_type: MemoryType | None = None, 

1920 metadata: dict[str, Any] | None = None, 

1921 tags: list[str] | None = None, 

1922 ) -> str: 

1923 """Update an existing memory. 

1924 

1925 Args: 

1926 memory_id: ID of memory to update. 

1927 content: New content (if provided, will re-embed). 

1928 importance: New importance value (0.0-1.0). 

1929 confidence: New confidence value (0.0-1.0). 

1930 memory_type: New memory type. 

1931 metadata: Metadata dict to merge into existing metadata. 

1932 tags: New tags to replace existing tags. 

1933 

1934 Returns: 

1935 The memory_id of updated memory. 

1936 

1937 Raises: 

1938 ValueError: If memory_id not found. 

1939 """ 

1940 if ( 

1941 content is None 

1942 and importance is None 

1943 and confidence is None 

1944 and memory_type is None 

1945 and metadata is None 

1946 and tags is None 

1947 ): 

1948 return memory_id 

1949 

1950 with self._latency_tracker("update"): 

1951 self._run_hooks("pre", "update", memory_id=memory_id) 

1952 memory = self._store.get(memory_id) 

1953 if memory is None: 

1954 raise ValueError(f"Memory not found: {memory_id}") 

1955 

1956 # Capture pre-update state BEFORE mutating the memory object. 

1957 # record_before_update expects two separate objects (before, after). 

1958 memory_before = MemoryObject( 

1959 memory_id=memory.memory_id, 

1960 user_id=memory.user_id, 

1961 content=memory.content, 

1962 embedding=memory.embedding, 

1963 score=memory.score, 

1964 created_at=memory.created_at, 

1965 last_accessed_at=memory.last_accessed_at, 

1966 source=memory.source, 

1967 importance=memory.importance, 

1968 lifecycle_state=memory.lifecycle_state, 

1969 metadata=memory.metadata.copy() if memory.metadata else {}, 

1970 embedding_dim=memory.embedding_dim, 

1971 tags=list(memory.tags) if memory.tags else [], 

1972 confidence=memory.confidence, 

1973 memory_type=memory.memory_type, 

1974 session_id=memory.session_id, 

1975 namespace=memory.namespace, 

1976 version=memory.version, 

1977 agent_id=memory.agent_id, 

1978 run_id=memory.run_id, 

1979 app_id=memory.app_id, 

1980 ) 

1981 

1982 # Now apply all mutations 

1983 if content is not None: 

1984 memory.content = content 

1985 memory.embedding = self._embed.embed_single(content) 

1986 memory.embedding_dim = len(memory.embedding) 

1987 memory.last_accessed_at = datetime.now(timezone.utc) 

1988 if self._config.enable_entity_boost: 

1989 memory.metadata["extracted_entities"] = list( 

1990 self._entity_linker.extract(content) 

1991 ) 

1992 

1993 if importance is not None: 

1994 memory.importance = max(0.0, min(1.0, importance)) 

1995 

1996 if confidence is not None: 

1997 memory.confidence = max(0.0, min(1.0, confidence)) 

1998 

1999 if memory_type is not None: 

2000 memory.memory_type = memory_type 

2001 

2002 if metadata is not None: 

2003 memory.metadata.update(metadata) 

2004 

2005 if tags is not None: 

2006 memory.tags = tags 

2007 

2008 # Record version BEFORE and AFTER the update (pre + post snapshot) 

2009 try: 

2010 vs = self._get_version_store() 

2011 vs.record_before_update(memory_before, memory, changed_by="update") 

2012 self._auto_prune_versions_for_memory(memory_id) 

2013 except (RuntimeError, Exception): 

2014 pass # Versioning is optional 

2015 

2016 previous_state = _memory_to_dict(memory) 

2017 memory.version += 1 

2018 

2019 snapshot = _memory_to_dict(memory) 

2020 self._dispatch_webhook_event( 

2021 WebhookEventType.UPDATED, 

2022 memory_id=memory_id, 

2023 user_id=memory.user_id, 

2024 snapshot=snapshot, 

2025 previous_state=previous_state, 

2026 ) 

2027 

2028 self._store.update(memory) 

2029 self._run_hooks("post", "update", memory_id=memory_id, version=memory.version) 

2030 logger.info(f"Updated memory: {memory_id} (version {memory.version})") 

2031 self._track_operation( 

2032 "update", 

2033 memory.user_id, 

2034 {"memory_id": memory_id, "version": memory.version}, 

2035 memory_id, 

2036 memory.namespace, 

2037 ) 

2038 return memory_id 

2039 

2040 def recall_since( 

2041 self, 

2042 user_id: str, 

2043 query: str, 

2044 hours: float = 24.0, 

2045 top_k: int = 5, 

2046 max_tokens: int | None = None, 

2047 lifecycle_filter: list[LifecycleState] | None = None, 

2048 ) -> list[MemoryObject]: 

2049 """Recall memories created in the last N hours. 

2050 

2051 Args: 

2052 user_id: User ID to search for. 

2053 query: Search query. 

2054 hours: Only return memories created in last N hours. 

2055 top_k: Maximum memories to return. 

2056 max_tokens: Token budget for context_block. 

2057 lifecycle_filter: Filter by lifecycle state. 

2058 

2059 Returns: 

2060 List of MemoryObjects. 

2061 """ 

2062 from datetime import timedelta 

2063 

2064 cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) 

2065 

2066 all_results = self.recall( 

2067 user_id=user_id, 

2068 query=query, 

2069 top_k=top_k * 3, 

2070 max_tokens=max_tokens, 

2071 lifecycle_filter=lifecycle_filter, 

2072 ) 

2073 

2074 filtered = [m for m in all_results if m.created_at and m.created_at >= cutoff] 

2075 return filtered[:top_k] 

2076 

2077 async def alist_users(self) -> list[str]: 

2078 """Async version of list_users().""" 

2079 import asyncio 

2080 

2081 return await asyncio.to_thread(self.list_users) 

2082 

2083 async def aupdate( 

2084 self, 

2085 memory_id: str, 

2086 content: str | None = None, 

2087 importance: float | None = None, 

2088 confidence: float | None = None, 

2089 memory_type: MemoryType | None = None, 

2090 metadata: dict[str, Any] | None = None, 

2091 ) -> str: 

2092 """Async version of update().""" 

2093 import asyncio 

2094 

2095 return await asyncio.to_thread( 

2096 self.update, 

2097 memory_id, 

2098 content, 

2099 importance, 

2100 confidence, 

2101 memory_type, 

2102 metadata, 

2103 ) 

2104 

2105 async def aupdate_many( 

2106 self, 

2107 memory_ids: list[str], 

2108 content: str | None = None, 

2109 importance: float | None = None, 

2110 confidence: float | None = None, 

2111 memory_type: MemoryType | None = None, 

2112 metadata: dict[str, Any] | None = None, 

2113 ) -> list[str]: 

2114 """Async version of update_many() — runs updates concurrently. 

2115 

2116 Small batches (≤10) use individual concurrent aupdate calls via 

2117 asyncio.gather. Larger batches fall back to a single threaded 

2118 update_many to avoid excessive thread overhead. 

2119 """ 

2120 import asyncio 

2121 

2122 if not memory_ids: 

2123 return [] 

2124 

2125 if len(memory_ids) <= 10: 

2126 tasks = [ 

2127 self.aupdate( 

2128 mid, 

2129 content=content, 

2130 importance=importance, 

2131 confidence=confidence, 

2132 memory_type=memory_type, 

2133 metadata=metadata, 

2134 ) 

2135 for mid in memory_ids 

2136 ] 

2137 return await asyncio.gather(*tasks) 

2138 

2139 return await asyncio.to_thread( 

2140 self.update_many, 

2141 memory_ids, 

2142 content, 

2143 importance, 

2144 confidence, 

2145 memory_type, 

2146 metadata, 

2147 ) 

2148 

2149 async def aforget_many( 

2150 self, 

2151 memory_ids: list[str], 

2152 ) -> int: 

2153 """Async version of forget_many() — runs deletions concurrently.""" 

2154 import asyncio 

2155 

2156 if not memory_ids: 

2157 return 0 

2158 

2159 tasks = [asyncio.to_thread(self._store.delete_by_id, mid) for mid in memory_ids] 

2160 results = await asyncio.gather(*tasks) 

2161 return sum(1 for r in results if r) 

2162 

2163 async def arecall_many( 

2164 self, 

2165 user_ids: list[str], 

2166 queries: list[str], 

2167 top_k: int = 5, 

2168 max_tokens: int | None = None, 

2169 lifecycle_filter: list[LifecycleState] | None = None, 

2170 hybrid_search: bool | None = None, 

2171 namespace: str = "default", 

2172 session_id: str | None = None, 

2173 metadata_filter: dict[str, Any] | None = None, 

2174 ) -> dict[str, list[MemoryObject]]: 

2175 """Async version of recall_many() — runs individual recalls concurrently.""" 

2176 import asyncio 

2177 

2178 if len(user_ids) != len(queries): 

2179 raise ValueError("user_ids and queries must have the same length") 

2180 

2181 tasks = [ 

2182 self.arecall( 

2183 uid, 

2184 q, 

2185 top_k=top_k, 

2186 max_tokens=max_tokens, 

2187 lifecycle_filter=lifecycle_filter, 

2188 hybrid_search=hybrid_search, 

2189 namespace=namespace, 

2190 session_id=session_id, 

2191 metadata_filter=metadata_filter, 

2192 ) 

2193 for uid, q in zip(user_ids, queries, strict=True) 

2194 ] 

2195 results = await asyncio.gather(*tasks) 

2196 return {uid: res for uid, res in zip(user_ids, results, strict=True)} 

2197 

2198 async def arecall_since( 

2199 self, 

2200 user_id: str, 

2201 query: str, 

2202 hours: float = 24.0, 

2203 top_k: int = 5, 

2204 max_tokens: int | None = None, 

2205 lifecycle_filter: list[LifecycleState] | None = None, 

2206 ) -> list[MemoryObject]: 

2207 """Async version of recall_since().""" 

2208 import asyncio 

2209 

2210 return await asyncio.to_thread( 

2211 self.recall_since, user_id, query, hours, top_k, max_tokens, lifecycle_filter 

2212 ) 

2213 

2214 async def aremember_many( 

2215 self, 

2216 user_id: str, 

2217 contents: list[str], 

2218 importance: float = 0.5, 

2219 source: MemorySource = MemorySource.USER_STATED, 

2220 tags: list[str] | None = None, 

2221 namespace: str = "default", 

2222 session_id: str | None = None, 

2223 memory_type: MemoryType = MemoryType.EPISODIC, 

2224 confidence: float = 1.0, 

2225 agent_id: str | None = None, 

2226 run_id: str | None = None, 

2227 app_id: str | None = None, 

2228 ) -> list[str]: 

2229 """Async version of remember_many().""" 

2230 import asyncio 

2231 

2232 return await asyncio.to_thread( 

2233 self.remember_many, 

2234 user_id, 

2235 contents, 

2236 importance, 

2237 source, 

2238 tags, 

2239 namespace, 

2240 session_id, 

2241 memory_type, 

2242 confidence, 

2243 agent_id, 

2244 run_id, 

2245 app_id, 

2246 ) 

2247 

2248 def feedback( 

2249 self, 

2250 user_id: str, 

2251 memory_id: str, 

2252 helpful: bool = True, 

2253 namespace: str = "default", 

2254 ) -> None: 

2255 """Record user feedback on a recalled memory. 

2256 

2257 Stores feedback in metadata and adjusts importance: 

2258 - helpful=True: boosts importance slightly (up to 1.0) 

2259 - helpful=False: reduces importance slightly (down to 0.0) 

2260 

2261 Args: 

2262 user_id: User ID. 

2263 memory_id: Memory ID that was recalled. 

2264 helpful: Whether the memory was helpful. 

2265 namespace: Memory namespace. 

2266 """ 

2267 if not user_id or not user_id.strip(): 

2268 raise ValueError("user_id cannot be empty") 

2269 if not memory_id or not memory_id.strip(): 

2270 raise ValueError("memory_id cannot be empty") 

2271 

2272 memory = self._store.get(memory_id) 

2273 if memory is None: 

2274 raise ValueError(f"Memory not found: {memory_id}") 

2275 if memory.user_id != user_id: 

2276 raise ValueError("Memory does not belong to this user") 

2277 if memory.namespace != namespace: 

2278 raise ValueError("Memory does not belong to this namespace") 

2279 

2280 # Record feedback history 

2281 feedback_entry = { 

2282 "helpful": helpful, 

2283 "timestamp": datetime.now(timezone.utc).isoformat(), 

2284 } 

2285 if "feedback" not in memory.metadata: 

2286 memory.metadata["feedback"] = [] 

2287 memory.metadata["feedback"].append(feedback_entry) 

2288 

2289 # Adjust importance based on feedback 

2290 adjustment = 0.05 if helpful else -0.05 

2291 memory.importance = max(0.0, min(1.0, memory.importance + adjustment)) 

2292 

2293 self._store.update(memory) 

2294 logger.info( 

2295 f"Feedback recorded for memory {memory_id}: helpful={helpful}, " 

2296 f"new_importance={memory.importance:.2f}" 

2297 ) 

2298 self._track_operation( 

2299 "feedback", user_id, {"memory_id": memory_id, "helpful": helpful}, memory_id, namespace 

2300 ) 

2301 

2302 def backfill_entities( 

2303 self, 

2304 user_id: str | None = None, 

2305 namespace: str | None = None, 

2306 ) -> int: 

2307 """Backfill ``extracted_entities`` for memories that don't have them yet. 

2308 

2309 Iterates over memories (optionally filtered by *user_id* and 

2310 *namespace*), extracts entities from their content using the 

2311 configured :attr:`_entity_linker`, and persists the result in 

2312 ``memory.metadata["extracted_entities"]``. 

2313 

2314 This is useful after enabling entity boost on an existing store 

2315 so that subsequent recall calls can read cached entities instead 

2316 of falling back to on-the-fly extraction. 

2317 

2318 Args: 

2319 user_id: If provided, only backfill this user's memories. 

2320 If ``None``, backfill across all users. 

2321 namespace: If provided, only backfill memories in this 

2322 namespace. If ``None``, backfill across all namespaces. 

2323 

2324 Returns: 

2325 Number of memories that were backfilled. 

2326 """ 

2327 if not self._config.enable_entity_boost: 

2328 logger.info("Entity boost is disabled; skipping entity backfill") 

2329 return 0 

2330 

2331 with self._latency_tracker("backfill_entities"): 

2332 if user_id is not None: 

2333 users = [user_id] 

2334 else: 

2335 users = self._store.get_all_users() 

2336 

2337 backfilled = 0 

2338 for uid in users: 

2339 if namespace is not None: 

2340 namespaces = [namespace] 

2341 else: 

2342 namespaces = self._known_namespaces(uid) 

2343 

2344 for ns in namespaces: 

2345 memories = self._store.get_all_by_user( 

2346 uid, 

2347 lifecycle_filter=[ 

2348 LifecycleState.ACTIVE, 

2349 LifecycleState.DECAYING, 

2350 LifecycleState.ARCHIVED, 

2351 ], 

2352 namespace=ns, 

2353 ) 

2354 for mem in memories: 

2355 if "extracted_entities" in mem.metadata: 

2356 continue 

2357 entities = self._entity_linker.extract(mem.content) 

2358 mem.metadata["extracted_entities"] = list(entities) 

2359 self._store.update(mem) 

2360 backfilled += 1 

2361 

2362 if backfilled > 0: 

2363 logger.info( 

2364 f"Backfilled extracted_entities for {backfilled} memories" 

2365 + (f" (user={user_id})" if user_id else "") 

2366 ) 

2367 self._track_operation( 

2368 "backfill_entities", 

2369 user_id or "all", 

2370 {"backfilled": backfilled}, 

2371 namespace=namespace or "all", 

2372 ) 

2373 return backfilled 

2374 

2375 async def abackfill_entities( 

2376 self, 

2377 user_id: str | None = None, 

2378 namespace: str | None = None, 

2379 ) -> int: 

2380 """Async version of :meth:`backfill_entities`.""" 

2381 import asyncio 

2382 

2383 return await asyncio.to_thread( 

2384 self.backfill_entities, user_id, namespace 

2385 ) 

2386 

2387 def run_maintenance( 

2388 self, 

2389 user_id: str, 

2390 namespace: str = "default", 

2391 auto_prune: bool = True, 

2392 auto_consolidate: bool = True, 

2393 auto_backfill_entities: bool = False, 

2394 prune_max_age_days: float = 90.0, 

2395 prune_min_importance: float = 0.1, 

2396 consolidate_min_memories: int = 5, 

2397 consolidate_max_age_days: float = 30.0, 

2398 auto_prune_expired: bool = True, 

2399 consolidate_with_llm_summary: bool = False, 

2400 ) -> dict[str, Any]: 

2401 """Run automatic maintenance tasks for a user's memories. 

2402 

2403 This is a one-shot maintenance run. For periodic maintenance, 

2404 call this method from a scheduler (e.g., cron, APScheduler). 

2405 

2406 Args: 

2407 user_id: User ID to maintain. 

2408 namespace: Memory namespace. 

2409 auto_prune: Whether to prune old memories. 

2410 auto_consolidate: Whether to consolidate old episodic memories. 

2411 auto_backfill_entities: Whether to backfill missing 

2412 ``extracted_entities`` metadata. 

2413 prune_max_age_days: Delete DECAYING memories older than this. 

2414 prune_min_importance: Delete memories below this importance. 

2415 consolidate_min_memories: Minimum memories to consolidate. 

2416 consolidate_max_age_days: Only consolidate memories older than this. 

2417 auto_prune_expired: Whether to delete TTL-expired memories. 

2418 consolidate_with_llm_summary: Use LLM-powered summarization. 

2419 

2420 Returns: 

2421 Dict with 'pruned' (int), 'expired' (int), 

2422 'consolidated' (str | None), and 'backfilled' (int) keys. 

2423 """ 

2424 if not user_id or not user_id.strip(): 

2425 raise ValueError("user_id cannot be empty") 

2426 

2427 with self._latency_tracker("run_maintenance"): 

2428 result: dict[str, Any] = { 

2429 "pruned": 0, 

2430 "expired": 0, 

2431 "consolidated": None, 

2432 "backfilled": 0, 

2433 } 

2434 

2435 if auto_backfill_entities: 

2436 backfilled = self.backfill_entities( 

2437 user_id=user_id, namespace=namespace 

2438 ) 

2439 result["backfilled"] = backfilled 

2440 

2441 if auto_prune: 

2442 pruned = self.prune( 

2443 user_id=user_id, 

2444 max_age_days=prune_max_age_days, 

2445 min_importance=prune_min_importance, 

2446 namespace=namespace, 

2447 ) 

2448 result["pruned"] = pruned 

2449 

2450 if auto_prune_expired: 

2451 expired = self.prune_expired( 

2452 user_id=user_id, namespace=namespace 

2453 ) 

2454 result["expired"] = expired 

2455 

2456 if auto_consolidate: 

2457 consolidated_id = self.consolidate( 

2458 user_id=user_id, 

2459 namespace=namespace, 

2460 min_memories=consolidate_min_memories, 

2461 max_age_days=consolidate_max_age_days, 

2462 with_llm_summary=consolidate_with_llm_summary, 

2463 ) 

2464 result["consolidated"] = consolidated_id 

2465 

2466 logger.info(f"Maintenance complete for {user_id}: {result}") 

2467 self._track_operation("run_maintenance", user_id, result, namespace=namespace) 

2468 return result 

2469 

2470 def get_metrics(self) -> dict[str, Any] | None: 

2471 """Return current metrics as a dictionary.""" 

2472 if self._metrics is None: 

2473 return None 

2474 return self._metrics.to_dict() # type: ignore[no-any-return] 

2475 

2476 def get_metrics_prometheus(self) -> str | None: 

2477 """Return current metrics in Prometheus text format.""" 

2478 if self._metrics is None: 

2479 return None 

2480 return self._metrics.to_prometheus() # type: ignore[no-any-return] 

2481 

2482 def enable_adaptive_retrieval(self, enable: bool = True) -> None: 

2483 """Enable or disable adaptive retrieval.""" 

2484 if enable: 

2485 try: 

2486 from kemi.adaptive import AdaptiveRetriever 

2487 

2488 self._adaptive_retriever = AdaptiveRetriever() 

2489 except ImportError: 

2490 logger.warning("Adaptive retrieval module not available") 

2491 else: 

2492 self._adaptive_retriever = None 

2493 

2494 def _track_operation( 

2495 self, 

2496 operation: str, 

2497 user_id: str, 

2498 details: dict[str, Any] | None = None, 

2499 memory_id: str | None = None, 

2500 namespace: str = "default", 

2501 status: str = "success", 

2502 audit_batch: list[dict[str, Any]] | None = None, 

2503 ) -> None: 

2504 """Track an operation in metrics and audit trail.""" 

2505 from kemi.operations import _ops_metrics 

2506 _ops_metrics.track_operation_full( 

2507 self, operation, user_id, details, memory_id, namespace, status, audit_batch 

2508 ) 

2509 

2510 def _record_embed_error(self) -> None: 

2511 """Record an embedding error in metrics if available.""" 

2512 from kemi.operations import _ops_metrics 

2513 _ops_metrics.record_embed_error(self) 

2514 

2515 def _record_store_error(self) -> None: 

2516 """Record a storage error in metrics if available.""" 

2517 from kemi.operations import _ops_metrics 

2518 _ops_metrics.record_store_error(self) 

2519 

2520 def add_event_hook(self, phase: str, callback: Callable[..., Any]) -> None: 

2521 """Register an event hook callback. 

2522 

2523 Args: 

2524 phase: "pre" or "post" — called before or after the operation. 

2525 callback: Callable that receives (operation, **kwargs). 

2526 """ 

2527 from kemi.operations import _ops_hooks 

2528 _ops_hooks.add(self, phase, callback) 

2529 

2530 def remove_event_hook(self, phase: str, callback: Callable[..., Any]) -> bool: 

2531 """Remove a previously registered event hook callback. 

2532 

2533 Returns True if removed, False if not found. 

2534 """ 

2535 from kemi.operations import _ops_hooks 

2536 return _ops_hooks.remove(self, phase, callback) 

2537 

2538 def _run_hooks( 

2539 self, 

2540 phase: str, 

2541 operation: str, 

2542 *, 

2543 raise_on_error: bool | None = None, 

2544 **kwargs: Any, 

2545 ) -> None: 

2546 """Run all hooks registered for a phase/operation. 

2547 

2548 Args: 

2549 phase: "pre" or "post". 

2550 operation: Name of the operation triggering the hook. 

2551 raise_on_error: If True, exceptions from hooks are re-raised so 

2552 a failing pre-hook can abort the operation. If None (default), 

2553 the value is taken from ``self._config.hooks_raise_on_error``. 

2554 **kwargs: Passed through to each callback. 

2555 """ 

2556 from kemi.operations import _ops_hooks 

2557 _ops_hooks.run(self, phase, operation, raise_on_error=raise_on_error, **kwargs) 

2558 

2559 def enable_query_cache(self, max_size: int = 128) -> None: 

2560 """Enable an LRU cache for recall() results. 

2561 

2562 Args: 

2563 max_size: Maximum number of cached query results. 

2564 """ 

2565 from kemi.operations import _ops_metrics 

2566 _ops_metrics.enable_query_cache(self, max_size) 

2567 

2568 def disable_query_cache(self) -> None: 

2569 """Disable the query cache.""" 

2570 from kemi.operations import _ops_metrics 

2571 _ops_metrics.disable_query_cache(self) 

2572 

2573 def configure_versioning( 

2574 self, 

2575 db_path: str | None = None, 

2576 max_versions_per_memory: int = 50, 

2577 auto_prune_versions: bool = True, 

2578 ) -> None: 

2579 """Enable memory version history tracking. 

2580 

2581 Args: 

2582 db_path: Path to the SQLite database. Defaults to the store's db_path. 

2583 max_versions_per_memory: Maximum versions to keep per memory before pruning. 

2584 auto_prune_versions: If True, prune old versions when limits are exceeded. 

2585 """ 

2586 from kemi.operations import _ops_versioning 

2587 _ops_versioning.configure( 

2588 self, db_path, max_versions_per_memory, auto_prune_versions 

2589 ) 

2590 

2591 def _get_version_store(self) -> MemoryVersionStore: 

2592 """Get the version store, initialising it lazily from the storage adapter's db. 

2593 

2594 Falls back to an in-memory SQLite database when the storage adapter 

2595 does not expose a ``_db_path`` (e.g. mock adapters used in tests, 

2596 in-memory backends). 

2597 """ 

2598 from kemi.operations import _ops_versioning 

2599 return _ops_versioning.get_store(self) 

2600 

2601 def get_history( 

2602 self, 

2603 memory_id: str, 

2604 limit: int = 100, 

2605 ) -> list["VersionSnapshot"]: 

2606 """Return version history for a memory, newest first.""" 

2607 from kemi.operations import _ops_versioning 

2608 return _ops_versioning.get_history(self, memory_id, limit) 

2609 

2610 def diff_versions( 

2611 self, 

2612 memory_id: str, 

2613 from_version: int, 

2614 to_version: int, 

2615 ) -> "DiffResult | None": 

2616 """Show field-level differences between two versions of a memory.""" 

2617 from kemi.operations import _ops_versioning 

2618 return _ops_versioning.diff(self, memory_id, from_version, to_version) 

2619 

2620 def rollback_memory( 

2621 self, 

2622 memory_id: str, 

2623 target_version: int, 

2624 ) -> "RollbackResult | None": 

2625 """Roll a memory back to a previous version.""" 

2626 from kemi.operations import _ops_versioning 

2627 return _ops_versioning.rollback(self, memory_id, target_version) 

2628 

2629 def _auto_prune_versions_for_memory(self, memory_id: str) -> None: 

2630 """Prune old versions for a memory, keeping only the most recent ones.""" 

2631 from kemi.operations import _ops_versioning 

2632 _ops_versioning.auto_prune(self, memory_id) 

2633 

2634 def configure_webhooks(self, db_path: str | None = None) -> None: 

2635 """Enable webhook dispatch for memory lifecycle events. 

2636 

2637 Args: 

2638 db_path: Path to SQLite database for webhook config storage. 

2639 Defaults to the same path used by the storage adapter. 

2640 """ 

2641 from kemi.operations import _ops_webhooks 

2642 _ops_webhooks.configure(self, db_path) 

2643 

2644 def _dispatch_webhook_event( 

2645 self, 

2646 event: WebhookEventType, 

2647 memory_id: str, 

2648 user_id: str, 

2649 snapshot: dict[str, Any] | None = None, 

2650 previous_state: dict[str, Any] | None = None, 

2651 **extra: Any, 

2652 ) -> None: 

2653 """Dispatch a webhook event if a dispatcher is configured.""" 

2654 from kemi.operations import _ops_webhooks 

2655 _ops_webhooks.dispatch( 

2656 self, event, memory_id, user_id, snapshot, previous_state, **extra 

2657 ) 

2658 

2659 def enable_audit_trail( 

2660 self, 

2661 retention_days: int = 365, 

2662 auto_purge: bool = True, 

2663 ) -> None: 

2664 """Enable the audit trail for compliance logging.""" 

2665 from kemi.operations import _ops_metrics 

2666 _ops_metrics.enable_audit_trail(self, retention_days, auto_purge) 

2667 

2668 def get_metrics(self) -> dict[str, Any] | None: 

2669 """Return current metrics snapshot as a dict, or None if disabled.""" 

2670 from kemi.operations import _ops_metrics 

2671 return _ops_metrics.get_metrics(self) 

2672 

2673 def get_metrics_prometheus(self) -> str | None: 

2674 """Return metrics in Prometheus text format, or None if disabled.""" 

2675 from kemi.operations import _ops_metrics 

2676 return _ops_metrics.get_metrics_prometheus(self) 

2677 

2678 def enable_adaptive_retrieval(self, enable: bool = True) -> None: 

2679 """Enable or disable adaptive retrieval (re-weights hybrid scores per user).""" 

2680 from kemi.operations import _ops_metrics 

2681 _ops_metrics.enable_adaptive_retrieval(self, enable) 

2682 

2683 

2684class _QueryCache: 

2685 """DEPRECATED shim — moved to :mod:`kemi.operations._query_cache`. 

2686 

2687 Kept as a re-export so existing imports of ``kemi.core._QueryCache`` 

2688 keep working. The canonical location is ``kemi.operations._query_cache._QueryCache``. 

2689 """ 

2690 

2691 def __new__(cls, *args: Any, **kwargs: Any) -> Any: # pragma: no cover 

2692 from kemi.operations._query_cache import _QueryCache as _Impl 

2693 

2694 return _Impl(*args, **kwargs) 

2695 

2696 

2697_QueryCache.__doc__ = _QueryCache.__doc__