Coverage for session_mgmt_mcp/search_enhanced.py: 10.20%

204 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-01 05:22 -0700

1#!/usr/bin/env python3 

2"""Enhanced Search Capabilities for Session Management MCP Server. 

3 

4Provides multi-modal search including code snippets, error patterns, and time-based queries. 

5""" 

6 

7import ast 

8import re 

9from datetime import datetime, timedelta 

10from typing import Any 

11 

12try: 

13 from dateutil.parser import parse as parse_date 

14 from dateutil.relativedelta import relativedelta 

15 

16 DATEUTIL_AVAILABLE = True 

17except ImportError: 

18 DATEUTIL_AVAILABLE = False 

19 

20from .reflection_tools import ReflectionDatabase 

21 

22 

23class CodeSearcher: 

24 """AST-based code search for Python code snippets.""" 

25 

26 def __init__(self) -> None: 

27 self.search_types = { 

28 "function": ast.FunctionDef, 

29 "class": ast.ClassDef, 

30 "import": (ast.Import, ast.ImportFrom), 

31 "assignment": ast.Assign, 

32 "call": ast.Call, 

33 "loop": (ast.For, ast.While), 

34 "conditional": ast.If, 

35 "try": ast.Try, 

36 "async": (ast.AsyncFunctionDef, ast.AsyncWith, ast.AsyncFor), 

37 } 

38 

39 def extract_code_patterns(self, content: str) -> list[dict[str, Any]]: 

40 """Extract code patterns from conversation content.""" 

41 patterns = [] 

42 

43 # Extract Python code blocks 

44 code_blocks = re.findall(r"```python\n(.*?)\n```", content, re.DOTALL) 

45 code_blocks.extend(re.findall(r"```\n(.*?)\n```", content, re.DOTALL)) 

46 

47 for i, code in enumerate(code_blocks): 

48 try: 

49 tree = ast.parse(code) 

50 for node in ast.walk(tree): 

51 for pattern_type, node_types in self.search_types.items(): 

52 if isinstance(node, node_types): 

53 pattern_info = { 

54 "type": pattern_type, 

55 "content": code, 

56 "block_index": i, 

57 "line_number": getattr(node, "lineno", 0), 

58 } 

59 

60 # Extract specific information based on node type 

61 if isinstance(node, ast.FunctionDef): 

62 pattern_info["name"] = node.name 

63 pattern_info["args"] = [ 

64 arg.arg for arg in node.args.args 

65 ] 

66 elif isinstance(node, ast.ClassDef): 

67 pattern_info["name"] = node.name 

68 elif isinstance(node, ast.Import | ast.ImportFrom): 

69 if isinstance(node, ast.Import): 

70 pattern_info["modules"] = [ 

71 alias.name for alias in node.names 

72 ] 

73 else: 

74 pattern_info["module"] = node.module 

75 pattern_info["names"] = [ 

76 alias.name for alias in node.names 

77 ] 

78 

79 patterns.append(pattern_info) 

80 

81 except (SyntaxError, ValueError): 

82 # Not valid Python code, skip 

83 continue 

84 

85 return patterns 

86 

87 

88class ErrorPatternMatcher: 

89 """Pattern matching for error messages and debugging contexts.""" 

90 

91 def __init__(self) -> None: 

92 self.error_patterns = { 

93 "python_traceback": r"Traceback \(most recent call last\):.*?(?=\n\n|\Z)", 

94 "python_exception": r"(\w+Error): (.+)", 

95 "javascript_error": r"(Error|TypeError|ReferenceError): (.+)", 

96 "compile_error": r"(error|Error): (.+) at line (\d+)", 

97 "warning": r"(warning|Warning): (.+)", 

98 "assertion": r"AssertionError: (.+)", 

99 "import_error": r"ImportError: (.+)", 

100 "module_not_found": r"ModuleNotFoundError: (.+)", 

101 "file_not_found": r"FileNotFoundError: (.+)", 

102 "permission_denied": r"PermissionError: (.+)", 

103 "network_error": r"(ConnectionError|TimeoutError|HTTPError): (.+)", 

104 } 

105 

106 self.context_patterns = { 

107 "debugging": r"(debug|debugging|breakpoint|pdb|print\()", 

108 "testing": r"(test|pytest|unittest|assert|mock)", 

109 "error_handling": r"(try|except|finally|raise|catch)", 

110 "performance": r"(slow|performance|benchmark|optimize|profil)", 

111 "security": r"(security|authentication|authorization|token|password)", 

112 } 

