Coverage for src / kemi / memory_formation.py: 96%

137 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""MemoryFormation – turn conversations into structured memories. 

2 

3Extracts atomic facts and events from chat histories using a pluggable 

4LLM extractor, deduplicates them against existing memories, and returns 

5:class:`MemoryObject` instances ready for persistence. 

6 

7Example:: 

8 

9 from kemi.core import Memory 

10 from kemi.memory_formation import remember_from_conversation, OpenAIMemoryExtractor 

11 

12 mem = Memory() 

13 conversation = [ 

14 {"role": "user", "content": "I love hiking in the Alps."}, 

15 {"role": "assistant", "content": "That sounds amazing!"}, 

16 {"role": "user", "content": "My favourite trail is the Tour du Mont Blanc."}, 

17 ] 

18 ids = remember_from_conversation( 

19 mem, conversation, user_id="alice", extractor=OpenAIMemoryExtractor() 

20 ) 

21""" 

22 

23from __future__ import annotations 

24 

25import json 

26import logging 

27import re 

28import uuid 

29from dataclasses import dataclass, field 

30from datetime import datetime, timezone 

31from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable 

32 

33from kemi import dedup 

34from kemi.adapters.base import EmbeddingAdapter, StorageAdapter 

35from kemi.models import ( 

36 LifecycleState, 

37 MemoryConfig, 

38 MemoryObject, 

39 MemorySource, 

40 MemoryType, 

41) 

42 

43if TYPE_CHECKING: 

44 from kemi.core import Memory 

45 

46logger = logging.getLogger(__name__) 

47 

48 

49# --------------------------------------------------------------------------- 

50# Protocol 

51# --------------------------------------------------------------------------- 

52 

53@runtime_checkable 

54class LLMMemoryExtractor(Protocol): 

55 """Protocol for conversation-to-memory extractors. 

56 

57 Any class implementing ``extract`` can be plugged into 

58 :func:`extract_memories` or :func:`remember_from_conversation`. 

59 """ 

60 

61 def extract( 

62 self, 

63 conversation: list[dict[str, Any]], 

64 *, 

65 user_id: str, 

66 session_id: str | None = None, 

67 ) -> list[CandidateMemory]: 

68 """Extract candidate memories from a conversation. 

69 

70 Args: 

71 conversation: List of messages. Each item should contain at 

72 least ``role`` (``"user"`` | ``"assistant"`` | ``"system"``) 

73 and ``content`` (``str``). An optional ``timestamp`` may be 

74 included. 

75 user_id: User the conversation belongs to. 

76 session_id: Optional session identifier. 

77 

78 Returns: 

79 List of candidate memories ready for deduplication and storage. 

80 """ 

81 ... 

82 

83 

84# --------------------------------------------------------------------------- 

85# Data model 

86# --------------------------------------------------------------------------- 

87 

88@dataclass 

89class CandidateMemory: 

90 """A memory candidate produced by an extractor before embedding and storage.""" 

91 

92 content: str 

93 importance: float = 0.5 

94 memory_type: MemoryType = MemoryType.EPISODIC 

95 tags: list[str] = field(default_factory=lambda: []) 

96 metadata: dict[str, Any] = field(default_factory=lambda: {}) 

97 

98 

99# --------------------------------------------------------------------------- 

100# Built-in extractors 

101# --------------------------------------------------------------------------- 

102 

103class RegexMemoryExtractor: 

104 """Simple regex/heuristic extractor for tests and local use. 

105 

106 Requires no external LLM. Matches common patterns such as preferences, 

107 goals, and personal facts. 

