Coverage for src / kemi / _memory_impl.py: 84%
783 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1from __future__ import annotations
3import logging
4import os
5import uuid
6from collections.abc import Callable
7from datetime import datetime, timedelta, timezone
8from typing import TYPE_CHECKING, Any
10if TYPE_CHECKING:
11 from kemi.encryption import EncryptionConfig
13from kemi.versions import (
14 DiffResult,
15 MemoryVersionStore,
16 RollbackResult,
17 VersionSnapshot,
18)
19from kemi.webhooks import WebhookDispatcher, WebhookEventType
22def _memory_to_dict(memory: "MemoryObject") -> dict[str, Any]:
23 """Convert a MemoryObject to a JSON-serialisable dict for webhook payloads."""
24 if not isinstance(memory, MemoryObject):
25 return {}
26 return {
27 "memory_id": memory.memory_id,
28 "content": memory.content,
29 "importance": memory.importance,
30 "confidence": memory.confidence,
31 "lifecycle_state": memory.lifecycle_state.value if memory.lifecycle_state else None,
32 "memory_type": memory.memory_type.value if memory.memory_type else None,
33 "source": memory.source.value if memory.source else None,
34 "tags": memory.tags,
35 "namespace": memory.namespace,
36 "session_id": memory.session_id,
37 "version": memory.version,
38 "created_at": memory.created_at.isoformat() if memory.created_at else None,
39 "last_accessed_at": memory.last_accessed_at.isoformat() if memory.last_accessed_at else None,
40 "metadata": memory.metadata,
41 "agent_id": memory.agent_id,
42 "run_id": memory.run_id,
43 "app_id": memory.app_id,
44 }
46from kemi import lifecycle, sanitize, scoring
47from kemi.adapters.base import EmbeddingAdapter, StorageAdapter
48from kemi.entities import EntityLinker, NoopEntityLinker, RegexEntityLinker
49from kemi.models import (
50 LifecycleState,
51 MemoryConfig,
52 MemoryObject,
53 MemorySource,
54 MemoryType,
55)
57logger = logging.getLogger(__name__)
60class Memory:
61 def __init__(
62 self,
63 embed: EmbeddingAdapter | None = None,
64 store: StorageAdapter | None = None,
65 config: MemoryConfig | None = None,
66 encryption: "EncryptionConfig | None" = None,
67 entity_linker: "EntityLinker | None" = None,
68 ) -> None:
69 # Lazy import to avoid circular dependencies
70 from kemi.encryption import EncryptionConfig
72 if encryption is None:
73 try:
74 encryption = EncryptionConfig.from_env()
75 except Exception:
76 # Broad catch intentional: EncryptionConfig.from_env() reads
77 # env vars and can fail for many reasons (missing key, bad
78 # Fernet format, base64 decode errors). Encryption is opt-in;
79 # fall back to disabled rather than blocking Memory init.
80 encryption = None
81 if embed is None:
82 try: # pragma: no cover
83 from kemi.adapters.embedding.fastembed import FastEmbedAdapter
85 self._embed: EmbeddingAdapter = FastEmbedAdapter()
86 except ImportError as e:
87 raise ImportError(
88 "No embedding adapter provided and fastembed is not installed. "
89 "Install with: pip install kemi[local] or provide your own: "
90 "Memory(embed=YourAdapter())"
91 ) from e
92 else:
93 self._embed = embed
95 if store is None:
96 # Honor an explicit KEMI_DB_PATH so background-task workers and
97 # ad-hoc Memory() instantiations point at the same database as
98 # the main app, rather than silently defaulting to ~/.kemi.
99 env_path = os.environ.get("KEMI_DB_PATH")
100 if env_path:
101 default_db_path = os.path.expanduser(env_path)
102 else:
103 default_db_path = os.path.join(os.path.expanduser("~"), ".kemi", "memories.db")
104 os.makedirs(os.path.dirname(default_db_path), exist_ok=True)
106 try:
107 from kemi.adapters.storage.sqlite_vec import SQLiteVecStorageAdapter
109 if SQLiteVecStorageAdapter.is_vec_available():
110 embedding_dim = self._embed.dimension()
111 self._store: StorageAdapter = SQLiteVecStorageAdapter(
112 db_path=default_db_path,
113 embedding_dim=embedding_dim,
114 encryption=encryption if isinstance(encryption, EncryptionConfig) and encryption.enabled else None,
115 )
116 logger.info(
117 "Using SQLiteVecStorageAdapter with ANN vector search "
118 "(sqlite-vec installed)"
119 )
120 else:
121 from kemi.adapters.storage.sqlite import SQLiteStorageAdapter
123 self._store = SQLiteStorageAdapter(
124 db_path=default_db_path,
125 encryption=encryption if isinstance(encryption, EncryptionConfig) and encryption.enabled else None,
126 )
127 except ImportError: # pragma: no cover
128 from kemi.adapters.storage.sqlite import SQLiteStorageAdapter
130 self._store = SQLiteStorageAdapter(
131 db_path=default_db_path,
132 encryption=encryption if isinstance(encryption, EncryptionConfig) and encryption.enabled else None,
133 ) # pragma: no cover
134 else:
135 self._store = store
137 if config is None:
138 self._config: MemoryConfig = MemoryConfig()
139 else:
140 self._config = config
142 # Optional observability
143 self._metrics: Any | None = None
144 try:
145 from kemi.observability import get_metrics_collector
147 self._metrics = get_metrics_collector()
148 except ImportError:
149 pass
151 self._audit_trail: Any | None = None
152 self._adaptive_retriever: Any | None = None
153 self._event_hooks: dict[str, list[Callable[..., Any]]] = {"pre": [], "post": []}
154 self._query_cache: _QueryCache | None = None
155 self._version_store: MemoryVersionStore | None = None
156 self._max_versions_per_memory: int = 50
157 self._auto_prune_versions: bool = True
158 self._webhook_dispatcher: WebhookDispatcher | None = None
160 if entity_linker is not None:
161 self._entity_linker: EntityLinker = entity_linker
162 elif self._config.enable_entity_boost:
163 self._entity_linker = RegexEntityLinker()
164 else:
165 self._entity_linker = NoopEntityLinker()
167 def _latency_tracker(self, operation: str) -> Any:
168 """Return a context manager that tracks operation latency if metrics are enabled."""
169 from kemi.operations import _ops_metrics
170 return _ops_metrics.latency_tracker(self, operation)
172 def remember(
173 self,
174 user_id: str,
175 content: str,
176 importance: float = 0.5,
177 source: MemorySource = MemorySource.USER_STATED,
178 metadata: dict[str, Any] | None = None,
179 sanitize_input: bool = False,
180 tags: list[str] | None = None,
181 namespace: str = "default",
182 session_id: str | None = None,
183 memory_type: MemoryType = MemoryType.EPISODIC,
184 confidence: float = 1.0,
185 agent_id: str | None = None,
186 run_id: str | None = None,
187 app_id: str | None = None,
188 ttl_seconds: int | None = None,
189 ) -> str:
190 """Store a new memory. See :class:`kemi.pipeline.ingestion.IngestionPipeline`."""
191 from kemi.pipeline.ingestion import IngestionPipeline
193 self._validate_remember_inputs(user_id, content, importance, ttl_seconds)
195 # Pre-hooks fire BEFORE the embedding call (matches historical behaviour).
196 self._run_hooks(
197 "pre",
198 "remember",
199 user_id=user_id,
200 content=content,
201 namespace=namespace,
202 )
204 with self._latency_tracker("remember"):
205 if sanitize_input:
206 content = sanitize.sanitize(content, strict=self._config.sanitize)
208 try:
209 embedding = self._embed.embed_single(content)
210 except Exception:
211 # Broad catch intentional: embedding adapters wrap many
212 # different backends (fastembed, OpenAI, custom) each with
213 # their own exception types. We just need to count the
214 # error in metrics, then re-raise the original.
215 self._record_embed_error()
216 raise
218 memory = self._build_memory_object(
219 user_id=user_id,
220 content=content,
221 embedding=embedding,
222 importance=importance,
223 source=source,
224 metadata=metadata,
225 tags=tags,
226 namespace=namespace,
227 session_id=session_id,
228 memory_type=memory_type,
229 confidence=confidence,
230 agent_id=agent_id,
231 run_id=run_id,
232 app_id=app_id,
233 ttl_seconds=ttl_seconds,
234 )
236 stored = IngestionPipeline(self._build_ingestion_context()).ingest(memory)
238 self._run_hooks(
239 "post",
240 "remember",
241 user_id=user_id,
242 memory_id=stored.memory_id,
243 namespace=namespace,
244 )
245 return stored.memory_id
247 def _build_ingestion_context(self) -> "IngestionContext":
248 """Assemble a slim :class:`IngestionContext` from the current Memory state."""
249 from kemi.pipeline.ingestion import IngestionContext
251 return IngestionContext(
252 store=self._store,
253 config=self._config,
254 entity_linker=self._entity_linker,
255 metrics=self._metrics,
256 record_store_error=self._record_store_error,
257 dispatch_webhook=self._dispatch_webhook_event,
258 track_operation=self._track_operation,
259 get_version_store=self._get_version_store,
260 auto_prune_versions=self._auto_prune_versions_for_memory,
261 )
263 @staticmethod
264 def _validate_remember_inputs(
265 user_id: str,
266 content: str,
267 importance: float,
268 ttl_seconds: int | None,
269 ) -> None:
270 if not user_id or not user_id.strip():
271 raise ValueError("user_id cannot be empty")
272 if not content or not content.strip():
273 raise ValueError("content cannot be empty — there is nothing to remember")
274 if not isinstance(importance, (int, float)):
275 raise TypeError(
276 f"importance must be a number between 0.0 and 1.0, got {type(importance).__name__}"
277 )
278 if ttl_seconds is not None and (
279 not isinstance(ttl_seconds, int) or ttl_seconds <= 0
280 ):
281 raise ValueError(f"ttl_seconds must be a positive integer, got {ttl_seconds}")
283 @staticmethod
284 def _build_memory_object(
285 user_id: str,
286 content: str,
287 embedding: list[float],
288 importance: float,
289 source: MemorySource,
290 metadata: dict[str, Any] | None,
291 tags: list[str] | None,
292 namespace: str,
293 session_id: str | None,
294 memory_type: MemoryType,
295 confidence: float,
296 agent_id: str | None,
297 run_id: str | None,
298 app_id: str | None,
299 ttl_seconds: int | None,
300 ) -> MemoryObject:
301 """Construct a fresh ACTIVE ``MemoryObject`` from raw inputs."""
302 clamped_importance = max(0.0, min(1.0, importance))
303 return MemoryObject(
304 memory_id=str(uuid.uuid4()),
305 user_id=user_id,
306 content=content,
307 embedding=embedding,
308 score=0.0,
309 created_at=datetime.now(timezone.utc),
310 last_accessed_at=datetime.now(timezone.utc),
311 source=source,
312 importance=clamped_importance,
313 lifecycle_state=LifecycleState.ACTIVE,
314 metadata=metadata or {},
315 embedding_dim=len(embedding),
316 tags=tags or [],
317 confidence=max(0.0, min(1.0, confidence)),
318 memory_type=memory_type,
319 session_id=session_id,
320 namespace=namespace,
321 version=1,
322 agent_id=agent_id,
323 run_id=run_id,
324 app_id=app_id,
325 expires_at=(
326 datetime.now(timezone.utc).replace(microsecond=0)
327 + timedelta(seconds=ttl_seconds)
328 if ttl_seconds is not None
329 else None
330 ),
331 )
333 def recall(
334 self,
335 user_id: str,
336 query: str,
337 top_k: int = 5,
338 max_tokens: int | None = None,
339 lifecycle_filter: list[LifecycleState] | None = None,
340 hybrid_search: bool | None = None,
341 namespace: str = "default",
342 session_id: str | None = None,
343 metadata_filter: dict[str, Any] | None = None,
344 ) -> list[MemoryObject]:
345 """Recall memories matching ``query``. See :class:`kemi.pipeline.retrieval.RetrievalPipeline`."""
346 from kemi.pipeline.retrieval import RetrievalContext, RetrievalPipeline
348 if not user_id or not user_id.strip():
349 raise ValueError("user_id cannot be empty")
350 if not query or not query.strip():
351 raise ValueError("query cannot be empty — what should kemi search for?")
352 if top_k < 1:
353 raise ValueError(f"top_k must be at least 1, got {top_k}")
355 with self._latency_tracker("recall"):
356 return RetrievalPipeline(self._build_retrieval_context()).retrieve(
357 user_id=user_id,
358 query=query,
359 top_k=top_k,
360 max_tokens=max_tokens,
361 lifecycle_filter=lifecycle_filter,
362 hybrid_search=hybrid_search,
363 namespace=namespace,
364 session_id=session_id,
365 metadata_filter=metadata_filter,
366 )
368 def _build_retrieval_context(self) -> "RetrievalContext":
369 """Assemble a :class:`RetrievalContext` from the current Memory state."""
370 from kemi.pipeline.retrieval import RetrievalContext
372 return RetrievalContext(
373 store=self._store,
374 embed=self._embed,
375 config=self._config,
376 entity_linker=self._entity_linker,
377 query_cache=self._query_cache,
378 metrics=self._metrics,
379 adaptive_retriever=self._adaptive_retriever,
380 run_hooks=self._run_hooks,
381 track_operation=self._track_operation,
382 )
384 def recall_many(
385 self,
386 user_ids: list[str],
387 queries: list[str],
388 top_k: int = 5,
389 max_tokens: int | None = None,
390 lifecycle_filter: list[LifecycleState] | None = None,
391 hybrid_search: bool | None = None,
392 namespace: str = "default",
393 session_id: str | None = None,
394 metadata_filter: dict[str, Any] | None = None,
395 ) -> dict[str, list[MemoryObject]]:
396 """Recall memories for multiple users and queries at once.
398 Args:
399 user_ids: List of user IDs to recall for.
400 queries: List of query strings (same length as user_ids).
401 top_k: Max memories per result.
402 max_tokens: Token budget.
403 lifecycle_filter: Optional lifecycle state filter.
404 hybrid_search: Override hybrid search.
405 namespace: Memory namespace.
406 session_id: Optional session ID filter.
407 metadata_filter: Optional metadata key-value filter dict.
409 Returns:
410 Dict mapping user_id -> list of MemoryObjects.
411 """
412 if len(user_ids) != len(queries):
413 raise ValueError("user_ids and queries must have the same length")
414 results: dict[str, list[MemoryObject]] = {}
415 for uid, q in zip(user_ids, queries, strict=True):
416 results[uid] = self.recall(
417 user_id=uid,
418 query=q,
419 top_k=top_k,
420 max_tokens=max_tokens,
421 lifecycle_filter=lifecycle_filter,
422 hybrid_search=hybrid_search,
423 namespace=namespace,
424 session_id=session_id,
425 metadata_filter=metadata_filter,
426 )
427 return results
429 def update_many(
430 self,
431 memory_ids: list[str],
432 content: str | None = None,
433 importance: float | None = None,
434 confidence: float | None = None,
435 memory_type: MemoryType | None = None,
436 metadata: dict[str, Any] | None = None,
437 ) -> list[str]:
438 """Update multiple memories at once.
440 Args:
441 memory_ids: List of memory IDs to update.
442 content: New content for all (if provided, will re-embed).
443 importance: New importance for all.
444 confidence: New confidence for all.
445 memory_type: New memory type for all.
447 Returns:
448 List of updated memory IDs.
449 """
450 updated: list[str] = []
451 for mid in memory_ids:
452 self.update(
453 mid,
454 content=content,
455 importance=importance,
456 confidence=confidence,
457 memory_type=memory_type,
458 metadata=metadata,
459 )
460 updated.append(mid)
461 return updated
463 def forget_many(
464 self,
465 memory_ids: list[str],
466 ) -> int:
467 """Delete multiple memories by ID at once.
469 Args:
470 memory_ids: List of memory IDs to delete.
472 Returns:
473 Number of memories deleted.
475 Note:
476 Unlike :meth:`forget`, this method does not fire pre/post
477 event hooks for each individual delete. Hooks are intentionally
478 skipped for batch performance.
479 """
480 count = 0
481 for mid in memory_ids:
482 if self._store.delete_by_id(mid):
483 count += 1
484 return count
486 def forget(
487 self,
488 user_id: str,
489 memory_id: str | None = None,
490 ) -> int:
491 if not user_id or not user_id.strip():
492 raise ValueError("user_id cannot be empty")
494 with self._latency_tracker("forget"):
495 self._run_hooks(
496 "pre", "forget", user_id=user_id, memory_id=memory_id, namespace="default"
497 )
498 if memory_id is not None:
499 deleted = self._store.delete_by_id(memory_id)
500 if deleted:
501 self._dispatch_webhook_event(
502 WebhookEventType.DELETED,
503 memory_id=memory_id,
504 user_id=user_id,
505 )
506 self._run_hooks(
507 "post",
508 "forget",
509 user_id=user_id,
510 memory_id=memory_id,
511 deleted=deleted,
512 namespace="default",
513 )
514 self._track_operation(
515 "forget",
516 user_id,
517 {"memory_id": memory_id, "deleted": deleted},
518 memory_id,
519 namespace="default",
520 )
521 return 1 if deleted else 0
522 else:
523 count = self._store.delete_by_user(user_id)
524 if count:
525 self._dispatch_webhook_event(
526 WebhookEventType.DELETED,
527 memory_id="batch",
528 user_id=user_id,
529 deleted_count=count,
530 )
531 self._run_hooks(
532 "post", "forget", user_id=user_id, deleted_count=count, namespace="default"
533 )
534 self._track_operation(
535 "forget", user_id, {"deleted_count": count}, namespace="default"
536 )
537 return count
539 def context_block(
540 self,
541 user_id: str,
542 query: str,
543 top_k: int = 5,
544 max_tokens: int = 1500,
545 prefix: str = "Relevant context from memory:",
546 namespace: str = "default",
547 session_id: str | None = None,
548 ) -> str:
549 memories = self.recall(
550 user_id=user_id,
551 query=query,
552 top_k=top_k,
553 max_tokens=max_tokens,
554 namespace=namespace,
555 session_id=session_id,
556 )
558 if not memories:
559 return ""
561 lines = [prefix]
562 for mem in memories:
563 lines.append(f"- {mem.content}")
565 return "\n".join(lines)
567 async def aremember(
568 self,
569 user_id: str,
570 content: str,
571 importance: float = 0.5,
572 source: MemorySource = MemorySource.USER_STATED,
573 metadata: dict[str, Any] | None = None,
574 sanitize_input: bool = False,
575 tags: list[str] | None = None,
576 namespace: str = "default",
577 session_id: str | None = None,
578 memory_type: MemoryType = MemoryType.EPISODIC,
579 confidence: float = 1.0,
580 agent_id: str | None = None,
581 run_id: str | None = None,
582 app_id: str | None = None,
583 ) -> str:
584 import asyncio
586 return await asyncio.to_thread(
587 self.remember,
588 user_id,
589 content,
590 importance,
591 source,
592 metadata,
593 sanitize_input,
594 tags,
595 namespace,
596 session_id,
597 memory_type,
598 confidence,
599 agent_id,
600 run_id,
601 app_id,
602 )
604 async def recall_stream(
605 self,
606 user_id: str,
607 query: str,
608 top_k: int = 5,
609 max_tokens: int | None = None,
610 lifecycle_filter: list[LifecycleState] | None = None,
611 hybrid_search: bool | None = None,
612 namespace: str = "default",
613 session_id: str | None = None,
614 metadata_filter: dict[str, Any] | None = None,
615 ):
616 """Stream recall results as an async generator.
618 Scores all candidate memories, then yields each one as MMR reranking
619 selects it, providing progressive delivery instead of waiting for
620 full ranking.
622 Args:
623 Same as :meth:`recall`.
625 Yields:
626 MemoryObject instances in ranked order.
627 """
628 if not user_id or not user_id.strip():
629 raise ValueError("user_id cannot be empty")
630 if not query or not query.strip():
631 raise ValueError("query cannot be empty — what should kemi search for?")
632 if top_k < 1:
633 raise ValueError(f"top_k must be at least 1, got {top_k}")
635 import asyncio
637 if hybrid_search is None:
638 hybrid_search = self._config.hybrid_search
640 query_embedding = await asyncio.to_thread(self._embed.embed_single, query)
642 if lifecycle_filter is None:
643 lifecycle_filter = lifecycle.get_recall_filter()
645 # Run the synchronous store search in a thread
646 search_results = await asyncio.to_thread(
647 self._store.search,
648 user_id=user_id,
649 query_embedding=query_embedding,
650 top_k=top_k * 3,
651 lifecycle_filter=lifecycle_filter,
652 namespace=namespace,
653 session_id=session_id,
654 )
656 if metadata_filter is not None:
657 search_results = [
658 m
659 for m in search_results
660 if all(m.metadata.get(k) == v for k, v in metadata_filter.items())
661 ]
663 current_dim = self._embed.dimension()
664 if search_results:
665 stored_dim = search_results[0].embedding_dim
666 if stored_dim is not None and stored_dim != current_dim:
667 raise ValueError(
668 "Embedding dimension mismatch: stored memories use "
669 f"{stored_dim} dimensions but current adapter produces "
670 f"{current_dim} dimensions. Run memory.migrate(user_id, "
671 "new_adapter) to re-embed your memories."
672 )
674 # Entity extraction for boost
675 query_entities_stream: set[str] | None = None
676 memory_entities_map_stream: dict[str, set[str]] | None = None
677 if self._config.enable_entity_boost:
678 query_entities_stream = self._entity_linker.extract(query)
679 memory_entities_map_stream = {}
680 for m in search_results:
681 cached = m.metadata.get("extracted_entities")
682 if cached is not None:
683 memory_entities_map_stream[m.memory_id] = set(cached)
684 else:
685 memory_entities_map_stream[m.memory_id] = self._entity_linker.extract(m.content)
687 # Score all candidates (same as rank_memories)
688 corpus = [m.content for m in search_results] if len(search_results) > 1 else None
689 for memory in search_results:
690 mem_entities = None
691 if memory_entities_map_stream is not None:
692 mem_entities = memory_entities_map_stream.get(memory.memory_id)
693 memory.score = scoring.score_memory(
694 memory,
695 query_embedding,
696 query,
697 hybrid_search,
698 corpus,
699 self._config.weight_semantic,
700 self._config.weight_recency,
701 self._config.weight_bm25,
702 self._config.weight_semantic_no_embed,
703 self._config.weight_recency_no_embed,
704 self._config.weight_importance,
705 query_entities_stream,
706 mem_entities,
707 self._config.entity_boost_weight,
708 )
710 # Sort by score descending first so mmr_rerank_stream gets pre-sorted input
711 search_results.sort(key=lambda m: m.score, reverse=True)
713 # Truncate to token budget before MMR
714 effective_max_tokens = (
715 max_tokens if max_tokens is not None else self._config.max_tokens_default
716 )
718 if effective_max_tokens is not None:
719 search_results = scoring.truncate_by_tokens(search_results, effective_max_tokens)
721 # Apply MMR and yield progressively
722 yielded_memories: list[MemoryObject] = []
723 for memory in scoring.mmr_rerank_stream(
724 search_results, query_embedding, top_k, lambda_param=0.7
725 ):
726 # Update lifecycle and access time
727 memory.last_accessed_at = datetime.now(timezone.utc)
728 new_state = lifecycle.evaluate_lifecycle(memory, self._config.decay_threshold_hours)
729 if new_state != memory.lifecycle_state:
730 updated = lifecycle.transition(memory, new_state)
731 self._store.update(updated)
732 if self._metrics is not None:
733 self._metrics.lifecycle_transitions.inc(1)
734 yielded_memories.append(memory)
735 yield memory
737 # Update metrics after all yielded
738 if self._metrics is not None:
739 self._metrics.total_memories.set(self._store.count(user_id))
740 self._run_hooks(
741 "post",
742 "recall",
743 user_id=user_id,
744 query=query,
745 results=yielded_memories,
746 namespace=namespace,
747 )
748 self._track_operation(
749 "recall",
750 user_id,
751 {"query": query, "results_count": len(yielded_memories), "cache_hit": False, "stream": True},
752 namespace=namespace,
753 )
754 if self._adaptive_retriever is not None:
755 try:
756 profile = self._adaptive_retriever.analyze_query(query)
757 self._adaptive_retriever.record_feedback(user_id, query, profile)
758 except Exception:
759 logger.debug("Adaptive retrieval analysis failed", exc_info=True)
761 async def arecall(
762 self,
763 user_id: str,
764 query: str,
765 top_k: int = 5,
766 max_tokens: int | None = None,
767 lifecycle_filter: list[LifecycleState] | None = None,
768 hybrid_search: bool | None = None,
769 namespace: str = "default",
770 session_id: str | None = None,
771 metadata_filter: dict[str, Any] | None = None,
772 stream: bool = False,
773 ):
774 import asyncio
776 if stream:
777 return self.recall_stream(
778 user_id,
779 query,
780 top_k,
781 max_tokens,
782 lifecycle_filter,
783 hybrid_search,
784 namespace,
785 session_id,
786 metadata_filter,
787 )
789 return await asyncio.to_thread(
790 self.recall,
791 user_id,
792 query,
793 top_k,
794 max_tokens,
795 lifecycle_filter,
796 hybrid_search,
797 namespace,
798 session_id,
799 metadata_filter,
800 )
802 async def aforget(
803 self,
804 user_id: str,
805 memory_id: str | None = None,
806 ) -> int:
807 import asyncio
809 return await asyncio.to_thread(self.forget, user_id, memory_id)
811 async def acontext_block(
812 self,
813 user_id: str,
814 query: str,
815 top_k: int = 5,
816 max_tokens: int = 1500,
817 prefix: str = "Relevant context from memory:",
818 namespace: str = "default",
819 session_id: str | None = None,
820 ) -> str:
821 import asyncio
823 return await asyncio.to_thread(
824 self.context_block, user_id, query, top_k, max_tokens, prefix, namespace, session_id
825 )
827 def migrate(
828 self,
829 user_id: str,
830 new_embed_fn: EmbeddingAdapter,
831 batch_size: int = 100,
832 ) -> int:
833 if not user_id or not user_id.strip():
834 raise ValueError("user_id cannot be empty")
835 if batch_size < 1:
836 raise ValueError(f"batch_size must be at least 1, got {batch_size}")
838 with self._latency_tracker("migrate"):
839 memories = self._store.get_all_by_user(
840 user_id,
841 lifecycle_filter=[LifecycleState.ACTIVE, LifecycleState.DECAYING],
842 )
844 if not memories:
845 return 0
847 count = 0
849 for i in range(0, len(memories), batch_size):
850 batch = memories[i : i + batch_size]
851 contents = [m.content for m in batch]
852 new_embeddings = new_embed_fn.embed(contents)
853 new_dim = new_embed_fn.dimension()
855 for j, mem in enumerate(batch):
856 mem.embedding = new_embeddings[j]
857 mem.embedding_dim = new_dim
858 self._store.update(mem)
859 count += 1
861 logger.info(f"Migrated {count} memories for user {user_id}")
862 self._track_operation("migrate", user_id, {"count": count})
863 return count
865 def export(self, file_path: str) -> int:
866 """Export all memories to a JSON file."""
867 import json
869 all_memories = self._store.get_all()
870 memories_data = []
871 for mem in all_memories:
872 memories_data.append(
873 {
874 "memory_id": mem.memory_id,
875 "user_id": mem.user_id,
876 "content": mem.content,
877 "embedding": mem.embedding,
878 "score": mem.score,
879 "created_at": mem.created_at.isoformat() if mem.created_at else None,
880 "last_accessed_at": mem.last_accessed_at.isoformat()
881 if mem.last_accessed_at
882 else None,
883 "source": mem.source.value if mem.source else None,
884 "importance": mem.importance,
885 "lifecycle_state": mem.lifecycle_state.value if mem.lifecycle_state else None,
886 "metadata": mem.metadata,
887 "embedding_dim": mem.embedding_dim,
888 "tags": mem.tags,
889 "confidence": mem.confidence,
890 "memory_type": mem.memory_type.value,
891 "session_id": mem.session_id,
892 "namespace": mem.namespace,
893 "version": mem.version,
894 "agent_id": mem.agent_id,
895 "run_id": mem.run_id,
896 "app_id": mem.app_id,
897 }
898 )
900 with open(file_path, "w") as f:
901 json.dump(memories_data, f, indent=2)
903 logger.info(f"Exported {len(memories_data)} memories to {file_path}")
904 return len(memories_data)
906 def import_from(self, file_path: str) -> int:
907 """Import memories from a JSON file."""
908 import json
910 with open(file_path) as f:
911 memories_data = json.load(f)
913 imported_count = 0
914 for mem_data in memories_data:
915 existing = self._store.get(mem_data["memory_id"])
916 if existing is not None:
917 continue
919 from datetime import datetime
921 from kemi.models import LifecycleState, MemorySource
923 created_at = (
924 datetime.fromisoformat(mem_data["created_at"])
925 if mem_data.get("created_at")
926 else datetime.now(timezone.utc)
927 )
928 last_accessed_at = (
929 datetime.fromisoformat(mem_data["last_accessed_at"])
930 if mem_data.get("last_accessed_at")
931 else datetime.now(timezone.utc)
932 )
934 memory = MemoryObject(
935 memory_id=mem_data["memory_id"],
936 user_id=mem_data["user_id"],
937 content=mem_data["content"],
938 embedding=mem_data.get("embedding"),
939 score=mem_data.get("score", 0.0),
940 created_at=created_at,
941 last_accessed_at=last_accessed_at,
942 source=MemorySource(mem_data["source"])
943 if mem_data.get("source")
944 else MemorySource.USER_STATED,
945 importance=mem_data.get("importance", 0.5),
946 lifecycle_state=LifecycleState(mem_data["lifecycle_state"])
947 if mem_data.get("lifecycle_state")
948 else LifecycleState.ACTIVE,
949 metadata=mem_data.get("metadata", {}),
950 embedding_dim=mem_data.get("embedding_dim"),
951 tags=mem_data.get("tags", []),
952 confidence=mem_data.get("confidence", 1.0),
953 memory_type=MemoryType(mem_data["memory_type"])
954 if mem_data.get("memory_type")
955 else MemoryType.EPISODIC,
956 session_id=mem_data.get("session_id"),
957 namespace=mem_data.get("namespace", "default"),
958 version=mem_data.get("version", 1),
959 agent_id=mem_data.get("agent_id"),
960 run_id=mem_data.get("run_id"),
961 app_id=mem_data.get("app_id"),
962 )
964 self._store.store(memory)
965 imported_count += 1
967 logger.info(f"Imported {imported_count} memories from {file_path}")
968 return imported_count
970 async def aexport(self, file_path: str) -> int:
971 import asyncio
973 return await asyncio.to_thread(self.export, file_path)
975 async def aimport_from(self, file_path: str) -> int:
976 import asyncio
978 return await asyncio.to_thread(self.import_from, file_path)
980 def upgrade(self) -> None:
981 self._store.upgrade_schema(from_version=1, to_version=1)
982 logger.info("Schema upgraded to version 1")
984 def remember_many(
985 self,
986 user_id: str,
987 contents: list[str],
988 importance: float = 0.5,
989 source: MemorySource = MemorySource.USER_STATED,
990 tags: list[str] | None = None,
991 namespace: str = "default",
992 session_id: str | None = None,
993 memory_type: MemoryType = MemoryType.EPISODIC,
994 confidence: float = 1.0,
995 agent_id: str | None = None,
996 run_id: str | None = None,
997 app_id: str | None = None,
998 ttl_seconds: int | None = None,
999 ) -> list[str]:
1000 """Store multiple memories at once.
1002 Args:
1003 user_id: User ID.
1004 contents: List of content strings to remember.
1005 importance: Importance value (0.0-1.0) for all.
1006 source: Memory source.
1007 tags: Optional list of tags to apply to all memories.
1009 Returns:
1010 List of memory IDs.
1012 Note:
1013 Fires pre/post ``remember`` event hooks for each item so
1014 batch behavior is consistent with :meth:`recall_many` and
1015 :meth:`update_many`.
1016 """
1017 if not contents:
1018 return []
1020 with self._latency_tracker("remember_many"):
1021 # Batch embed all contents at once for performance
1022 embeddings = self._embed.embed(contents)
1024 memory_ids: list[str] = []
1025 audit_batch: list[dict[str, Any]] | None = [] if self._audit_trail is not None else None
1026 for i, content in enumerate(contents):
1027 if not content or not content.strip():
1028 raise ValueError("content cannot be empty — there is nothing to remember")
1029 self._run_hooks(
1030 "pre", "remember", user_id=user_id, content=content, namespace=namespace
1031 )
1032 memory_id = self._remember_with_embedding(
1033 user_id=user_id,
1034 content=content,
1035 embedding=embeddings[i],
1036 importance=importance,
1037 source=source,
1038 tags=tags,
1039 namespace=namespace,
1040 session_id=session_id,
1041 memory_type=memory_type,
1042 confidence=confidence,
1043 audit_batch=audit_batch,
1044 agent_id=agent_id,
1045 run_id=run_id,
1046 app_id=app_id,
1047 ttl_seconds=ttl_seconds,
1048 )
1049 self._run_hooks(
1050 "post", "remember", user_id=user_id, memory_id=memory_id, namespace=namespace
1051 )
1052 memory_ids.append(memory_id)
1054 if self._metrics is not None:
1055 self._metrics.remember_many_total.inc(1)
1056 self._metrics.total_memories.set(self._store.count(user_id))
1058 # Append batch-level audit entry and flush all entries in one transaction
1059 if audit_batch is not None:
1060 self._track_operation(
1061 "remember_many",
1062 user_id,
1063 {"count": len(memory_ids)},
1064 namespace=namespace,
1065 audit_batch=audit_batch,
1066 )
1067 if self._audit_trail is not None:
1068 try:
1069 self._audit_trail.log_operation_batch(audit_batch)
1070 except Exception:
1071 logger.warning("Audit log batch failed for remember_many", exc_info=True)
1072 return memory_ids
1074 def _remember_with_embedding(
1075 self,
1076 user_id: str,
1077 content: str,
1078 embedding: list[float],
1079 importance: float,
1080 source: MemorySource,
1081 metadata: dict[str, Any] | None = None,
1082 tags: list[str] | None = None,
1083 namespace: str = "default",
1084 session_id: str | None = None,
1085 memory_type: MemoryType = MemoryType.EPISODIC,
1086 confidence: float = 1.0,
1087 audit_batch: list[dict[str, Any]] | None = None,
1088 agent_id: str | None = None,
1089 run_id: str | None = None,
1090 app_id: str | None = None,
1091 ttl_seconds: int | None = None,
1092 ) -> str:
1093 """Internal: store a memory with a pre-computed embedding.
1095 Builds the ``MemoryObject`` and delegates to
1096 :class:`kemi.pipeline.ingestion.IngestionPipeline`. Used by
1097 :meth:`remember_many` in a tight loop. Does NOT fire hooks;
1098 that is the responsibility of the public :meth:`remember`.
1099 """
1100 from kemi.pipeline.ingestion import IngestionPipeline
1102 memory = self._build_memory_object(
1103 user_id=user_id,
1104 content=content,
1105 embedding=embedding,
1106 importance=importance,
1107 source=source,
1108 metadata=metadata,
1109 tags=tags,
1110 namespace=namespace,
1111 session_id=session_id,
1112 memory_type=memory_type,
1113 confidence=confidence,
1114 agent_id=agent_id,
1115 run_id=run_id,
1116 app_id=app_id,
1117 ttl_seconds=ttl_seconds,
1118 )
1119 stored = IngestionPipeline(self._build_ingestion_context()).ingest(
1120 memory, audit_batch=audit_batch
1121 )
1122 return stored.memory_id
1124 def list_users(self) -> list[str]:
1125 """Get all unique user IDs that have memories.
1127 Returns:
1128 List of user IDs.
1129 """
1130 return self._store.get_all_users()
1132 def prune(
1133 self,
1134 user_id: str,
1135 max_age_days: float | None = None,
1136 min_importance: float | None = None,
1137 lifecycle_states: list[LifecycleState] | None = None,
1138 namespace: str = "default",
1139 ) -> int:
1140 """Auto-prune old or low-importance memories.
1142 Args:
1143 user_id: User ID to prune.
1144 max_age_days: Delete memories older than this many days.
1145 min_importance: Delete memories with importance below this threshold.
1146 lifecycle_states: Only prune memories in these states.
1147 namespace: Memory namespace to prune.
1149 Returns:
1150 Number of memories deleted.
1151 """
1152 if not user_id or not user_id.strip():
1153 raise ValueError("user_id cannot be empty")
1155 with self._latency_tracker("prune"):
1156 all_memories = self._store.get_all_by_user(
1157 user_id,
1158 lifecycle_filter=lifecycle_states or [LifecycleState.DECAYING],
1159 namespace=namespace,
1160 )
1162 to_delete: list[str] = []
1163 now = datetime.now(timezone.utc)
1165 for mem in all_memories:
1166 if max_age_days is not None:
1167 age_days = (now - mem.created_at).total_seconds() / 86400.0
1168 if age_days > max_age_days:
1169 to_delete.append(mem.memory_id)
1170 continue
1172 if min_importance is not None:
1173 if mem.importance < min_importance:
1174 to_delete.append(mem.memory_id)
1175 continue
1177 for mid in to_delete:
1178 self._store.delete_by_id(mid)
1180 logger.info(f"Pruned {len(to_delete)} memories for user {user_id}")
1181 self._track_operation(
1182 "prune", user_id, {"deleted_count": len(to_delete)}, namespace=namespace
1183 )
1184 return len(to_delete)
1186 def prune_expired(
1187 self,
1188 user_id: str | None = None,
1189 namespace: str | None = None,
1190 ) -> int:
1191 """Delete memories whose ``expires_at`` has passed.
1193 Sweeps ACTIVE and DECAYING memories with a non-null ``expires_at``
1194 in the past, transitions them to DELETED, and removes them from
1195 the store. Called automatically by :meth:`run_maintenance`.
1197 Args:
1198 user_id: If provided, only sweep this user. If None, sweep all
1199 users.
1200 namespace: If provided, only sweep this namespace. If None,
1201 sweep memories across all namespaces.
1203 Returns:
1204 Number of memories deleted.
1205 """
1206 with self._latency_tracker("prune_expired"):
1207 now = datetime.now(timezone.utc)
1208 deleted = 0
1210 if user_id is not None:
1211 users = [user_id]
1212 else:
1213 users = self._store.get_all_users()
1215 for uid in users:
1216 if namespace is not None:
1217 memories = self._store.get_all_by_user(
1218 uid,
1219 lifecycle_filter=[LifecycleState.ACTIVE, LifecycleState.DECAYING],
1220 namespace=namespace,
1221 )
1222 else:
1223 # Sweep across all namespaces. The base adapter's
1224 # ``get_all_by_user`` defaults ``namespace="default"``,
1225 # so we fetch from each known namespace explicitly.
1226 namespaces = self._known_namespaces(uid)
1227 memories = []
1228 for ns in namespaces:
1229 memories.extend(
1230 self._store.get_all_by_user(
1231 uid,
1232 lifecycle_filter=[
1233 LifecycleState.ACTIVE,
1234 LifecycleState.DECAYING,
1235 ],
1236 namespace=ns,
1237 )
1238 )
1239 for mem in memories:
1240 if mem.expires_at is not None and mem.expires_at <= now:
1241 if self._store.delete_by_id(mem.memory_id):
1242 deleted += 1
1244 if deleted > 0:
1245 logger.info(
1246 f"Pruned {deleted} expired memories"
1247 + (f" for user {user_id}" if user_id else "")
1248 )
1249 self._track_operation(
1250 "prune_expired",
1251 user_id or "all",
1252 {"deleted_count": deleted},
1253 namespace=namespace or "all",
1254 )
1255 return deleted
1257 def _known_namespaces(self, user_id: str) -> set[str]:
1258 """Return the set of distinct namespaces holding memories for a user."""
1259 namespaces: set[str] = set()
1260 # Default namespace is always present even if empty.
1261 namespaces.add("default")
1262 try:
1263 # Use get_all_by_user with no namespace filter is not possible
1264 # (the base API defaults to "default"), so we sample a small
1265 # batch from get_all() and collect namespaces.
1266 for mem in self._store.get_all(limit=1000):
1267 if mem.user_id == user_id:
1268 namespaces.add(mem.namespace)
1269 except Exception:
1270 pass
1271 return namespaces
1273 def recall_between(
1274 self,
1275 user_id: str,
1276 query: str,
1277 start: datetime,
1278 end: datetime,
1279 top_k: int = 5,
1280 max_tokens: int | None = None,
1281 lifecycle_filter: list[LifecycleState] | None = None,
1282 namespace: str = "default",
1283 session_id: str | None = None,
1284 ) -> list[MemoryObject]:
1285 """Recall memories created within a specific date range.
1287 Args:
1288 user_id: User ID to search for.
1289 query: Search query.
1290 start: Start datetime (inclusive).
1291 end: End datetime (inclusive).
1292 top_k: Maximum memories to return.
1293 max_tokens: Token budget.
1294 lifecycle_filter: Filter by lifecycle state.
1295 namespace: Memory namespace.
1296 session_id: Filter by session ID.
1298 Returns:
1299 List of MemoryObjects created within the date range.
1300 """
1301 if not user_id or not user_id.strip():
1302 raise ValueError("user_id cannot be empty")
1303 if not query or not query.strip():
1304 raise ValueError("query cannot be empty")
1305 if top_k < 1:
1306 raise ValueError(f"top_k must be at least 1, got {top_k}")
1308 all_results = self.recall(
1309 user_id=user_id,
1310 query=query,
1311 top_k=top_k * 3,
1312 max_tokens=max_tokens,
1313 lifecycle_filter=lifecycle_filter,
1314 namespace=namespace,
1315 session_id=session_id,
1316 )
1318 filtered = [m for m in all_results if m.created_at and start <= m.created_at <= end]
1319 return filtered[:top_k]
1321 def recall_user_profile(
1322 self,
1323 user_id: str,
1324 *,
1325 top_k: int = 20,
1326 namespace: str = "default",
1327 ) -> list[MemoryObject]:
1328 """Recall a user's long-lived profile — semantic facts and preferences.
1330 Filters for SEMANTIC memories that are ACTIVE, DECAYING, or ARCHIVED,
1331 then ranks by importance (highest first). This is the ergonomic
1332 equivalent of:
1334 .. code-block:: python
1336 memories = memory.recall(
1337 user_id="alice",
1338 query="profile preferences facts",
1339 top_k=20,
1340 lifecycle_filter=[ACTIVE, DECAYING, ARCHIVED],
1341 )
1343 Args:
1344 user_id: User whose profile to retrieve.
1345 top_k: Maximum number of profile facts to return.
1346 namespace: Memory namespace.
1348 Returns:
1349 List of MemoryObjects sorted by importance descending.
1350 """
1351 if not user_id or not user_id.strip():
1352 raise ValueError("user_id cannot be empty")
1353 if top_k < 1:
1354 raise ValueError(f"top_k must be at least 1, got {top_k}")
1356 with self._latency_tracker("recall_user_profile"):
1357 self._run_hooks("pre", "recall_user_profile", user_id=user_id, namespace=namespace)
1359 all_memories = self._store.get_all_by_user(
1360 user_id,
1361 lifecycle_filter=[
1362 LifecycleState.ACTIVE,
1363 LifecycleState.DECAYING,
1364 LifecycleState.ARCHIVED,
1365 ],
1366 namespace=namespace,
1367 )
1369 profile_memories = [
1370 m for m in all_memories if m.memory_type == MemoryType.SEMANTIC
1371 ]
1372 profile_memories.sort(key=lambda m: m.importance, reverse=True)
1374 for mem in profile_memories[:top_k]:
1375 mem.last_accessed_at = datetime.now(timezone.utc)
1376 new_state = lifecycle.evaluate_lifecycle(mem, self._config.decay_threshold_hours)
1377 if new_state != mem.lifecycle_state:
1378 updated = lifecycle.transition(mem, new_state)
1379 self._store.update(updated)
1381 self._run_hooks(
1382 "post",
1383 "recall_user_profile",
1384 user_id=user_id,
1385 results=profile_memories[:top_k],
1386 namespace=namespace,
1387 )
1388 self._track_operation(
1389 "recall_user_profile",
1390 user_id,
1391 {"results_count": len(profile_memories[:top_k])},
1392 namespace=namespace,
1393 )
1394 return profile_memories[:top_k]
1396 def recall_session_context(
1397 self,
1398 user_id: str,
1399 session_id: str,
1400 *,
1401 top_k: int = 20,
1402 namespace: str = "default",
1403 ) -> list[MemoryObject]:
1404 """Recall recent episodic memories scoped to a specific session.
1406 Filters for EPISODIC memories that belong to the given *session_id*,
1407 then ranks by recency (most recent first). This is the ergonomic
1408 equivalent of:
1410 .. code-block:: python
1412 memories = memory.recall(
1413 user_id="alice",
1414 query="session context",
1415 top_k=20,
1416 lifecycle_filter=[ACTIVE, DECAYING],
1417 session_id="sess_123",
1418 )
1420 Args:
1421 user_id: User whose session to retrieve.
1422 session_id: Session identifier.
1423 top_k: Maximum number of session memories to return.
1424 namespace: Memory namespace.
1426 Returns:
1427 List of MemoryObjects sorted by created_at descending.
1428 """
1429 if not user_id or not user_id.strip():
1430 raise ValueError("user_id cannot be empty")
1431 if not session_id or not session_id.strip():
1432 raise ValueError("session_id cannot be empty")
1433 if top_k < 1:
1434 raise ValueError(f"top_k must be at least 1, got {top_k}")
1436 with self._latency_tracker("recall_session_context"):
1437 self._run_hooks(
1438 "pre", "recall_session_context", user_id=user_id, session_id=session_id, namespace=namespace
1439 )
1441 all_memories = self._store.get_all_by_user(
1442 user_id,
1443 lifecycle_filter=[
1444 LifecycleState.ACTIVE,
1445 LifecycleState.DECAYING,
1446 LifecycleState.ARCHIVED,
1447 ],
1448 namespace=namespace,
1449 session_id=session_id,
1450 )
1452 session_memories = [
1453 m for m in all_memories if m.memory_type == MemoryType.EPISODIC
1454 ]
1455 session_memories.sort(key=lambda m: m.created_at or datetime.min.replace(tzinfo=timezone.utc), reverse=True)
1457 for mem in session_memories[:top_k]:
1458 mem.last_accessed_at = datetime.now(timezone.utc)
1459 new_state = lifecycle.evaluate_lifecycle(mem, self._config.decay_threshold_hours)
1460 if new_state != mem.lifecycle_state:
1461 updated = lifecycle.transition(mem, new_state)
1462 self._store.update(updated)
1464 self._run_hooks(
1465 "post",
1466 "recall_session_context",
1467 user_id=user_id,
1468 session_id=session_id,
1469 results=session_memories[:top_k],
1470 namespace=namespace,
1471 )
1472 self._track_operation(
1473 "recall_session_context",
1474 user_id,
1475 {"results_count": len(session_memories[:top_k]), "session_id": session_id},
1476 namespace=namespace,
1477 )
1478 return session_memories[:top_k]
1480 def recall_agent_knowledge(
1481 self,
1482 agent_id: str,
1483 *,
1484 namespace: str = "default",
1485 top_k: int = 50,
1486 ) -> list[MemoryObject]:
1487 """Recall memories that belong to a specific agent.
1489 Scans across all users (via :meth:`list_users`) and returns the
1490 agent's most important memories in the given namespace.
1492 Args:
1493 agent_id: Agent identifier to filter by.
1494 namespace: Memory namespace.
1495 top_k: Maximum number of memories to return.
1497 Returns:
1498 List of MemoryObjects sorted by importance descending.
1499 """
1500 if not agent_id or not agent_id.strip():
1501 raise ValueError("agent_id cannot be empty")
1502 if top_k < 1:
1503 raise ValueError(f"top_k must be at least 1, got {top_k}")
1505 with self._latency_tracker("recall_agent_knowledge"):
1506 self._run_hooks(
1507 "pre", "recall_agent_knowledge", agent_id=agent_id, namespace=namespace
1508 )
1510 all_users = self._store.get_all_users()
1511 agent_memories: list[MemoryObject] = []
1512 for uid in all_users:
1513 user_memories = self._store.get_all_by_user(
1514 uid,
1515 lifecycle_filter=[
1516 LifecycleState.ACTIVE,
1517 LifecycleState.DECAYING,
1518 LifecycleState.ARCHIVED,
1519 ],
1520 namespace=namespace,
1521 )
1522 for mem in user_memories:
1523 if mem.agent_id == agent_id:
1524 agent_memories.append(mem)
1526 agent_memories.sort(key=lambda m: m.importance, reverse=True)
1528 for mem in agent_memories[:top_k]:
1529 mem.last_accessed_at = datetime.now(timezone.utc)
1530 new_state = lifecycle.evaluate_lifecycle(mem, self._config.decay_threshold_hours)
1531 if new_state != mem.lifecycle_state:
1532 updated = lifecycle.transition(mem, new_state)
1533 self._store.update(updated)
1535 self._run_hooks(
1536 "post",
1537 "recall_agent_knowledge",
1538 agent_id=agent_id,
1539 results=agent_memories[:top_k],
1540 namespace=namespace,
1541 )
1542 self._track_operation(
1543 "recall_agent_knowledge",
1544 "all",
1545 {"results_count": len(agent_memories[:top_k]), "agent_id": agent_id},
1546 namespace=namespace,
1547 )
1548 return agent_memories[:top_k]
1550 def recall_explain(
1551 self,
1552 user_id: str,
1553 query: str,
1554 top_k: int = 5,
1555 namespace: str = "default",
1556 session_id: str | None = None,
1557 ) -> list[dict[str, Any]]:
1558 """Recall memories with detailed score breakdowns.
1560 Returns each memory with an 'explanation' dict showing:
1561 - semantic_score: cosine similarity contribution
1562 - recency_score: temporal recency contribution
1563 - bm25_score: keyword match contribution (if hybrid)
1564 - importance_score: importance contribution
1565 - final_score: the combined score
1567 Args:
1568 user_id: User ID.
1569 query: Search query.
1570 top_k: Max results.
1571 namespace: Memory namespace.
1572 session_id: Filter by session.
1574 Returns:
1575 List of dicts with 'memory' (MemoryObject) and 'explanation'.
1576 """
1577 if not user_id or not user_id.strip():
1578 raise ValueError("user_id cannot be empty")
1579 if not query or not query.strip():
1580 raise ValueError("query cannot be empty")
1581 if top_k < 1:
1582 raise ValueError(f"top_k must be at least 1, got {top_k}")
1584 query_embedding = self._embed.embed_single(query)
1586 search_results = self._store.search(
1587 user_id=user_id,
1588 query_embedding=query_embedding,
1589 top_k=top_k * 3,
1590 lifecycle_filter=lifecycle.get_recall_filter(),
1591 namespace=namespace,
1592 session_id=session_id,
1593 )
1595 corpus = [m.content for m in search_results] if len(search_results) > 1 else None
1597 query_entities_explain: set[str] | None = None
1598 memory_entities_map_explain: dict[str, set[str]] | None = None
1599 if self._config.enable_entity_boost:
1600 query_entities_explain = self._entity_linker.extract(query)
1601 memory_entities_map_explain = {}
1602 for m in search_results:
1603 cached = m.metadata.get("extracted_entities")
1604 if cached is not None:
1605 memory_entities_map_explain[m.memory_id] = set(cached)
1606 else:
1607 memory_entities_map_explain[m.memory_id] = self._entity_linker.extract(m.content)
1609 explained: list[dict[str, Any]] = []
1610 for memory in search_results:
1611 semantic = scoring.cosine_similarity(memory.embedding, query_embedding)
1612 semantic_norm = (semantic + 1.0) / 2.0
1613 recency = scoring.temporal_recency(memory.last_accessed_at)
1615 entity_score = 0.0
1616 if (
1617 self._config.enable_entity_boost
1618 and query_entities_explain is not None
1619 and memory_entities_map_explain is not None
1620 ):
1621 mem_entities = memory_entities_map_explain.get(memory.memory_id, set())
1622 entity_score = scoring.jaccard_similarity(query_entities_explain, mem_entities)
1624 if self._config.hybrid_search:
1625 if corpus and len(corpus) > 1:
1626 bm25 = scoring.bm25_score_corpus(query, memory.content, corpus)
1627 else:
1628 bm25 = scoring.bm25_score(query, memory.content)
1629 final = (
1630 semantic_norm * self._config.weight_semantic
1631 + recency * self._config.weight_recency
1632 + bm25 * self._config.weight_bm25
1633 + entity_score * self._config.entity_boost_weight
1634 )
1635 explanation = {
1636 "semantic_score": round(semantic_norm, 4),
1637 "recency_score": round(recency, 4),
1638 "bm25_score": round(bm25, 4),
1639 "importance_score": None,
1640 "entity_score": round(entity_score, 4),
1641 "final_score": round(final, 4),
1642 "weights": {
1643 "semantic": self._config.weight_semantic,
1644 "recency": self._config.weight_recency,
1645 "bm25": self._config.weight_bm25,
1646 "entity": self._config.entity_boost_weight,
1647 },
1648 }
1649 else:
1650 importance = max(0.0, min(1.0, memory.importance))
1651 final = (
1652 semantic_norm * self._config.weight_semantic_no_embed
1653 + recency * self._config.weight_recency_no_embed
1654 + importance * self._config.weight_importance
1655 + entity_score * self._config.entity_boost_weight
1656 )
1657 explanation = {
1658 "semantic_score": round(semantic_norm, 4),
1659 "recency_score": round(recency, 4),
1660 "bm25_score": None,
1661 "importance_score": round(importance, 4),
1662 "entity_score": round(entity_score, 4),
1663 "final_score": round(final, 4),
1664 "weights": {
1665 "semantic": self._config.weight_semantic_no_embed,
1666 "recency": self._config.weight_recency_no_embed,
1667 "importance": self._config.weight_importance,
1668 "entity": self._config.entity_boost_weight,
1669 },
1670 }
1672 memory.score = final
1673 explained.append({"memory": memory, "explanation": explanation})
1675 explained.sort(key=lambda x: x["explanation"]["final_score"], reverse=True)
1676 return explained[:top_k]
1678 def consolidate(
1679 self,
1680 user_id: str,
1681 namespace: str = "default",
1682 min_memories: int = 5,
1683 max_age_days: float = 30.0,
1684 with_llm_summary: bool = False,
1685 ) -> str | None:
1686 """Consolidate old episodic memories into a semantic summary.
1688 Uses local extractive summarization (no LLM required) by default.
1689 When ``with_llm_summary=True``, uses LLM-powered abstractive
1690 summarization via the configured provider (see ``MemoryConfig``).
1692 Finds clusters of related old memories, generates a summary for
1693 each, stores it as a SEMANTIC memory, and archives the old ones.
1695 Args:
1696 user_id: User to consolidate.
1697 namespace: Memory namespace.
1698 min_memories: Minimum memories needed to form a cluster.
1699 max_age_days: Only consider memories older than this.
1700 with_llm_summary: If True, use LLM-powered abstractive summary.
1702 Returns:
1703 Memory ID of the consolidated summary, or None if no consolidation occurred.
1704 """
1705 try:
1706 from kemi import consolidation
1707 except ImportError: # pragma: no cover
1708 logger.warning("consolidation module not available")
1709 return None
1711 mid = consolidation.consolidate(
1712 store=self._store,
1713 embed=self._embed,
1714 user_id=user_id,
1715 namespace=namespace,
1716 min_memories=min_memories,
1717 max_age_days=max_age_days,
1718 with_llm_summary=with_llm_summary,
1719 summarizer_llm_provider=self._config.summarizer_llm_provider,
1720 summarizer_llm_model=self._config.summarizer_llm_model,
1721 summarizer_prompt_template=self._config.summarizer_prompt_template,
1722 )
1723 if mid is not None:
1724 self._dispatch_webhook_event(
1725 WebhookEventType.CONSOLIDATED,
1726 memory_id=mid,
1727 user_id=user_id,
1728 )
1729 return mid
1731 def cluster_topics(
1732 self,
1733 user_id: str,
1734 n_clusters: int = 3,
1735 namespace: str = "default",
1736 ) -> dict[str, list[MemoryObject]]:
1737 """Cluster memories into topic groups using embeddings.
1739 Requires scikit-learn to be installed.
1741 Args:
1742 user_id: User ID.
1743 n_clusters: Number of topic clusters.
1744 namespace: Memory namespace.
1746 Returns:
1747 Dict mapping topic labels to lists of memories.
1748 """
1749 try:
1750 from kemi import topics
1751 except ImportError: # pragma: no cover
1752 logger.warning("topics module not available")
1753 return {}
1755 return topics.cluster_memories(
1756 store=self._store,
1757 user_id=user_id,
1758 n_clusters=n_clusters,
1759 namespace=namespace,
1760 )
1762 def extract_entities(self, memory_id: str) -> list[dict[str, Any]]:
1763 """Extract named entities from a memory's content.
1765 Uses regex/heuristic-based extraction (no external NER model required).
1767 Args:
1768 memory_id: Memory ID.
1770 Returns:
1771 List of entity dicts with 'text', 'label', 'start', 'end'.
1772 """
1773 try:
1774 from kemi import graph
1775 except ImportError: # pragma: no cover
1776 logger.warning("graph module not available")
1777 return []
1779 memory = self._store.get(memory_id)
1780 if memory is None:
1781 raise ValueError(f"Memory not found: {memory_id}")
1783 return graph.extract_entities(memory.content)
1785 def get_memory_graph(
1786 self,
1787 user_id: str,
1788 namespace: str = "default",
1789 ) -> dict[str, Any]:
1790 """Build a memory graph of entities and relations.
1792 Args:
1793 user_id: User ID.
1794 namespace: Memory namespace.
1796 Returns:
1797 Dict with 'entities' (list) and 'relations' (list of dicts).
1798 """
1799 try:
1800 from kemi import graph
1801 except ImportError: # pragma: no cover
1802 logger.warning("graph module not available")
1803 return {"entities": [], "relations": []}
1805 return graph.build_memory_graph(
1806 store=self._store,
1807 user_id=user_id,
1808 namespace=namespace,
1809 )
1811 def stats(
1812 self,
1813 user_id: str,
1814 lifecycle_filter: list[LifecycleState] | None = None,
1815 session_id: str | None = None,
1816 ) -> dict[str, Any]:
1817 """Return health statistics for a user's memory store.
1819 Args:
1820 user_id: The user whose memories to analyze.
1821 lifecycle_filter: Optional list of lifecycle states to filter by.
1822 session_id: Optional session ID to filter by.
1824 Returns a dict with these keys:
1825 total: int - total number of memories
1826 by_lifecycle: dict - count per lifecycle state
1827 e.g. {"active": 10, "decaying": 3, "archived": 1, "deleted": 0}
1828 by_source: dict - count per memory source
1829 e.g. {"user_stated": 8, "agent_inferred": 5}
1830 avg_importance: float - average importance score (0.0 if no memories)
1831 tag_counts: dict - how many memories each tag appears in
1832 e.g. {"food": 3, "work": 7}
1833 total_with_tags: int - number of memories that have at least one tag
1834 total_without_tags: int - number of memories with no tags
1835 """
1836 if not user_id or not user_id.strip():
1837 raise ValueError("user_id cannot be empty")
1839 all_memories = self._store.get_all_by_user(
1840 user_id, lifecycle_filter=lifecycle_filter, session_id=session_id
1841 )
1843 by_lifecycle = {state.value: 0 for state in LifecycleState}
1844 by_source = {source.value: 0 for source in MemorySource}
1845 tag_counts: dict[str, int] = {}
1846 total_with_tags = 0
1847 total_importance = 0.0
1849 for mem in all_memories:
1850 by_lifecycle[mem.lifecycle_state.value] += 1
1851 by_source[mem.source.value] += 1
1852 total_importance += mem.importance
1854 if mem.tags:
1855 total_with_tags += 1
1856 for tag in mem.tags:
1857 tag_counts[tag] = tag_counts.get(tag, 0) + 1
1859 total = len(all_memories)
1860 avg_importance = total_importance / total if total > 0 else 0.0
1861 total_without_tags = total - total_with_tags
1863 return {
1864 "total": total,
1865 "by_lifecycle": by_lifecycle,
1866 "by_source": by_source,
1867 "avg_importance": avg_importance,
1868 "tag_counts": tag_counts,
1869 "total_with_tags": total_with_tags,
1870 "total_without_tags": total_without_tags,
1871 }
1873 async def astats(self, user_id: str) -> dict[str, Any]:
1874 """Async version of stats()."""
1875 import asyncio
1877 return await asyncio.to_thread(self.stats, user_id)
1879 def recall_by_tag(
1880 self,
1881 user_id: str,
1882 tag: str,
1883 lifecycle_filter: list[LifecycleState] | None = None,
1884 ) -> list[MemoryObject]:
1885 """Recall memories by tag.
1887 Args:
1888 user_id: User ID to search for.
1889 tag: Tag to filter by.
1890 lifecycle_filter: Filter by lifecycle state.
1892 Returns:
1893 List of MemoryObjects with the specified tag.
1894 """
1895 if not user_id or not user_id.strip():
1896 raise ValueError("user_id cannot be empty")
1897 if not tag or not tag.strip():
1898 raise ValueError("tag cannot be empty")
1900 return self._store.get_by_tag(user_id, tag, lifecycle_filter)
1902 async def arecall_by_tag(
1903 self,
1904 user_id: str,
1905 tag: str,
1906 lifecycle_filter: list[LifecycleState] | None = None,
1907 ) -> list[MemoryObject]:
1908 """Async version of recall_by_tag()."""
1909 import asyncio
1911 return await asyncio.to_thread(self.recall_by_tag, user_id, tag, lifecycle_filter)
1913 def update(
1914 self,
1915 memory_id: str,
1916 content: str | None = None,
1917 importance: float | None = None,
1918 confidence: float | None = None,
1919 memory_type: MemoryType | None = None,
1920 metadata: dict[str, Any] | None = None,
1921 tags: list[str] | None = None,
1922 ) -> str:
1923 """Update an existing memory.
1925 Args:
1926 memory_id: ID of memory to update.
1927 content: New content (if provided, will re-embed).
1928 importance: New importance value (0.0-1.0).
1929 confidence: New confidence value (0.0-1.0).
1930 memory_type: New memory type.
1931 metadata: Metadata dict to merge into existing metadata.
1932 tags: New tags to replace existing tags.
1934 Returns:
1935 The memory_id of updated memory.
1937 Raises:
1938 ValueError: If memory_id not found.
1939 """
1940 if (
1941 content is None
1942 and importance is None
1943 and confidence is None
1944 and memory_type is None
1945 and metadata is None
1946 and tags is None
1947 ):
1948 return memory_id
1950 with self._latency_tracker("update"):
1951 self._run_hooks("pre", "update", memory_id=memory_id)
1952 memory = self._store.get(memory_id)
1953 if memory is None:
1954 raise ValueError(f"Memory not found: {memory_id}")
1956 # Capture pre-update state BEFORE mutating the memory object.
1957 # record_before_update expects two separate objects (before, after).
1958 memory_before = MemoryObject(
1959 memory_id=memory.memory_id,
1960 user_id=memory.user_id,
1961 content=memory.content,
1962 embedding=memory.embedding,
1963 score=memory.score,
1964 created_at=memory.created_at,
1965 last_accessed_at=memory.last_accessed_at,
1966 source=memory.source,
1967 importance=memory.importance,
1968 lifecycle_state=memory.lifecycle_state,
1969 metadata=memory.metadata.copy() if memory.metadata else {},
1970 embedding_dim=memory.embedding_dim,
1971 tags=list(memory.tags) if memory.tags else [],
1972 confidence=memory.confidence,
1973 memory_type=memory.memory_type,
1974 session_id=memory.session_id,
1975 namespace=memory.namespace,
1976 version=memory.version,
1977 agent_id=memory.agent_id,
1978 run_id=memory.run_id,
1979 app_id=memory.app_id,
1980 )
1982 # Now apply all mutations
1983 if content is not None:
1984 memory.content = content
1985 memory.embedding = self._embed.embed_single(content)
1986 memory.embedding_dim = len(memory.embedding)
1987 memory.last_accessed_at = datetime.now(timezone.utc)
1988 if self._config.enable_entity_boost:
1989 memory.metadata["extracted_entities"] = list(
1990 self._entity_linker.extract(content)
1991 )
1993 if importance is not None:
1994 memory.importance = max(0.0, min(1.0, importance))
1996 if confidence is not None:
1997 memory.confidence = max(0.0, min(1.0, confidence))
1999 if memory_type is not None:
2000 memory.memory_type = memory_type
2002 if metadata is not None:
2003 memory.metadata.update(metadata)
2005 if tags is not None:
2006 memory.tags = tags
2008 # Record version BEFORE and AFTER the update (pre + post snapshot)
2009 try:
2010 vs = self._get_version_store()
2011 vs.record_before_update(memory_before, memory, changed_by="update")
2012 self._auto_prune_versions_for_memory(memory_id)
2013 except (RuntimeError, Exception):
2014 pass # Versioning is optional
2016 previous_state = _memory_to_dict(memory)
2017 memory.version += 1
2019 snapshot = _memory_to_dict(memory)
2020 self._dispatch_webhook_event(
2021 WebhookEventType.UPDATED,
2022 memory_id=memory_id,
2023 user_id=memory.user_id,
2024 snapshot=snapshot,
2025 previous_state=previous_state,
2026 )
2028 self._store.update(memory)
2029 self._run_hooks("post", "update", memory_id=memory_id, version=memory.version)
2030 logger.info(f"Updated memory: {memory_id} (version {memory.version})")
2031 self._track_operation(
2032 "update",
2033 memory.user_id,
2034 {"memory_id": memory_id, "version": memory.version},
2035 memory_id,
2036 memory.namespace,
2037 )
2038 return memory_id
2040 def recall_since(
2041 self,
2042 user_id: str,
2043 query: str,
2044 hours: float = 24.0,
2045 top_k: int = 5,
2046 max_tokens: int | None = None,
2047 lifecycle_filter: list[LifecycleState] | None = None,
2048 ) -> list[MemoryObject]:
2049 """Recall memories created in the last N hours.
2051 Args:
2052 user_id: User ID to search for.
2053 query: Search query.
2054 hours: Only return memories created in last N hours.
2055 top_k: Maximum memories to return.
2056 max_tokens: Token budget for context_block.
2057 lifecycle_filter: Filter by lifecycle state.
2059 Returns:
2060 List of MemoryObjects.
2061 """
2062 from datetime import timedelta
2064 cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
2066 all_results = self.recall(
2067 user_id=user_id,
2068 query=query,
2069 top_k=top_k * 3,
2070 max_tokens=max_tokens,
2071 lifecycle_filter=lifecycle_filter,
2072 )
2074 filtered = [m for m in all_results if m.created_at and m.created_at >= cutoff]
2075 return filtered[:top_k]
2077 async def alist_users(self) -> list[str]:
2078 """Async version of list_users()."""
2079 import asyncio
2081 return await asyncio.to_thread(self.list_users)
2083 async def aupdate(
2084 self,
2085 memory_id: str,
2086 content: str | None = None,
2087 importance: float | None = None,
2088 confidence: float | None = None,
2089 memory_type: MemoryType | None = None,
2090 metadata: dict[str, Any] | None = None,
2091 ) -> str:
2092 """Async version of update()."""
2093 import asyncio
2095 return await asyncio.to_thread(
2096 self.update,
2097 memory_id,
2098 content,
2099 importance,
2100 confidence,
2101 memory_type,
2102 metadata,
2103 )
2105 async def aupdate_many(
2106 self,
2107 memory_ids: list[str],
2108 content: str | None = None,
2109 importance: float | None = None,
2110 confidence: float | None = None,
2111 memory_type: MemoryType | None = None,
2112 metadata: dict[str, Any] | None = None,
2113 ) -> list[str]:
2114 """Async version of update_many() — runs updates concurrently.
2116 Small batches (≤10) use individual concurrent aupdate calls via
2117 asyncio.gather. Larger batches fall back to a single threaded
2118 update_many to avoid excessive thread overhead.
2119 """
2120 import asyncio
2122 if not memory_ids:
2123 return []
2125 if len(memory_ids) <= 10:
2126 tasks = [
2127 self.aupdate(
2128 mid,
2129 content=content,
2130 importance=importance,
2131 confidence=confidence,
2132 memory_type=memory_type,
2133 metadata=metadata,
2134 )
2135 for mid in memory_ids
2136 ]
2137 return await asyncio.gather(*tasks)
2139 return await asyncio.to_thread(
2140 self.update_many,
2141 memory_ids,
2142 content,
2143 importance,
2144 confidence,
2145 memory_type,
2146 metadata,
2147 )
2149 async def aforget_many(
2150 self,
2151 memory_ids: list[str],
2152 ) -> int:
2153 """Async version of forget_many() — runs deletions concurrently."""
2154 import asyncio
2156 if not memory_ids:
2157 return 0
2159 tasks = [asyncio.to_thread(self._store.delete_by_id, mid) for mid in memory_ids]
2160 results = await asyncio.gather(*tasks)
2161 return sum(1 for r in results if r)
2163 async def arecall_many(
2164 self,
2165 user_ids: list[str],
2166 queries: list[str],
2167 top_k: int = 5,
2168 max_tokens: int | None = None,
2169 lifecycle_filter: list[LifecycleState] | None = None,
2170 hybrid_search: bool | None = None,
2171 namespace: str = "default",
2172 session_id: str | None = None,
2173 metadata_filter: dict[str, Any] | None = None,
2174 ) -> dict[str, list[MemoryObject]]:
2175 """Async version of recall_many() — runs individual recalls concurrently."""
2176 import asyncio
2178 if len(user_ids) != len(queries):
2179 raise ValueError("user_ids and queries must have the same length")
2181 tasks = [
2182 self.arecall(
2183 uid,
2184 q,
2185 top_k=top_k,
2186 max_tokens=max_tokens,
2187 lifecycle_filter=lifecycle_filter,
2188 hybrid_search=hybrid_search,
2189 namespace=namespace,
2190 session_id=session_id,
2191 metadata_filter=metadata_filter,
2192 )
2193 for uid, q in zip(user_ids, queries, strict=True)
2194 ]
2195 results = await asyncio.gather(*tasks)
2196 return {uid: res for uid, res in zip(user_ids, results, strict=True)}
2198 async def arecall_since(
2199 self,
2200 user_id: str,
2201 query: str,
2202 hours: float = 24.0,
2203 top_k: int = 5,
2204 max_tokens: int | None = None,
2205 lifecycle_filter: list[LifecycleState] | None = None,
2206 ) -> list[MemoryObject]:
2207 """Async version of recall_since()."""
2208 import asyncio
2210 return await asyncio.to_thread(
2211 self.recall_since, user_id, query, hours, top_k, max_tokens, lifecycle_filter
2212 )
2214 async def aremember_many(
2215 self,
2216 user_id: str,
2217 contents: list[str],
2218 importance: float = 0.5,
2219 source: MemorySource = MemorySource.USER_STATED,
2220 tags: list[str] | None = None,
2221 namespace: str = "default",
2222 session_id: str | None = None,
2223 memory_type: MemoryType = MemoryType.EPISODIC,
2224 confidence: float = 1.0,
2225 agent_id: str | None = None,
2226 run_id: str | None = None,
2227 app_id: str | None = None,
2228 ) -> list[str]:
2229 """Async version of remember_many()."""
2230 import asyncio
2232 return await asyncio.to_thread(
2233 self.remember_many,
2234 user_id,
2235 contents,
2236 importance,
2237 source,
2238 tags,
2239 namespace,
2240 session_id,
2241 memory_type,
2242 confidence,
2243 agent_id,
2244 run_id,
2245 app_id,
2246 )
2248 def feedback(
2249 self,
2250 user_id: str,
2251 memory_id: str,
2252 helpful: bool = True,
2253 namespace: str = "default",
2254 ) -> None:
2255 """Record user feedback on a recalled memory.
2257 Stores feedback in metadata and adjusts importance:
2258 - helpful=True: boosts importance slightly (up to 1.0)
2259 - helpful=False: reduces importance slightly (down to 0.0)
2261 Args:
2262 user_id: User ID.
2263 memory_id: Memory ID that was recalled.
2264 helpful: Whether the memory was helpful.
2265 namespace: Memory namespace.
2266 """
2267 if not user_id or not user_id.strip():
2268 raise ValueError("user_id cannot be empty")
2269 if not memory_id or not memory_id.strip():
2270 raise ValueError("memory_id cannot be empty")
2272 memory = self._store.get(memory_id)
2273 if memory is None:
2274 raise ValueError(f"Memory not found: {memory_id}")
2275 if memory.user_id != user_id:
2276 raise ValueError("Memory does not belong to this user")
2277 if memory.namespace != namespace:
2278 raise ValueError("Memory does not belong to this namespace")
2280 # Record feedback history
2281 feedback_entry = {
2282 "helpful": helpful,
2283 "timestamp": datetime.now(timezone.utc).isoformat(),
2284 }
2285 if "feedback" not in memory.metadata:
2286 memory.metadata["feedback"] = []
2287 memory.metadata["feedback"].append(feedback_entry)
2289 # Adjust importance based on feedback
2290 adjustment = 0.05 if helpful else -0.05
2291 memory.importance = max(0.0, min(1.0, memory.importance + adjustment))
2293 self._store.update(memory)
2294 logger.info(
2295 f"Feedback recorded for memory {memory_id}: helpful={helpful}, "
2296 f"new_importance={memory.importance:.2f}"
2297 )
2298 self._track_operation(
2299 "feedback", user_id, {"memory_id": memory_id, "helpful": helpful}, memory_id, namespace
2300 )
2302 def backfill_entities(
2303 self,
2304 user_id: str | None = None,
2305 namespace: str | None = None,
2306 ) -> int:
2307 """Backfill ``extracted_entities`` for memories that don't have them yet.
2309 Iterates over memories (optionally filtered by *user_id* and
2310 *namespace*), extracts entities from their content using the
2311 configured :attr:`_entity_linker`, and persists the result in
2312 ``memory.metadata["extracted_entities"]``.
2314 This is useful after enabling entity boost on an existing store
2315 so that subsequent recall calls can read cached entities instead
2316 of falling back to on-the-fly extraction.
2318 Args:
2319 user_id: If provided, only backfill this user's memories.
2320 If ``None``, backfill across all users.
2321 namespace: If provided, only backfill memories in this
2322 namespace. If ``None``, backfill across all namespaces.
2324 Returns:
2325 Number of memories that were backfilled.
2326 """
2327 if not self._config.enable_entity_boost:
2328 logger.info("Entity boost is disabled; skipping entity backfill")
2329 return 0
2331 with self._latency_tracker("backfill_entities"):
2332 if user_id is not None:
2333 users = [user_id]
2334 else:
2335 users = self._store.get_all_users()
2337 backfilled = 0
2338 for uid in users:
2339 if namespace is not None:
2340 namespaces = [namespace]
2341 else:
2342 namespaces = self._known_namespaces(uid)
2344 for ns in namespaces:
2345 memories = self._store.get_all_by_user(
2346 uid,
2347 lifecycle_filter=[
2348 LifecycleState.ACTIVE,
2349 LifecycleState.DECAYING,
2350 LifecycleState.ARCHIVED,
2351 ],
2352 namespace=ns,
2353 )
2354 for mem in memories:
2355 if "extracted_entities" in mem.metadata:
2356 continue
2357 entities = self._entity_linker.extract(mem.content)
2358 mem.metadata["extracted_entities"] = list(entities)
2359 self._store.update(mem)
2360 backfilled += 1
2362 if backfilled > 0:
2363 logger.info(
2364 f"Backfilled extracted_entities for {backfilled} memories"
2365 + (f" (user={user_id})" if user_id else "")
2366 )
2367 self._track_operation(
2368 "backfill_entities",
2369 user_id or "all",
2370 {"backfilled": backfilled},
2371 namespace=namespace or "all",
2372 )
2373 return backfilled
2375 async def abackfill_entities(
2376 self,
2377 user_id: str | None = None,
2378 namespace: str | None = None,
2379 ) -> int:
2380 """Async version of :meth:`backfill_entities`."""
2381 import asyncio
2383 return await asyncio.to_thread(
2384 self.backfill_entities, user_id, namespace
2385 )
2387 def run_maintenance(
2388 self,
2389 user_id: str,
2390 namespace: str = "default",
2391 auto_prune: bool = True,
2392 auto_consolidate: bool = True,
2393 auto_backfill_entities: bool = False,
2394 prune_max_age_days: float = 90.0,
2395 prune_min_importance: float = 0.1,
2396 consolidate_min_memories: int = 5,
2397 consolidate_max_age_days: float = 30.0,
2398 auto_prune_expired: bool = True,
2399 consolidate_with_llm_summary: bool = False,
2400 ) -> dict[str, Any]:
2401 """Run automatic maintenance tasks for a user's memories.
2403 This is a one-shot maintenance run. For periodic maintenance,
2404 call this method from a scheduler (e.g., cron, APScheduler).
2406 Args:
2407 user_id: User ID to maintain.
2408 namespace: Memory namespace.
2409 auto_prune: Whether to prune old memories.
2410 auto_consolidate: Whether to consolidate old episodic memories.
2411 auto_backfill_entities: Whether to backfill missing
2412 ``extracted_entities`` metadata.
2413 prune_max_age_days: Delete DECAYING memories older than this.
2414 prune_min_importance: Delete memories below this importance.
2415 consolidate_min_memories: Minimum memories to consolidate.
2416 consolidate_max_age_days: Only consolidate memories older than this.
2417 auto_prune_expired: Whether to delete TTL-expired memories.
2418 consolidate_with_llm_summary: Use LLM-powered summarization.
2420 Returns:
2421 Dict with 'pruned' (int), 'expired' (int),
2422 'consolidated' (str | None), and 'backfilled' (int) keys.
2423 """
2424 if not user_id or not user_id.strip():
2425 raise ValueError("user_id cannot be empty")
2427 with self._latency_tracker("run_maintenance"):
2428 result: dict[str, Any] = {
2429 "pruned": 0,
2430 "expired": 0,
2431 "consolidated": None,
2432 "backfilled": 0,
2433 }
2435 if auto_backfill_entities:
2436 backfilled = self.backfill_entities(
2437 user_id=user_id, namespace=namespace
2438 )
2439 result["backfilled"] = backfilled
2441 if auto_prune:
2442 pruned = self.prune(
2443 user_id=user_id,
2444 max_age_days=prune_max_age_days,
2445 min_importance=prune_min_importance,
2446 namespace=namespace,
2447 )
2448 result["pruned"] = pruned
2450 if auto_prune_expired:
2451 expired = self.prune_expired(
2452 user_id=user_id, namespace=namespace
2453 )
2454 result["expired"] = expired
2456 if auto_consolidate:
2457 consolidated_id = self.consolidate(
2458 user_id=user_id,
2459 namespace=namespace,
2460 min_memories=consolidate_min_memories,
2461 max_age_days=consolidate_max_age_days,
2462 with_llm_summary=consolidate_with_llm_summary,
2463 )
2464 result["consolidated"] = consolidated_id
2466 logger.info(f"Maintenance complete for {user_id}: {result}")
2467 self._track_operation("run_maintenance", user_id, result, namespace=namespace)
2468 return result
2470 def get_metrics(self) -> dict[str, Any] | None:
2471 """Return current metrics as a dictionary."""
2472 if self._metrics is None:
2473 return None
2474 return self._metrics.to_dict() # type: ignore[no-any-return]
2476 def get_metrics_prometheus(self) -> str | None:
2477 """Return current metrics in Prometheus text format."""
2478 if self._metrics is None:
2479 return None
2480 return self._metrics.to_prometheus() # type: ignore[no-any-return]
2482 def enable_adaptive_retrieval(self, enable: bool = True) -> None:
2483 """Enable or disable adaptive retrieval."""
2484 if enable:
2485 try:
2486 from kemi.adaptive import AdaptiveRetriever
2488 self._adaptive_retriever = AdaptiveRetriever()
2489 except ImportError:
2490 logger.warning("Adaptive retrieval module not available")
2491 else:
2492 self._adaptive_retriever = None
2494 def _track_operation(
2495 self,
2496 operation: str,
2497 user_id: str,
2498 details: dict[str, Any] | None = None,
2499 memory_id: str | None = None,
2500 namespace: str = "default",
2501 status: str = "success",
2502 audit_batch: list[dict[str, Any]] | None = None,
2503 ) -> None:
2504 """Track an operation in metrics and audit trail."""
2505 from kemi.operations import _ops_metrics
2506 _ops_metrics.track_operation_full(
2507 self, operation, user_id, details, memory_id, namespace, status, audit_batch
2508 )
2510 def _record_embed_error(self) -> None:
2511 """Record an embedding error in metrics if available."""
2512 from kemi.operations import _ops_metrics
2513 _ops_metrics.record_embed_error(self)
2515 def _record_store_error(self) -> None:
2516 """Record a storage error in metrics if available."""
2517 from kemi.operations import _ops_metrics
2518 _ops_metrics.record_store_error(self)
2520 def add_event_hook(self, phase: str, callback: Callable[..., Any]) -> None:
2521 """Register an event hook callback.
2523 Args:
2524 phase: "pre" or "post" — called before or after the operation.
2525 callback: Callable that receives (operation, **kwargs).
2526 """
2527 from kemi.operations import _ops_hooks
2528 _ops_hooks.add(self, phase, callback)
2530 def remove_event_hook(self, phase: str, callback: Callable[..., Any]) -> bool:
2531 """Remove a previously registered event hook callback.
2533 Returns True if removed, False if not found.
2534 """
2535 from kemi.operations import _ops_hooks
2536 return _ops_hooks.remove(self, phase, callback)
2538 def _run_hooks(
2539 self,
2540 phase: str,
2541 operation: str,
2542 *,
2543 raise_on_error: bool | None = None,
2544 **kwargs: Any,
2545 ) -> None:
2546 """Run all hooks registered for a phase/operation.
2548 Args:
2549 phase: "pre" or "post".
2550 operation: Name of the operation triggering the hook.
2551 raise_on_error: If True, exceptions from hooks are re-raised so
2552 a failing pre-hook can abort the operation. If None (default),
2553 the value is taken from ``self._config.hooks_raise_on_error``.
2554 **kwargs: Passed through to each callback.
2555 """
2556 from kemi.operations import _ops_hooks
2557 _ops_hooks.run(self, phase, operation, raise_on_error=raise_on_error, **kwargs)
2559 def enable_query_cache(self, max_size: int = 128) -> None:
2560 """Enable an LRU cache for recall() results.
2562 Args:
2563 max_size: Maximum number of cached query results.
2564 """
2565 from kemi.operations import _ops_metrics
2566 _ops_metrics.enable_query_cache(self, max_size)
2568 def disable_query_cache(self) -> None:
2569 """Disable the query cache."""
2570 from kemi.operations import _ops_metrics
2571 _ops_metrics.disable_query_cache(self)
2573 def configure_versioning(
2574 self,
2575 db_path: str | None = None,
2576 max_versions_per_memory: int = 50,
2577 auto_prune_versions: bool = True,
2578 ) -> None:
2579 """Enable memory version history tracking.
2581 Args:
2582 db_path: Path to the SQLite database. Defaults to the store's db_path.
2583 max_versions_per_memory: Maximum versions to keep per memory before pruning.
2584 auto_prune_versions: If True, prune old versions when limits are exceeded.
2585 """
2586 from kemi.operations import _ops_versioning
2587 _ops_versioning.configure(
2588 self, db_path, max_versions_per_memory, auto_prune_versions
2589 )
2591 def _get_version_store(self) -> MemoryVersionStore:
2592 """Get the version store, initialising it lazily from the storage adapter's db.
2594 Falls back to an in-memory SQLite database when the storage adapter
2595 does not expose a ``_db_path`` (e.g. mock adapters used in tests,
2596 in-memory backends).
2597 """
2598 from kemi.operations import _ops_versioning
2599 return _ops_versioning.get_store(self)
2601 def get_history(
2602 self,
2603 memory_id: str,
2604 limit: int = 100,
2605 ) -> list["VersionSnapshot"]:
2606 """Return version history for a memory, newest first."""
2607 from kemi.operations import _ops_versioning
2608 return _ops_versioning.get_history(self, memory_id, limit)
2610 def diff_versions(
2611 self,
2612 memory_id: str,
2613 from_version: int,
2614 to_version: int,
2615 ) -> "DiffResult | None":
2616 """Show field-level differences between two versions of a memory."""
2617 from kemi.operations import _ops_versioning
2618 return _ops_versioning.diff(self, memory_id, from_version, to_version)
2620 def rollback_memory(
2621 self,
2622 memory_id: str,
2623 target_version: int,
2624 ) -> "RollbackResult | None":
2625 """Roll a memory back to a previous version."""
2626 from kemi.operations import _ops_versioning
2627 return _ops_versioning.rollback(self, memory_id, target_version)
2629 def _auto_prune_versions_for_memory(self, memory_id: str) -> None:
2630 """Prune old versions for a memory, keeping only the most recent ones."""
2631 from kemi.operations import _ops_versioning
2632 _ops_versioning.auto_prune(self, memory_id)
2634 def configure_webhooks(self, db_path: str | None = None) -> None:
2635 """Enable webhook dispatch for memory lifecycle events.
2637 Args:
2638 db_path: Path to SQLite database for webhook config storage.
2639 Defaults to the same path used by the storage adapter.
2640 """
2641 from kemi.operations import _ops_webhooks
2642 _ops_webhooks.configure(self, db_path)
2644 def _dispatch_webhook_event(
2645 self,
2646 event: WebhookEventType,
2647 memory_id: str,
2648 user_id: str,
2649 snapshot: dict[str, Any] | None = None,
2650 previous_state: dict[str, Any] | None = None,
2651 **extra: Any,
2652 ) -> None:
2653 """Dispatch a webhook event if a dispatcher is configured."""
2654 from kemi.operations import _ops_webhooks
2655 _ops_webhooks.dispatch(
2656 self, event, memory_id, user_id, snapshot, previous_state, **extra
2657 )
2659 def enable_audit_trail(
2660 self,
2661 retention_days: int = 365,
2662 auto_purge: bool = True,
2663 ) -> None:
2664 """Enable the audit trail for compliance logging."""
2665 from kemi.operations import _ops_metrics
2666 _ops_metrics.enable_audit_trail(self, retention_days, auto_purge)
2668 def get_metrics(self) -> dict[str, Any] | None:
2669 """Return current metrics snapshot as a dict, or None if disabled."""
2670 from kemi.operations import _ops_metrics
2671 return _ops_metrics.get_metrics(self)
2673 def get_metrics_prometheus(self) -> str | None:
2674 """Return metrics in Prometheus text format, or None if disabled."""
2675 from kemi.operations import _ops_metrics
2676 return _ops_metrics.get_metrics_prometheus(self)
2678 def enable_adaptive_retrieval(self, enable: bool = True) -> None:
2679 """Enable or disable adaptive retrieval (re-weights hybrid scores per user)."""
2680 from kemi.operations import _ops_metrics
2681 _ops_metrics.enable_adaptive_retrieval(self, enable)
2684class _QueryCache:
2685 """DEPRECATED shim — moved to :mod:`kemi.operations._query_cache`.
2687 Kept as a re-export so existing imports of ``kemi.core._QueryCache``
2688 keep working. The canonical location is ``kemi.operations._query_cache._QueryCache``.
2689 """
2691 def __new__(cls, *args: Any, **kwargs: Any) -> Any: # pragma: no cover
2692 from kemi.operations._query_cache import _QueryCache as _Impl
2694 return _Impl(*args, **kwargs)
2697_QueryCache.__doc__ = _QueryCache.__doc__