113 

114 def extract_error_patterns(self, content: str) -> list[dict[str, Any]]: 

115 """Extract error patterns and debugging context from content.""" 

116 patterns = [] 

117 

118 # Find error patterns 

119 for pattern_name, regex in self.error_patterns.items(): 

120 matches = re.finditer(regex, content, re.MULTILINE | re.DOTALL) 

121 for match in matches: 

122 patterns.append( 

123 { 

124 "type": "error", 

125 "subtype": pattern_name, 

126 "content": match.group(0), 

127 "start": match.start(), 

128 "end": match.end(), 

129 "groups": match.groups() if match.groups() else [], 

130 }, 

131 ) 

132 

133 # Find context patterns 

134 for context_name, regex in self.context_patterns.items(): 

135 if re.search(regex, content, re.IGNORECASE): 

136 patterns.append( 

137 { 

138 "type": "context", 

139 "subtype": context_name, 

140 "content": content, 

141 "relevance": "high" 

142 if context_name in ["debugging", "error_handling"] 

143 else "medium", 

144 }, 

145 ) 

146 

147 return patterns 

148 

149 

150class TemporalSearchParser: 

151 """Parse natural language time expressions for conversation search.""" 

152 

153 def __init__(self) -> None: 

154 self.relative_patterns = { 

155 "today": timedelta(hours=0), 

156 "yesterday": timedelta(days=1), 

157 "this week": timedelta(weeks=1), 

158 "last week": timedelta(weeks=1, days=7), 

159 "this month": relativedelta(months=1) 

160 if DATEUTIL_AVAILABLE 

161 else timedelta(days=30), 

162 "last month": relativedelta(months=2) 

163 if DATEUTIL_AVAILABLE 

164 else timedelta(days=60), 

165 "this year": relativedelta(years=1) 

166 if DATEUTIL_AVAILABLE 

167 else timedelta(days=365), 

168 } 

169 

170 self.time_patterns = [ 

171 r"(\d+)\s+(minute|hour|day|week|month|year)s?\s+ago", 

172 r"(today|yesterday|this\s+week|last\s+week|this\s+month|last\s+month)", 

173 r"since\s+(today|yesterday|this\s+week|last\s+week)", 

174 r"in\s+the\s+last\s+(\d+)\s+(minute|hour|day|week|month|year)s?", 

175 r"(\d{4}-\d{2}-\d{2})", # ISO date 

176 r"(\d{1,2}/\d{1,2}/\d{4})", # MM/DD/YYYY 

177 ] 

178 

179 def _calculate_delta(self, amount: int, unit: str) -> timedelta: 

180 """Calculate timedelta from amount and unit.""" 

181 if unit == "minute": 

182 return timedelta(minutes=amount) 

183 if unit == "hour": 

184 return timedelta(hours=amount) 

185 if unit == "day": 

186 return timedelta(days=amount) 

187 if unit == "week": 

188 return timedelta(weeks=amount) 

189 if unit == "month": 

190 return ( 

191 relativedelta(months=amount) 

192 if DATEUTIL_AVAILABLE 

193 else timedelta(days=amount * 30) 

194 ) 

195 if unit == "year": 

196 return ( 

197 relativedelta(years=amount) 

198 if DATEUTIL_AVAILABLE 

199 else timedelta(days=amount * 365) 

200 ) 

201 return timedelta() 

202 

203 def _parse_relative_patterns( 

204 self, 

205 expression: str, 

206 now: datetime, 

207 ) -> tuple[datetime | None, datetime | None]: 

208 """Parse relative time patterns.""" 

209 for pattern, delta in self.relative_patterns.items(): 

210 if pattern in expression: 

211 if "last" in pattern or pattern == "yesterday": 

212 end_time = now - delta 

213 start_time = end_time - delta 

214 else: 

215 start_time = now - delta 

216 end_time = now 

217 return start_time, end_time 

218 return None, None 

219 

220 def _parse_ago_pattern( 

221 self, 

222 expression: str, 

223 now: datetime, 

224 ) -> tuple[datetime | None, datetime | None]: 

225 """Parse 'X time units ago' pattern.""" 

226 match = re.search( 

227 r"(\d+)\s+(minute|hour|day|week|month|year)s?\s+ago", 

228 expression, 

229 ) 

230 if match: 

231 amount = int(match.group(1)) 

232 unit = match.group(2) 

233 delta = self._calculate_delta(amount, unit) 

234 end_time = now - delta 

235 return end_time, now 

