Coverage for session_buddy / memory / entity_extractor.py: 66.21%
125 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1"""
2LLM-Powered Entity Extraction - Memori pattern with multi-provider cascade.
4Primary: OpenAI → Anthropic → Gemini → Pattern-based fallback.
6Uses Pydantic models for typed outputs. Providers are optional; cascade
7skips any unavailable provider gracefully and falls back to patterns.
8"""
10import asyncio
11import logging
12from dataclasses import dataclass
13from datetime import datetime
14from typing import Any
16from pydantic import BaseModel, Field
18logger = logging.getLogger(__name__)
21# Pydantic models for structured extraction (Memori pattern)
22class ExtractedEntity(BaseModel):
23 """Single extracted entity with type and confidence."""
25 entity_type: str = Field(
26 description="Type of entity: person, technology, file, concept, organization"
27 )
28 entity_value: str = Field(description="The actual entity value")
29 confidence: float = Field(
30 default=1.0, ge=0.0, le=1.0, description="Confidence score 0.0-1.0"
31 )
34class EntityRelationship(BaseModel):
35 """Relationship between two entities."""
37 from_entity: str = Field(description="Source entity value")
38 to_entity: str = Field(description="Target entity value")
39 relationship_type: str = Field(
40 description="Type: uses, extends, references, related_to, depends_on"
41 )
42 strength: float = Field(
43 default=1.0, ge=0.0, le=1.0, description="Relationship strength"
44 )
47class ProcessedMemory(BaseModel):
48 """
49 Complete processed memory structure - Memori pattern.
51 This is the output from LLM-powered analysis of conversations.
52 """
54 # Categorization (Memori's 5 categories)
55 category: str = Field(
56 description="Memory category: facts, preferences, skills, rules, context"
57 )
58 subcategory: str | None = Field(
59 default=None, description="Optional subcategory for finer granularity"
60 )
62 # Importance scoring
63 importance_score: float = Field(
64 ge=0.0,
65 le=1.0,
66 description="Importance score 0.0-1.0 based on relevance and utility",
67 )
69 # Content processing
70 summary: str = Field(
71 description="Concise summary of the conversation (1-2 sentences)"
72 )
73 searchable_content: str = Field(
74 description="Optimized content for search/retrieval"
75 )
76 reasoning: str = Field(description="Why this memory is important and how to use it")
78 # Entity extraction
79 entities: list[ExtractedEntity] = Field(
80 default_factory=list, description="Extracted entities from conversation"
81 )
82 relationships: list[EntityRelationship] = Field(
83 default_factory=list, description="Relationships between entities"
84 )
86 # Metadata
87 suggested_tier: str = Field(
88 default="long_term",
89 description="Suggested memory tier: working, short_term, long_term",
90 )
91 tags: list[str] = Field(
92 default_factory=list, description="Relevant tags for categorization"
93 )
96@dataclass
97class EntityExtractionResult:
98 """Result of entity extraction operation."""
100 processed_memory: ProcessedMemory
101 entities_count: int
102 relationships_count: int
103 extraction_time_ms: float
104 llm_provider: str
107class LLMEntityExtractor:
108 """
109 LLM-powered entity extraction using OpenAI Structured Outputs.
111 Inspired by Memori's MemoryAgent pattern but adapted for session-mgmt-mcp's
112 development workflow context.
113 """
115 def __init__(
116 self,
117 llm_provider: str = "openai",
118 model: str = "gpt-4o-mini",
119 api_key: str | None = None,
120 ):
121 """
122 Initialize entity extractor with LLM configuration.
124 Args:
125 llm_provider: LLM provider (openai, anthropic, etc.)
126 model: Model name (gpt-4o-mini recommended for cost/performance)
127 api_key: Optional API key (uses environment variable if not provided)
129 """
130 self.llm_provider = llm_provider
131 self.model = model
132 self.api_key = api_key
133 self._client: Any = None
135 async def initialize(self) -> None:
136 """Initialize LLM client (lazy initialization)."""
137 if self._client is not None:
138 return
140 try:
141 if self.llm_provider == "openai":
142 from openai import AsyncOpenAI
144 self._client = AsyncOpenAI(api_key=self.api_key)
145 logger.info(f"Initialized OpenAI client with model: {self.model}")
146 else:
147 msg = f"Unsupported LLM provider: {self.llm_provider}"
148 raise ValueError(msg)
149 except ImportError:
150 logger.exception(
151 f"LLM provider '{self.llm_provider}' not available. "
152 "Install openai package: pip install openai"
153 )
154 raise
156 async def extract_entities(
157 self,
158 user_input: str,
159 ai_output: str,
160 context: dict[str, Any] | None = None,
161 ) -> EntityExtractionResult:
162 """
163 Extract entities and categorize memory using LLM structured outputs.
165 Args:
166 user_input: User's input message
167 ai_output: AI assistant's response
168 context: Optional context (project, session_id, etc.)
170 Returns:
171 EntityExtractionResult with processed memory
173 """
174 await self.initialize()
176 start_time = datetime.now()
178 # Build prompt requesting JSON compatible with ProcessedMemory
179 system = (
180 "You are an information extraction assistant. Return ONLY valid JSON "
181 "matching this schema keys: {category, subcategory, importance_score, "
182 "summary, searchable_content, reasoning, entities, relationships, "
183 "suggested_tier, tags}. Entities contain {entity_type, entity_value, confidence}. "
184 "Relationships contain {from_entity, to_entity, relationship_type, strength}."
185 )
186 prompt = (
187 f"User: {user_input}\nAssistant: {ai_output}\n"
188 "Extract structured memory now."
189 )
191 try:
192 # Prefer OpenAI structured output when available
193 if self.llm_provider == "openai":
194 client = self._client
195 assert client is not None
196 response = await client.chat.completions.create(
197 model=self.model,
198 messages=[
199 {"role": "system", "content": system},
200 {"role": "user", "content": prompt},
201 ],
202 response_format={"type": "json_object"},
203 )
204 content = response.choices[0].message.content or "{}"
205 pm = ProcessedMemory.model_validate_json(content)
206 processed_memory = pm
207 else:
208 # Unsupported provider in this class; delegate to cascade engine
209 msg = "Unsupported provider in LLMEntityExtractor"
210 raise RuntimeError(msg)
212 extraction_time = (datetime.now() - start_time).total_seconds() * 1000
213 return EntityExtractionResult(
214 processed_memory=processed_memory,
215 entities_count=len(processed_memory.entities),
216 relationships_count=len(processed_memory.relationships),
217 extraction_time_ms=extraction_time,
218 llm_provider=self.llm_provider,
219 )
220 except Exception:
221 # Fall back to a minimal default to avoid hard failure
222 logger.info("LLM extraction failed; falling back to default output")
223 processed_memory = ProcessedMemory(
224 category="context",
225 importance_score=0.5,
226 summary="Conversation recorded",
227 searchable_content=f"{user_input} {ai_output}",
228 reasoning="LLM extraction fallback",
229 )
230 extraction_time = (datetime.now() - start_time).total_seconds() * 1000
231 return EntityExtractionResult(
232 processed_memory=processed_memory,
233 entities_count=0,
234 relationships_count=0,
235 extraction_time_ms=extraction_time,
236 llm_provider=self.llm_provider,
237 )
240class PatternBasedExtractor:
241 """Regex/keyword-based extraction as a no-deps fallback."""
243 def _categorize(self, text: str) -> str:
244 lower = text.lower()
245 if any(k in lower for k in ("prefer", "like", "avoid")): 245 ↛ 246line 245 didn't jump to line 246 because the condition on line 245 was never true
246 return "preferences"
247 if any(k in lower for k in ("skill", "learned", "expert")): 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true
248 return "skills"
249 if any(k in lower for k in ("rule", "policy", "guideline")): 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true
250 return "rules"
251 if any(k in lower for k in ("context", "today", "currently", "now")): 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true
252 return "context"
253 return "facts"
255 async def extract_entities(
256 self, user_input: str, ai_output: str
257 ) -> ProcessedMemory:
258 text = f"{user_input}\n{ai_output}"
259 category = self._categorize(text)
260 return ProcessedMemory(
261 category=category,
262 importance_score=0.5,
263 summary="Conversation recorded",
264 searchable_content=text,
265 reasoning="Pattern-based extraction",
266 tags=[category],
267 suggested_tier="long_term",
268 )
271class EntityExtractionEngine:
272 """Multi-provider extraction with cascade fallback."""
274 def __init__(self) -> None:
275 from session_buddy.llm_providers import LLMManager, LLMMessage
276 from session_buddy.settings import get_settings
278 self._LLMMessage = LLMMessage
279 self.manager = LLMManager()
280 self.fallback_extractor = PatternBasedExtractor()
281 settings = get_settings()
282 self.timeout_s = settings.llm_extraction_timeout
283 self.retries = settings.llm_extraction_retries
285 async def extract_entities(
286 self, user_input: str, ai_output: str
287 ) -> EntityExtractionResult:
288 system = (
289 "You are an information extraction assistant. Return ONLY valid JSON "
290 "for keys: category, subcategory, importance_score, summary, "
291 "searchable_content, reasoning, entities, relationships, suggested_tier, tags."
292 )
293 from session_buddy.llm_providers import LLMMessage
295 messages = [
296 LLMMessage(role="system", content=system),
297 LLMMessage(
298 role="user",
299 content=(
300 "Extract structured memory from the following.\n"
301 f"User: {user_input}\nAssistant: {ai_output}"
302 ),
303 ),
304 ]
306 providers = ["openai", "anthropic", "gemini"]
307 start_time = datetime.now()
309 for provider in providers:
310 try:
311 resp: Any | None = (
312 None # Initialize to prevent "possibly unbound" error
313 )
314 for attempt in range(max(1, self.retries + 1)): 314 ↛ 327line 314 didn't jump to line 327 because the loop on line 314 didn't complete
315 try:
316 resp = await asyncio.wait_for(
317 self.manager.generate(
318 messages, provider=provider, temperature=0.2
319 ),
320 timeout=self.timeout_s,
321 )
322 break
323 except Exception:
324 if attempt >= self.retries:
325 raise
326 continue
327 assert resp is not None # Type narrowing for pyright
328 pm = ProcessedMemory.model_validate_json(resp.content)
329 extraction_time = (datetime.now() - start_time).total_seconds() * 1000
330 return EntityExtractionResult(
331 processed_memory=pm,
332 entities_count=len(pm.entities),
333 relationships_count=len(pm.relationships),
334 extraction_time_ms=extraction_time,
335 llm_provider=provider,
336 )
337 except Exception as e:
338 logger.warning(f"{provider} extraction failed: {e}")
339 continue
341 # Final fallback: pattern-based
342 pm = await self.fallback_extractor.extract_entities(user_input, ai_output)
343 extraction_time = (datetime.now() - start_time).total_seconds() * 1000
344 return EntityExtractionResult(
345 processed_memory=pm,
346 entities_count=len(pm.entities),
347 relationships_count=len(pm.relationships),
348 extraction_time_ms=extraction_time,
349 llm_provider="pattern",
350 )