108 """ 

109 

110 _PATTERNS: list[tuple[str, list[str], MemoryType, float]] = [ 

111 # (regex, tags, memory_type, importance) 

112 (r"I like\s+(.+?)[.!?]", ["preference"], MemoryType.SEMANTIC, 0.6), 

113 (r"My name is\s+(.+?)[.!?]", ["identity"], MemoryType.SEMANTIC, 0.8), 

114 (r"I am\s+(.+?)[.!?]", ["identity"], MemoryType.SEMANTIC, 0.7), 

115 (r"I want to\s+(.+?)[.!?]", ["goal"], MemoryType.EPISODIC, 0.7), 

116 (r"I need to\s+(.+?)[.!?]", ["goal"], MemoryType.EPISODIC, 0.7), 

117 (r"I prefer\s+(.+?)[.!?]", ["preference"], MemoryType.SEMANTIC, 0.6), 

118 (r"I live in\s+(.+?)[.!?]", ["location"], MemoryType.SEMANTIC, 0.7), 

119 (r"I work at\s+(.+?)[.!?]", ["work"], MemoryType.SEMANTIC, 0.7), 

120 (r"Remember that\s+(.+?)[.!?]", ["reminder"], MemoryType.EPISODIC, 0.8), 

121 (r"Don't forget\s+(.+?)[.!?]", ["reminder"], MemoryType.EPISODIC, 0.8), 

122 ] 

123 

124 def __init__(self) -> None: 

125 self._compiled = [ 

126 (re.compile(p, re.IGNORECASE), tags, mtype, imp) 

127 for p, tags, mtype, imp in self._PATTERNS 

128 ] 

129 

130 def extract( 

131 self, 

132 conversation: list[dict[str, Any]], 

133 *, 

134 user_id: str, 

135 session_id: str | None = None, 

136 ) -> list[CandidateMemory]: 

137 candidates: list[CandidateMemory] = [] 

138 seen: set[str] = set() 

139 

140 for msg in conversation: 

141 content = msg.get("content", "") 

142 if not content: 

143 continue 

144 

145 for pattern, tags, mtype, imp in self._compiled: 

146 for match in pattern.finditer(content): 

147 fact = match.group(1).strip() 

148 if not fact or fact.lower() in seen: 

149 continue 

150 seen.add(fact.lower()) 

151 

152 meta: dict[str, Any] = {} 

153 if "timestamp" in msg: 

154 meta["extracted_from_timestamp"] = msg["timestamp"] 

155 if session_id: 

156 meta["session_id"] = session_id 

157 

158 candidates.append( 

159 CandidateMemory( 

160 content=fact, 

161 importance=imp, 

162 memory_type=mtype, 

163 tags=list(tags), 

164 metadata=meta, 

165 ) 

166 ) 

167 

168 return candidates 

169 

170 

171class OpenAIMemoryExtractor: 

172 """OpenAI-powered memory extractor. 

173 

174 Uses the Chat Completions API with a structured system prompt to turn 

175 a conversation into atomic memory candidates. 

176 """ 

177 

178 def __init__( 

179 self, 

180 model: str = "gpt-4o-mini", 

181 api_key: str | None = None, 

182 base_url: str | None = None, 

183 ) -> None: 

184 try: 

185 import openai 

186 except ImportError as exc: 

187 raise ImportError( 

188 "OpenAIMemoryExtractor requires the 'openai' package. " 

189 "Install it with: pip install openai" 

190 ) from exc 

191 

192 self._client = openai.OpenAI(api_key=api_key, base_url=base_url) 

193 self._model = model 

194 

195 def extract( 

196 self, 

197 conversation: list[dict[str, Any]], 

198 *, 

199 user_id: str, 

200 session_id: str | None = None, 

201 ) -> list[CandidateMemory]: 

202 system_prompt = ( 

203 "You are a memory extraction assistant. " 

204 "Given a conversation, extract atomic facts and events that would be " 

205 "useful to remember about the user. " 

206 "For each memory, provide: content (short sentence), importance (0.0-1.0), " 

207 "type ('episodic' or 'semantic'), tags (list of strings), and metadata (dict). " 

208 "Return ONLY a JSON array of objects. Do not include markdown formatting." 

209 ) 

210 

211 messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}] 

212 for msg in conversation: 

213 role = msg.get("role", "user") 

214 content = msg.get("content", "") 

215 messages.append({"role": role, "content": content}) 

216 

217 try: 

218 response = self._client.chat.completions.create( 

219 model=self._model, 

220 messages=messages, # type: ignore[arg-type] 

221 temperature=0.3, 

222 max_tokens=1024, 

223 ) 

224 raw = response.choices[0].message.content or "[]" 

225 raw = raw.strip() 

226 if raw.startswith("```"): 

227 raw = raw.split("```")[1].strip("json").strip() 

228 parsed: list[dict[str, Any]] = json.loads(raw) # type: ignore[assignment] 

229 except Exception: 

230 logger.exception("OpenAI memory extraction failed") 

231 return [] 

232 

233 candidates: list[CandidateMemory] = [] 

234 for item in parsed: 

235 if not isinstance(item, dict): # type: ignore[reportUnnecessaryIsinstance] 

236 continue 

237 content = str(item.get("content", "")).strip() 

238 if not content: 

239 continue 

240 

241 mtype_str = str(item.get("type", "episodic")) 

242 mtype = MemoryType.EPISODIC if mtype_str == "episodic" else MemoryType.SEMANTIC 

243 

244 meta: dict[str, Any] = dict(item.get("metadata", {})) 

245 if session_id: 

246 meta["session_id"] = session_id 

247 

248 candidates.append( 

249 CandidateMemory( 

250 content=content, 

251 importance=float(item.get("importance", 0.5)), 

252 memory_type=mtype, 

253 tags=list(item.get("tags", [])), 

254 metadata=meta, 

255 ) 

256 ) 

257 

258 return candidates 

259 

260 

261class StaticMemoryExtractor: 

262 """Extractor that returns a fixed list of candidates. 

