Coverage for src / kemi / memory_formation.py: 96%
137 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""MemoryFormation – turn conversations into structured memories.
3Extracts atomic facts and events from chat histories using a pluggable
4LLM extractor, deduplicates them against existing memories, and returns
5:class:`MemoryObject` instances ready for persistence.
7Example::
9 from kemi.core import Memory
10 from kemi.memory_formation import remember_from_conversation, OpenAIMemoryExtractor
12 mem = Memory()
13 conversation = [
14 {"role": "user", "content": "I love hiking in the Alps."},
15 {"role": "assistant", "content": "That sounds amazing!"},
16 {"role": "user", "content": "My favourite trail is the Tour du Mont Blanc."},
17 ]
18 ids = remember_from_conversation(
19 mem, conversation, user_id="alice", extractor=OpenAIMemoryExtractor()
20 )
21"""
23from __future__ import annotations
25import json
26import logging
27import re
28import uuid
29from dataclasses import dataclass, field
30from datetime import datetime, timezone
31from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
33from kemi import dedup
34from kemi.adapters.base import EmbeddingAdapter, StorageAdapter
35from kemi.models import (
36 LifecycleState,
37 MemoryConfig,
38 MemoryObject,
39 MemorySource,
40 MemoryType,
41)
43if TYPE_CHECKING:
44 from kemi.core import Memory
46logger = logging.getLogger(__name__)
49# ---------------------------------------------------------------------------
50# Protocol
51# ---------------------------------------------------------------------------
53@runtime_checkable
54class LLMMemoryExtractor(Protocol):
55 """Protocol for conversation-to-memory extractors.
57 Any class implementing ``extract`` can be plugged into
58 :func:`extract_memories` or :func:`remember_from_conversation`.
59 """
61 def extract(
62 self,
63 conversation: list[dict[str, Any]],
64 *,
65 user_id: str,
66 session_id: str | None = None,
67 ) -> list[CandidateMemory]:
68 """Extract candidate memories from a conversation.
70 Args:
71 conversation: List of messages. Each item should contain at
72 least ``role`` (``"user"`` | ``"assistant"`` | ``"system"``)
73 and ``content`` (``str``). An optional ``timestamp`` may be
74 included.
75 user_id: User the conversation belongs to.
76 session_id: Optional session identifier.
78 Returns:
79 List of candidate memories ready for deduplication and storage.
80 """
81 ...
84# ---------------------------------------------------------------------------
85# Data model
86# ---------------------------------------------------------------------------
88@dataclass
89class CandidateMemory:
90 """A memory candidate produced by an extractor before embedding and storage."""
92 content: str
93 importance: float = 0.5
94 memory_type: MemoryType = MemoryType.EPISODIC
95 tags: list[str] = field(default_factory=lambda: [])
96 metadata: dict[str, Any] = field(default_factory=lambda: {})
99# ---------------------------------------------------------------------------
100# Built-in extractors
101# ---------------------------------------------------------------------------
103class RegexMemoryExtractor:
104 """Simple regex/heuristic extractor for tests and local use.
106 Requires no external LLM. Matches common patterns such as preferences,
107 goals, and personal facts.
108 """
110 _PATTERNS: list[tuple[str, list[str], MemoryType, float]] = [
111 # (regex, tags, memory_type, importance)
112 (r"I like\s+(.+?)[.!?]", ["preference"], MemoryType.SEMANTIC, 0.6),
113 (r"My name is\s+(.+?)[.!?]", ["identity"], MemoryType.SEMANTIC, 0.8),
114 (r"I am\s+(.+?)[.!?]", ["identity"], MemoryType.SEMANTIC, 0.7),
115 (r"I want to\s+(.+?)[.!?]", ["goal"], MemoryType.EPISODIC, 0.7),
116 (r"I need to\s+(.+?)[.!?]", ["goal"], MemoryType.EPISODIC, 0.7),
117 (r"I prefer\s+(.+?)[.!?]", ["preference"], MemoryType.SEMANTIC, 0.6),
118 (r"I live in\s+(.+?)[.!?]", ["location"], MemoryType.SEMANTIC, 0.7),
119 (r"I work at\s+(.+?)[.!?]", ["work"], MemoryType.SEMANTIC, 0.7),
120 (r"Remember that\s+(.+?)[.!?]", ["reminder"], MemoryType.EPISODIC, 0.8),
121 (r"Don't forget\s+(.+?)[.!?]", ["reminder"], MemoryType.EPISODIC, 0.8),
122 ]
124 def __init__(self) -> None:
125 self._compiled = [
126 (re.compile(p, re.IGNORECASE), tags, mtype, imp)
127 for p, tags, mtype, imp in self._PATTERNS
128 ]
130 def extract(
131 self,
132 conversation: list[dict[str, Any]],
133 *,
134 user_id: str,
135 session_id: str | None = None,
136 ) -> list[CandidateMemory]:
137 candidates: list[CandidateMemory] = []
138 seen: set[str] = set()
140 for msg in conversation:
141 content = msg.get("content", "")
142 if not content:
143 continue
145 for pattern, tags, mtype, imp in self._compiled:
146 for match in pattern.finditer(content):
147 fact = match.group(1).strip()
148 if not fact or fact.lower() in seen:
149 continue
150 seen.add(fact.lower())
152 meta: dict[str, Any] = {}
153 if "timestamp" in msg:
154 meta["extracted_from_timestamp"] = msg["timestamp"]
155 if session_id:
156 meta["session_id"] = session_id
158 candidates.append(
159 CandidateMemory(
160 content=fact,
161 importance=imp,
162 memory_type=mtype,
163 tags=list(tags),
164 metadata=meta,
165 )
166 )
168 return candidates
171class OpenAIMemoryExtractor:
172 """OpenAI-powered memory extractor.
174 Uses the Chat Completions API with a structured system prompt to turn
175 a conversation into atomic memory candidates.
176 """
178 def __init__(
179 self,
180 model: str = "gpt-4o-mini",
181 api_key: str | None = None,
182 base_url: str | None = None,
183 ) -> None:
184 try:
185 import openai
186 except ImportError as exc:
187 raise ImportError(
188 "OpenAIMemoryExtractor requires the 'openai' package. "
189 "Install it with: pip install openai"
190 ) from exc
192 self._client = openai.OpenAI(api_key=api_key, base_url=base_url)
193 self._model = model
195 def extract(
196 self,
197 conversation: list[dict[str, Any]],
198 *,
199 user_id: str,
200 session_id: str | None = None,
201 ) -> list[CandidateMemory]:
202 system_prompt = (
203 "You are a memory extraction assistant. "
204 "Given a conversation, extract atomic facts and events that would be "
205 "useful to remember about the user. "
206 "For each memory, provide: content (short sentence), importance (0.0-1.0), "
207 "type ('episodic' or 'semantic'), tags (list of strings), and metadata (dict). "
208 "Return ONLY a JSON array of objects. Do not include markdown formatting."
209 )
211 messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}]
212 for msg in conversation:
213 role = msg.get("role", "user")
214 content = msg.get("content", "")
215 messages.append({"role": role, "content": content})
217 try:
218 response = self._client.chat.completions.create(
219 model=self._model,
220 messages=messages, # type: ignore[arg-type]
221 temperature=0.3,
222 max_tokens=1024,
223 )
224 raw = response.choices[0].message.content or "[]"
225 raw = raw.strip()
226 if raw.startswith("```"):
227 raw = raw.split("```")[1].strip("json").strip()
228 parsed: list[dict[str, Any]] = json.loads(raw) # type: ignore[assignment]
229 except Exception:
230 logger.exception("OpenAI memory extraction failed")
231 return []
233 candidates: list[CandidateMemory] = []
234 for item in parsed:
235 if not isinstance(item, dict): # type: ignore[reportUnnecessaryIsinstance]
236 continue
237 content = str(item.get("content", "")).strip()
238 if not content:
239 continue
241 mtype_str = str(item.get("type", "episodic"))
242 mtype = MemoryType.EPISODIC if mtype_str == "episodic" else MemoryType.SEMANTIC
244 meta: dict[str, Any] = dict(item.get("metadata", {}))
245 if session_id:
246 meta["session_id"] = session_id
248 candidates.append(
249 CandidateMemory(
250 content=content,
251 importance=float(item.get("importance", 0.5)),
252 memory_type=mtype,
253 tags=list(item.get("tags", [])),
254 metadata=meta,
255 )
256 )
258 return candidates
261class StaticMemoryExtractor:
262 """Extractor that returns a fixed list of candidates.
264 Useful for deterministic testing or as a no-op placeholder.
265 """
267 def __init__(self, candidates: list[CandidateMemory]) -> None:
268 self._candidates = list(candidates)
270 def extract(
271 self,
272 conversation: list[dict[str, Any]],
273 *,
274 user_id: str,
275 session_id: str | None = None,
276 ) -> list[CandidateMemory]:
277 return list(self._candidates)
280# ---------------------------------------------------------------------------
281# Public API
282# ---------------------------------------------------------------------------
284def extract_memories(
285 conversation: list[dict[str, Any]],
286 *,
287 user_id: str,
288 session_id: str | None = None,
289 extractor: LLMMemoryExtractor | None = None,
290 embed: EmbeddingAdapter | None = None,
291 store: StorageAdapter | None = None,
292 config: MemoryConfig | None = None,
293 dedup_threshold: float = 0.85,
294 namespace: str = "default",
295) -> list[MemoryObject]:
296 """Extract and deduplicate memories from a conversation.
298 Steps:
300 1. Runs the extractor over the conversation to get candidate memories.
301 2. Batch-embeds all candidate contents.
302 3. Creates :class:`MemoryObject` instances.
303 4. Deduplicates against *existing* memories in ``store`` and against
304 each other using the configured threshold.
305 5. Returns the accepted memories (unsorted).
307 Args:
308 conversation: List of messages. Each item should contain at
309 least ``role`` and ``content``. An optional ``timestamp`` may
310 be included.
311 user_id: User the conversation belongs to.
312 session_id: Optional session identifier.
313 extractor: Pluggable extractor. Defaults to
314 :class:`RegexMemoryExtractor` if not provided.
315 embed: Embedding adapter (e.g. ``memory._embed``). Required if
316 candidates are to be embedded.
317 store: Storage adapter (e.g. ``memory._store``). Used to fetch
318 existing memories for deduplication.
319 config: Optional :class:`MemoryConfig` for threshold defaults.
320 dedup_threshold: Cosine-similarity threshold above which a
321 candidate is considered a duplicate (default 0.85).
322 namespace: Memory namespace for extracted objects and for the
323 dedup query against existing memories.
325 Returns:
326 List of :class:`MemoryObject` instances ready to persist.
327 """
328 if not conversation:
329 return []
331 if embed is None:
332 raise ValueError(
333 "An embedding adapter is required. Pass embed=memory._embed "
334 "or initialise a Memory instance first."
335 )
337 if extractor is None:
338 extractor = RegexMemoryExtractor()
340 candidates = extractor.extract(conversation, user_id=user_id, session_id=session_id)
341 if not candidates:
342 return []
344 # Resolve threshold
345 threshold = dedup_threshold
346 if config is not None:
347 threshold = config.dedup_threshold
349 contents = [c.content for c in candidates]
350 embeddings = embed.embed(contents)
352 now = datetime.now(timezone.utc)
353 memory_objects: list[MemoryObject] = []
354 for i, cand in enumerate(candidates):
355 memory_objects.append(
356 MemoryObject(
357 memory_id=str(uuid.uuid4()),
358 user_id=user_id,
359 content=cand.content,
360 embedding=embeddings[i],
361 embedding_dim=len(embeddings[i]),
362 importance=max(0.0, min(1.0, cand.importance)),
363 memory_type=cand.memory_type,
364 tags=list(cand.tags),
365 metadata={
366 **cand.metadata,
367 "source": "memory_formation",
368 "extracted_at": now.isoformat(),
369 },
370 session_id=session_id,
371 namespace=namespace,
372 source=MemorySource.AGENT_INFERRED,
373 created_at=now,
374 last_accessed_at=now,
375 )
376 )
378 # Deduplicate against existing memories
379 existing: list[MemoryObject] = []
380 if store is not None:
381 existing = store.get_all_by_user(
382 user_id,
383 lifecycle_filter=[
384 LifecycleState.ACTIVE,
385 LifecycleState.DECAYING,
386 LifecycleState.ARCHIVED,
387 ],
388 namespace=namespace,
389 )
391 accepted: list[MemoryObject] = []
392 for mem in memory_objects:
393 # Check against existing memories
394 dups = dedup.find_duplicates(mem, existing, threshold)
395 if dups:
396 logger.debug(
397 "Skipping duplicate of existing memory %s: %s",
398 dups[0].memory_id,
399 mem.content[:50],
400 )
401 continue
403 # Check against already-accepted candidates
404 dups = dedup.find_duplicates(mem, accepted, threshold)
405 if dups:
406 logger.debug(
407 "Skipping intra-conversation duplicate: %s", mem.content[:50]
408 )
409 continue
411 accepted.append(mem)
413 return accepted
416def remember_from_conversation(
417 memory: "Memory",
418 conversation: list[dict[str, Any]],
419 *,
420 user_id: str,
421 session_id: str | None = None,
422 extractor: LLMMemoryExtractor | None = None,
423 dedup_threshold: float | None = None,
424 namespace: str = "default",
425) -> list[str]:
426 """Extract memories from a conversation and persist them.
428 This is a convenience wrapper around :func:`extract_memories` that
429 handles batch embedding, deduplication, and storage via the internal
430 ``_remember_with_embedding`` path for efficiency.
432 Args:
433 memory: A :class:`kemi.core.Memory` instance.
434 conversation: List of messages with ``role``, ``content``, and
435 optional ``timestamp``.
436 user_id: User the conversation belongs to.
437 session_id: Optional session identifier.
438 extractor: Pluggable extractor (defaults to regex).
439 dedup_threshold: Deduplication threshold.
440 namespace: Memory namespace for extracted objects and for the
441 dedup query against existing memories.
443 Returns:
444 List of persisted memory IDs.
445 """
446 if not hasattr(memory, "_embed") or not hasattr(memory, "_store"):
447 raise TypeError(
448 f"memory must be a kemi.core.Memory instance, got {type(memory).__name__}"
449 )
451 threshold = dedup_threshold
452 if threshold is None:
453 threshold = memory._config.dedup_threshold
455 extracted = extract_memories(
456 conversation,
457 user_id=user_id,
458 session_id=session_id,
459 extractor=extractor,
460 embed=memory._embed,
461 store=memory._store,
462 config=memory._config,
463 dedup_threshold=threshold,
464 namespace=namespace,
465 )
467 memory_ids: list[str] = []
468 for mem in extracted:
469 mid = memory._remember_with_embedding(
470 user_id=user_id,
471 content=mem.content,
472 embedding=mem.embedding or [],
473 importance=mem.importance,
474 source=mem.source,
475 metadata=mem.metadata,
476 tags=mem.tags,
477 namespace=namespace,
478 session_id=mem.session_id,
479 memory_type=mem.memory_type,
480 confidence=mem.confidence,
481 )
482 memory_ids.append(mid)
484 return memory_ids