Coverage for session_buddy / memory / entity_extractor.py: 66.21%

125 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1""" 

2LLM-Powered Entity Extraction - Memori pattern with multi-provider cascade. 

3 

4Primary: OpenAI → Anthropic → Gemini → Pattern-based fallback. 

5 

6Uses Pydantic models for typed outputs. Providers are optional; cascade 

7skips any unavailable provider gracefully and falls back to patterns. 

8""" 

9 

10import asyncio 

11import logging 

12from dataclasses import dataclass 

13from datetime import datetime 

14from typing import Any 

15 

16from pydantic import BaseModel, Field 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21# Pydantic models for structured extraction (Memori pattern) 

22class ExtractedEntity(BaseModel): 

23 """Single extracted entity with type and confidence.""" 

24 

25 entity_type: str = Field( 

26 description="Type of entity: person, technology, file, concept, organization" 

27 ) 

28 entity_value: str = Field(description="The actual entity value") 

29 confidence: float = Field( 

30 default=1.0, ge=0.0, le=1.0, description="Confidence score 0.0-1.0" 

31 ) 

32 

33 

34class EntityRelationship(BaseModel): 

35 """Relationship between two entities.""" 

36 

37 from_entity: str = Field(description="Source entity value") 

38 to_entity: str = Field(description="Target entity value") 

39 relationship_type: str = Field( 

40 description="Type: uses, extends, references, related_to, depends_on" 

41 ) 

42 strength: float = Field( 

43 default=1.0, ge=0.0, le=1.0, description="Relationship strength" 

44 ) 

45 

46 

47class ProcessedMemory(BaseModel): 

48 """ 

49 Complete processed memory structure - Memori pattern. 

50 

51 This is the output from LLM-powered analysis of conversations. 

52 """ 

53 

54 # Categorization (Memori's 5 categories) 

55 category: str = Field( 

56 description="Memory category: facts, preferences, skills, rules, context" 

57 ) 

58 subcategory: str | None = Field( 

59 default=None, description="Optional subcategory for finer granularity" 

60 ) 

61 

62 # Importance scoring 

63 importance_score: float = Field( 

64 ge=0.0, 

65 le=1.0, 

66 description="Importance score 0.0-1.0 based on relevance and utility", 

67 ) 

68 

69 # Content processing 

70 summary: str = Field( 

71 description="Concise summary of the conversation (1-2 sentences)" 

72 ) 

73 searchable_content: str = Field( 

74 description="Optimized content for search/retrieval" 

75 ) 

76 reasoning: str = Field(description="Why this memory is important and how to use it") 

77 

78 # Entity extraction 

79 entities: list[ExtractedEntity] = Field( 

80 default_factory=list, description="Extracted entities from conversation" 

81 ) 

82 relationships: list[EntityRelationship] = Field( 

83 default_factory=list, description="Relationships between entities" 

84 ) 

85 

86 # Metadata 

87 suggested_tier: str = Field( 

88 default="long_term", 

89 description="Suggested memory tier: working, short_term, long_term", 

90 ) 

91 tags: list[str] = Field( 

92 default_factory=list, description="Relevant tags for categorization" 

93 ) 

94 

95 

96@dataclass 

97class EntityExtractionResult: 

98 """Result of entity extraction operation.""" 

99 

100 processed_memory: ProcessedMemory 

101 entities_count: int 

102 relationships_count: int 

103 extraction_time_ms: float 

104 llm_provider: str 

105 

106 

107class LLMEntityExtractor: 

108 """ 

109 LLM-powered entity extraction using OpenAI Structured Outputs. 

110 

111 Inspired by Memori's MemoryAgent pattern but adapted for session-mgmt-mcp's 

112 development workflow context. 

113 """ 

114 

115 def __init__( 

116 self, 

117 llm_provider: str = "openai", 

118 model: str = "gpt-4o-mini", 

119 api_key: str | None = None, 

120 ): 

121 """ 

122 Initialize entity extractor with LLM configuration. 

123 

124 Args: 

125 llm_provider: LLM provider (openai, anthropic, etc.) 

126 model: Model name (gpt-4o-mini recommended for cost/performance) 

127 api_key: Optional API key (uses environment variable if not provided) 

128 

129 """ 