236 return None, None 

237 

238 def _parse_last_pattern( 

239 self, 

240 expression: str, 

241 now: datetime, 

242 ) -> tuple[datetime | None, datetime | None]: 

243 """Parse 'in the last X units' pattern.""" 

244 match = re.search( 

245 r"in\s+the\s+last\s+(\d+)\s+(minute|hour|day|week|month|year)s?", 

246 expression, 

247 ) 

248 if match: 

249 amount = int(match.group(1)) 

250 unit = match.group(2) 

251 delta = self._calculate_delta(amount, unit) 

252 start_time = now - delta 

253 return start_time, now 

254 return None, None 

255 

256 def _parse_absolute_date( 

257 self, 

258 expression: str, 

259 ) -> tuple[datetime | None, datetime | None]: 

260 """Parse absolute date expressions.""" 

261 if not DATEUTIL_AVAILABLE: 

262 return None, None 

263 

264 try: 

265 parsed_date = parse_date(expression) 

266 # Return day range (start of day to end of day) 

267 start_time = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0) 

268 end_time = start_time + timedelta(days=1) 

269 return start_time, end_time 

270 except (ValueError, TypeError): 

271 return None, None 

272 

273 def parse_time_expression( 

274 self, 

275 expression: str, 

276 ) -> tuple[datetime | None, datetime | None]: 

277 """Parse time expression into start and end datetime.""" 

278 expression = expression.lower().strip() 

279 now = datetime.now() 

280 

281 # Try different parsing strategies 

282 parsers = [ 

283 self._parse_relative_patterns, 

284 self._parse_ago_pattern, 

285 self._parse_last_pattern, 

286 lambda expr, dt: self._parse_absolute_date(expr), 

287 ] 

288 

289 for parser in parsers: 

290 result = parser(expression, now) 

291 if result != (None, None): 

292 return result 

293 

294 return None, None 

295 

296 

297class EnhancedSearchEngine: 

298 """Main search engine that combines all enhanced search capabilities.""" 

299 

300 def __init__(self, reflection_db: ReflectionDatabase) -> None: 

301 self.reflection_db = reflection_db 

302 self.code_searcher = CodeSearcher() 

303 self.error_matcher = ErrorPatternMatcher() 

304 self.temporal_parser = TemporalSearchParser() 

305 

306 async def search_code_patterns( 

307 self, 

308 query: str, 

309 pattern_type: str | None = None, 

310 limit: int = 10, 

311 ) -> list[dict[str, Any]]: 

312 """Search for code patterns in conversations.""" 

313 results = [] 

314 

315 # Get all conversations from database 

316 if hasattr(self.reflection_db, "conn") and self.reflection_db.conn: 

317 cursor = self.reflection_db.conn.execute( 

318 "SELECT id, content, project, timestamp, metadata FROM conversations", 

319 ) 

320 conversations = cursor.fetchall() 

321 

322 for conv in conversations: 

323 conv_id, content, project, timestamp, metadata = conv 

324 patterns = self.code_searcher.extract_code_patterns(content) 

325 

326 for pattern in patterns: 

327 if pattern_type and pattern["type"] != pattern_type: 

328 continue 

329 

330 # Calculate relevance based on query similarity 

331 relevance = self._calculate_code_relevance(pattern, query) 

332 

333 if relevance > 0.3: # Threshold for relevance 

334 results.append( 

335 { 

336 "conversation_id": conv_id, 

337 "project": project, 

338 "timestamp": timestamp, 

339 "pattern": pattern, 

340 "relevance": relevance, 

341 "snippet": content[:500] + "..." 

342 if len(content) > 500 

343 else content, 

344 }, 

345 ) 

346 

347 # Sort by relevance and limit results 

348 results.sort(key=lambda x: x["relevance"], reverse=True) 

349 return results[:limit] 

350 

351 async def search_error_patterns( 

352 self, 

353 query: str, 

354 error_type: str | None = None, 

355 limit: int = 10, 

356 ) -> list[dict[str, Any]]: 

357 """Search for error patterns and debugging contexts.""" 

358 results = [] 

359 

360 if hasattr(self.reflection_db, "conn") and self.reflection_db.conn: 

361 cursor = self.reflection_db.conn.execute( 

362 "SELECT id, content, project, timestamp, metadata FROM conversations", 

363 ) 

364 conversations = cursor.fetchall() 

365 

366 for conv in conversations: 

367 conv_id, content, project, timestamp, metadata = conv 