263 

264 Useful for deterministic testing or as a no-op placeholder. 

265 """ 

266 

267 def __init__(self, candidates: list[CandidateMemory]) -> None: 

268 self._candidates = list(candidates) 

269 

270 def extract( 

271 self, 

272 conversation: list[dict[str, Any]], 

273 *, 

274 user_id: str, 

275 session_id: str | None = None, 

276 ) -> list[CandidateMemory]: 

277 return list(self._candidates) 

278 

279 

280# --------------------------------------------------------------------------- 

281# Public API 

282# --------------------------------------------------------------------------- 

283 

284def extract_memories( 

285 conversation: list[dict[str, Any]], 

286 *, 

287 user_id: str, 

288 session_id: str | None = None, 

289 extractor: LLMMemoryExtractor | None = None, 

290 embed: EmbeddingAdapter | None = None, 

291 store: StorageAdapter | None = None, 

292 config: MemoryConfig | None = None, 

293 dedup_threshold: float = 0.85, 

294 namespace: str = "default", 

295) -> list[MemoryObject]: 

296 """Extract and deduplicate memories from a conversation. 

297 

298 Steps: 

299 

300 1. Runs the extractor over the conversation to get candidate memories. 

301 2. Batch-embeds all candidate contents. 

302 3. Creates :class:`MemoryObject` instances. 

303 4. Deduplicates against *existing* memories in ``store`` and against 

304 each other using the configured threshold. 

305 5. Returns the accepted memories (unsorted). 

306 

307 Args: 

308 conversation: List of messages. Each item should contain at 

309 least ``role`` and ``content``. An optional ``timestamp`` may 

310 be included. 

311 user_id: User the conversation belongs to. 

312 session_id: Optional session identifier. 

313 extractor: Pluggable extractor. Defaults to 

314 :class:`RegexMemoryExtractor` if not provided. 

315 embed: Embedding adapter (e.g. ``memory._embed``). Required if 

316 candidates are to be embedded. 

317 store: Storage adapter (e.g. ``memory._store``). Used to fetch 

318 existing memories for deduplication. 

319 config: Optional :class:`MemoryConfig` for threshold defaults. 

320 dedup_threshold: Cosine-similarity threshold above which a 

321 candidate is considered a duplicate (default 0.85). 

322 namespace: Memory namespace for extracted objects and for the 

323 dedup query against existing memories. 

324 

325 Returns: 

326 List of :class:`MemoryObject` instances ready to persist. 