130 self.llm_provider = llm_provider 

131 self.model = model 

132 self.api_key = api_key 

133 self._client: Any = None 

134 

135 async def initialize(self) -> None: 

136 """Initialize LLM client (lazy initialization).""" 

137 if self._client is not None: 

138 return 

139 

140 try: 

141 if self.llm_provider == "openai": 

142 from openai import AsyncOpenAI 

143 

144 self._client = AsyncOpenAI(api_key=self.api_key) 

145 logger.info(f"Initialized OpenAI client with model: {self.model}") 

146 else: 

147 msg = f"Unsupported LLM provider: {self.llm_provider}" 

148 raise ValueError(msg) 

149 except ImportError: 

150 logger.exception( 

151 f"LLM provider '{self.llm_provider}' not available. " 

152 "Install openai package: pip install openai" 

153 ) 

154 raise 

155 

156 async def extract_entities( 

157 self, 

158 user_input: str, 

159 ai_output: str, 

160 context: dict[str, Any] | None = None, 

161 ) -> EntityExtractionResult: 

162 """ 

163 Extract entities and categorize memory using LLM structured outputs. 

164 

165 Args: 

166 user_input: User's input message 

167 ai_output: AI assistant's response 

168 context: Optional context (project, session_id, etc.) 

169 

170 Returns: 

171 EntityExtractionResult with processed memory 

172 

173 """ 

174 await self.initialize() 

175 

176 start_time = datetime.now() 

177 

178 # Build prompt requesting JSON compatible with ProcessedMemory 

179 system = ( 

180 "You are an information extraction assistant. Return ONLY valid JSON " 

181 "matching this schema keys: {category, subcategory, importance_score, " 

182 "summary, searchable_content, reasoning, entities, relationships, " 

183 "suggested_tier, tags}. Entities contain {entity_type, entity_value, confidence}. " 

184 "Relationships contain {from_entity, to_entity, relationship_type, strength}." 

185 ) 

186 prompt = ( 

187 f"User: {user_input}\nAssistant: {ai_output}\n" 

188 "Extract structured memory now." 

189 ) 

190 

191 try: 

192 # Prefer OpenAI structured output when available 

193 if self.llm_provider == "openai": 

194 client = self._client 

195 assert client is not None 

196 response = await client.chat.completions.create( 

197 model=self.model, 

198 messages=[ 

199 {"role": "system", "content": system}, 

200 {"role": "user", "content": prompt}, 

201 ], 

202 response_format={"type": "json_object"}, 

203 ) 

204 content = response.choices[0].message.content or "{}" 

205 pm = ProcessedMemory.model_validate_json(content) 

206 processed_memory = pm 

207 else: 

208 # Unsupported provider in this class; delegate to cascade engine 

209 msg = "Unsupported provider in LLMEntityExtractor" 

210 raise RuntimeError(msg) 

211 

212 extraction_time = (datetime.now() - start_time).total_seconds() * 1000 

213 return EntityExtractionResult( 

214 processed_memory=processed_memory, 

215 entities_count=len(processed_memory.entities), 

216 relationships_count=len(processed_memory.relationships), 

217 extraction_time_ms=extraction_time, 

218 llm_provider=self.llm_provider, 

219 ) 

220 except Exception: 

221 # Fall back to a minimal default to avoid hard failure 

222 logger.info("LLM extraction failed; falling back to default output") 

223 processed_memory = ProcessedMemory( 

224 category="context", 

225 importance_score=0.5, 

226 summary="Conversation recorded", 

227 searchable_content=f"{user_input} {ai_output}", 

228 reasoning="LLM extraction fallback", 

229 ) 

230 extraction_time = (datetime.now() - start_time).total_seconds() * 1000 

231 return EntityExtractionResult( 

232 processed_memory=processed_memory, 

233 entities_count=0, 

234 relationships_count=0, 

235 extraction_time_ms=extraction_time, 

236 llm_provider=self.llm_provider, 

237 ) 

238 

239 

240class PatternBasedExtractor: 

241 """Regex/keyword-based extraction as a no-deps fallback.""" 

242 