368 patterns = self.error_matcher.extract_error_patterns(content) 

369 

370 for pattern in patterns: 

371 if error_type and pattern["subtype"] != error_type: 

372 continue 

373 

374 # Calculate relevance based on query similarity 

375 relevance = self._calculate_error_relevance(pattern, query) 

376 

377 if relevance > 0.2: # Lower threshold for errors 

378 results.append( 

379 { 

380 "conversation_id": conv_id, 

381 "project": project, 

382 "timestamp": timestamp, 

383 "pattern": pattern, 

384 "relevance": relevance, 

385 "snippet": content[:500] + "..." 

386 if len(content) > 500 

387 else content, 

388 }, 

389 ) 

390 

391 # Sort by relevance and limit results 

392 results.sort(key=lambda x: x["relevance"], reverse=True) 

393 return results[:limit] 

394 

395 async def search_temporal( 

396 self, 

397 time_expression: str, 

398 query: str | None = None, 

399 limit: int = 10, 

400 ) -> list[dict[str, Any]]: 

401 """Search conversations within a time range.""" 

402 start_time, end_time = self.temporal_parser.parse_time_expression( 

403 time_expression, 

404 ) 

405 

406 if not start_time or not end_time: 

407 return [{"error": f"Could not parse time expression: {time_expression}"}] 

408 

409 results = [] 

410 

411 if hasattr(self.reflection_db, "conn") and self.reflection_db.conn: 

412 # Convert to ISO format for database query 

413 start_iso = start_time.isoformat() 

414 end_iso = end_time.isoformat() 

415 

416 sql_query = """ 

417 SELECT id, content, project, timestamp, metadata 

418 FROM conversations 

419 WHERE timestamp BETWEEN ? AND ? 

420 ORDER BY timestamp DESC 

421 """ 

422 

423 cursor = self.reflection_db.conn.execute(sql_query, (start_iso, end_iso)) 

424 conversations = cursor.fetchall() 

425 

426 for conv in conversations: 

427 conv_id, content, project, timestamp, metadata = conv 

428 

429 # If query provided, filter by content relevance 

430 if query: 

431 relevance = self._calculate_text_relevance(content, query) 

432 if relevance < 0.3: 

433 continue 

434 else: 

435 relevance = 1.0 

436 

437 results.append( 

438 { 

439 "conversation_id": conv_id, 

440 "project": project, 

441 "timestamp": timestamp, 

442 "content": content[:500] + "..." 

443 if len(content) > 500 

444 else content, 

445 "relevance": relevance, 

446 }, 

447 ) 

448 

449 return results[:limit] 

450 

451 def _calculate_code_relevance(self, pattern: dict[str, Any], query: str) -> float: 

452 """Calculate relevance score for code patterns.""" 

453 relevance = 0.0 

454 query_lower = query.lower() 

455 

456 # Type matching 

457 if pattern["type"] in query_lower: 

458 relevance += 0.5 

459 

460 # Name matching (for functions/classes) 

461 if "name" in pattern and pattern["name"].lower() in query_lower: 

462 relevance += 0.7 

463 

464 # Content matching 

465 if query_lower in pattern["content"].lower(): 

466 relevance += 0.4 

467 

468 # Module/import matching 

469 if "modules" in pattern: 

470 for module in pattern["modules"]: 

471 if module.lower() in query_lower: 

472 relevance += 0.3 

473 

474 return min(relevance, 1.0) 

475 

476 def _calculate_error_relevance(self, pattern: dict[str, Any], query: str) -> float: 

477 """Calculate relevance score for error patterns.""" 

478 relevance = 0.0 

479 query_lower = query.lower() 

480 

481 # Error type matching 

482 if pattern["subtype"] in query_lower: 

483 relevance += 0.6 

484 

485 # Content matching 

486 if "content" in pattern and query_lower in pattern["content"].lower(): 

487 relevance += 0.5 

488 

489 # Context relevance boost 

490 if pattern["type"] == "context" and pattern.get("relevance") == "high": 

491 relevance += 0.3 

492 

493 return min(relevance, 1.0) 

494 

495 def _calculate_text_relevance(self, content: str, query: str) -> float: 

496 """Simple text relevance calculation.""" 

497 query_lower = query.lower() 

498 content_lower = content.lower() 

499 

500 # Simple keyword matching 

501 query_words = query_lower.split() 

502 content_words = content_lower.split() 

503 

504 matches = sum(1 for word in query_words if word in content_words) 

505 return matches / len(query_words) if query_words else 0.0