327 """ 

328 if not conversation: 

329 return [] 

330 

331 if embed is None: 

332 raise ValueError( 

333 "An embedding adapter is required. Pass embed=memory._embed " 

334 "or initialise a Memory instance first." 

335 ) 

336 

337 if extractor is None: 

338 extractor = RegexMemoryExtractor() 

339 

340 candidates = extractor.extract(conversation, user_id=user_id, session_id=session_id) 

341 if not candidates: 

342 return [] 

343 

344 # Resolve threshold 

345 threshold = dedup_threshold 

346 if config is not None: 

347 threshold = config.dedup_threshold 

348 

349 contents = [c.content for c in candidates] 

350 embeddings = embed.embed(contents) 

351 

352 now = datetime.now(timezone.utc) 

353 memory_objects: list[MemoryObject] = [] 

354 for i, cand in enumerate(candidates): 

355 memory_objects.append( 

356 MemoryObject( 

357 memory_id=str(uuid.uuid4()), 

358 user_id=user_id, 

359 content=cand.content, 

360 embedding=embeddings[i], 

361 embedding_dim=len(embeddings[i]), 

362 importance=max(0.0, min(1.0, cand.importance)), 

363 memory_type=cand.memory_type, 

364 tags=list(cand.tags), 

365 metadata={ 

366 **cand.metadata, 

367 "source": "memory_formation", 

368 "extracted_at": now.isoformat(), 

369 }, 

370 session_id=session_id, 

371 namespace=namespace, 

372 source=MemorySource.AGENT_INFERRED, 

373 created_at=now, 

374 last_accessed_at=now, 

375 ) 

376 ) 

377 

378 # Deduplicate against existing memories 

379 existing: list[MemoryObject] = [] 

380 if store is not None: 

381 existing = store.get_all_by_user( 

382 user_id, 

383 lifecycle_filter=[ 

384 LifecycleState.ACTIVE, 

385 LifecycleState.DECAYING, 

386 LifecycleState.ARCHIVED, 

387 ], 

388 namespace=namespace, 

389 ) 

390 

391 accepted: list[MemoryObject] = [] 

392 for mem in memory_objects: 

393 # Check against existing memories 

394 dups = dedup.find_duplicates(mem, existing, threshold) 

395 if dups: 

396 logger.debug( 

397 "Skipping duplicate of existing memory %s: %s", 

398 dups[0].memory_id, 

399 mem.content[:50], 

400 ) 

401 continue 

402 

403 # Check against already-accepted candidates 

404 dups = dedup.find_duplicates(mem, accepted, threshold) 

405 if dups: 

406 logger.debug( 

407 "Skipping intra-conversation duplicate: %s", mem.content[:50] 

408 ) 

409 continue 

410 

411 accepted.append(mem) 

412 

413 return accepted 

414 

415 

416def remember_from_conversation( 

417 memory: "Memory", 

418 conversation: list[dict[str, Any]], 

419 *, 

420 user_id: str, 

421 session_id: str | None = None, 

422 extractor: LLMMemoryExtractor | None = None, 

423 dedup_threshold: float | None = None, 

424 namespace: str = "default", 

425) -> list[str]: 

426 """Extract memories from a conversation and persist them. 

427 

428 This is a convenience wrapper around :func:`extract_memories` that 

429 handles batch embedding, deduplication, and storage via the internal 

430 ``_remember_with_embedding`` path for efficiency. 

431 

432 Args: 

433 memory: A :class:`kemi.core.Memory` instance. 

434 conversation: List of messages with ``role``, ``content``, and 

435 optional ``timestamp``. 

436 user_id: User the conversation belongs to. 

437 session_id: Optional session identifier. 

438 extractor: Pluggable extractor (defaults to regex). 

439 dedup_threshold: Deduplication threshold. 

440 namespace: Memory namespace for extracted objects and for the 

441 dedup query against existing memories. 

442 

443 Returns: 

444 List of persisted memory IDs. 

445 """ 

446 if not hasattr(memory, "_embed") or not hasattr(memory, "_store"): 

447 raise TypeError( 

448 f"memory must be a kemi.core.Memory instance, got {type(memory).__name__}" 

449 ) 

450 

451 threshold = dedup_threshold 

452 if threshold is None: 

453 threshold = memory._config.dedup_threshold 

454 

455 extracted = extract_memories( 

456 conversation, 

457 user_id=user_id, 

458 session_id=session_id, 

459 extractor=extractor, 

460 embed=memory._embed, 

461 store=memory._store, 

462 config=memory._config, 

463 dedup_threshold=threshold, 

464 namespace=namespace, 

465 ) 

466 

467 memory_ids: list[str] = [] 

468 for mem in extracted: 

469 mid = memory._remember_with_embedding( 

470 user_id=user_id, 

471 content=mem.content, 

472 embedding=mem.embedding or [], 

473 importance=mem.importance, 

474 source=mem.source, 

475 metadata=mem.metadata, 

476 tags=mem.tags, 

477 namespace=namespace, 

478 session_id=mem.session_id, 

479 memory_type=mem.memory_type, 

480 confidence=mem.confidence, 

481 ) 

482 memory_ids.append(mid) 

483 

484 return memory_ids