243 def _categorize(self, text: str) -> str: 

244 lower = text.lower() 

245 if any(k in lower for k in ("prefer", "like", "avoid")): 245 ↛ 246line 245 didn't jump to line 246 because the condition on line 245 was never true

246 return "preferences" 

247 if any(k in lower for k in ("skill", "learned", "expert")): 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true

248 return "skills" 

249 if any(k in lower for k in ("rule", "policy", "guideline")): 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true

250 return "rules" 

251 if any(k in lower for k in ("context", "today", "currently", "now")): 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true

252 return "context" 

253 return "facts" 

254 

255 async def extract_entities( 

256 self, user_input: str, ai_output: str 

257 ) -> ProcessedMemory: 

258 text = f"{user_input}\n{ai_output}" 

259 category = self._categorize(text) 

260 return ProcessedMemory( 

261 category=category, 

262 importance_score=0.5, 

263 summary="Conversation recorded", 

264 searchable_content=text, 

265 reasoning="Pattern-based extraction", 

266 tags=[category], 

267 suggested_tier="long_term", 

268 ) 

269 

270 

271class EntityExtractionEngine: 

272 """Multi-provider extraction with cascade fallback.""" 

273 

274 def __init__(self) -> None: 

275 from session_buddy.llm_providers import LLMManager, LLMMessage 

276 from session_buddy.settings import get_settings 

277 

278 self._LLMMessage = LLMMessage 

279 self.manager = LLMManager() 

280 self.fallback_extractor = PatternBasedExtractor() 

281 settings = get_settings() 

282 self.timeout_s = settings.llm_extraction_timeout 

283 self.retries = settings.llm_extraction_retries 

284 

285 async def extract_entities( 

286 self, user_input: str, ai_output: str 

287 ) -> EntityExtractionResult: 

288 system = ( 

289 "You are an information extraction assistant. Return ONLY valid JSON " 

290 "for keys: category, subcategory, importance_score, summary, " 

291 "searchable_content, reasoning, entities, relationships, suggested_tier, tags." 

292 ) 

293 from session_buddy.llm_providers import LLMMessage 

294 

295 messages = [ 

296 LLMMessage(role="system", content=system), 

297 LLMMessage( 

298 role="user", 

299 content=( 

300 "Extract structured memory from the following.\n" 

301 f"User: {user_input}\nAssistant: {ai_output}" 

302 ), 

303 ), 

304 ] 

305 

306 providers = ["openai", "anthropic", "gemini"] 

307 start_time = datetime.now() 

308 

309 for provider in providers: 

310 try: 

311 resp: Any | None = ( 

312 None # Initialize to prevent "possibly unbound" error 

313 ) 

314 for attempt in range(max(1, self.retries + 1)): 314 ↛ 327line 314 didn't jump to line 327 because the loop on line 314 didn't complete

315 try: 

316 resp = await asyncio.wait_for( 

317 self.manager.generate( 

318 messages, provider=provider, temperature=0.2 

319 ), 

320 timeout=self.timeout_s, 

321 ) 

322 break 

323 except Exception: 

324 if attempt >= self.retries: 

325 raise 

326 continue 

327 assert resp is not None # Type narrowing for pyright 

328 pm = ProcessedMemory.model_validate_json(resp.content) 

329 extraction_time = (datetime.now() - start_time).total_seconds() * 1000 

330 return EntityExtractionResult( 

331 processed_memory=pm, 

332 entities_count=len(pm.entities), 

333 relationships_count=len(pm.relationships), 

334 extraction_time_ms=extraction_time, 

335 llm_provider=provider, 

336 ) 

337 except Exception as e: 

338 logger.warning(f"{provider} extraction failed: {e}") 

339 continue 

340 

341 # Final fallback: pattern-based 

342 pm = await self.fallback_extractor.extract_entities(user_input, ai_output) 

343 extraction_time = (datetime.now() - start_time).total_seconds() * 1000 

344 return EntityExtractionResult( 

345 processed_memory=pm, 

346 entities_count=len(pm.entities), 

347 relationships_count=len(pm.relationships), 

348 extraction_time_ms=extraction_time, 

349 llm_provider="pattern", 

350 )