Coverage for session_buddy / search_enhanced.py: 14.12%

236 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""Enhanced Search Capabilities for Session Management MCP Server. 

3 

4Provides multi-modal search including code snippets, error patterns, and time-based queries. 

5""" 

6 

7import ast 

8import contextlib 

9from datetime import datetime, timedelta 

10from typing import TYPE_CHECKING, Any, cast 

11 

12if TYPE_CHECKING: 

13 from dateutil.parser import parse as parse_date 

14 from dateutil.relativedelta import relativedelta 

15 

16try: 

17 from dateutil.parser import parse as parse_date 

18 from dateutil.relativedelta import relativedelta 

19 

20 DATEUTIL_AVAILABLE = True 

21except ImportError: 

22 DATEUTIL_AVAILABLE = False 

23 

24 

25import operator 

26 

27from .reflection_tools import ReflectionDatabase 

28from .session_types import TimeRange 

29from .utils.regex_patterns import SAFE_PATTERNS 

30 

31 

32class CodeSearcher: 

33 """AST-based code search for Python code snippets.""" 

34 

35 def __init__(self) -> None: 

36 self.search_types: dict[str, type[ast.AST] | tuple[type[ast.AST], ...]] = { 

37 "function": ast.FunctionDef, 

38 "class": ast.ClassDef, 

39 "import": (ast.Import, ast.ImportFrom), 

40 "assignment": ast.Assign, 

41 "call": ast.Call, 

42 "loop": (ast.For, ast.While), 

43 "conditional": ast.If, 

44 "try": ast.Try, 

45 "async": (ast.AsyncFunctionDef, ast.AsyncWith, ast.AsyncFor), 

46 } 

47 

48 def _extract_pattern_info( 

49 self, 

50 node: ast.AST, 

51 pattern_type: str, 

52 code: str, 

53 block_index: int, 

54 ) -> dict[str, Any]: 

55 """Extract pattern information from AST node.""" 

56 pattern_info = { 

57 "type": pattern_type, 

58 "content": code, 

59 "block_index": block_index, 

60 "line_number": getattr(node, "lineno", 0), 

61 } 

62 

63 # Extract specific information based on node type 

64 if isinstance(node, ast.FunctionDef): 

65 pattern_info["name"] = node.name 

66 pattern_info["args"] = [arg.arg for arg in node.args.args] 

67 elif isinstance(node, ast.ClassDef): 

68 pattern_info["name"] = node.name 

69 elif isinstance(node, ast.Import | ast.ImportFrom): 

70 if isinstance(node, ast.Import): 

71 pattern_info["modules"] = [alias.name for alias in node.names] 

72 else: 

73 pattern_info["module"] = node.module 

74 pattern_info["names"] = [alias.name for alias in node.names] 

75 

76 return pattern_info 

77 

78 def _process_code_block(self, code: str, block_index: int) -> list[dict[str, Any]]: 

79 """Process a single code block and extract patterns.""" 

80 patterns = [] 

81 with contextlib.suppress(SyntaxError, ValueError): 

82 # Not valid Python code, skip 

83 tree = ast.parse(code) 

84 for node in ast.walk(tree): 

85 for pattern_type, node_types in self.search_types.items(): 

86 # Handle both single classes and tuples of classes 

87 type_check = ( 

88 node_types if isinstance(node_types, tuple) else (node_types,) 

89 ) 

90 if isinstance(node, type_check): 

91 pattern_info = self._extract_pattern_info( 

92 node, 

93 pattern_type, 

94 code, 

95 block_index, 

96 ) 

97 patterns.append(pattern_info) 

98 return patterns 

99 

100 def extract_code_patterns(self, content: str) -> list[dict[str, Any]]: 

101 """Extract code patterns from conversation content.""" 

102 patterns = [] 

103 

104 # Extract Python code blocks using validated patterns 

105 python_code_blocks = SAFE_PATTERNS["python_code_block"].findall(content) 

106 generic_code_blocks = SAFE_PATTERNS["generic_code_block"].findall(content) 

107 code_blocks = python_code_blocks + generic_code_blocks 

108 

109 for i, code in enumerate(code_blocks): 

110 block_patterns = self._process_code_block(code, i) 

111 patterns.extend(block_patterns) 

112 

113 return patterns 

114 

115 

116class ErrorPatternMatcher: 

117 """Pattern matching for error messages and debugging contexts.""" 

118 

119 def __init__(self) -> None: 

120 # Map pattern names to our validated patterns 

121 self.error_patterns = { 

122 "python_traceback": "python_traceback", 

123 "python_exception": "python_exception", 

124 "javascript_error": "javascript_error", 

125 "compile_error": "compile_error", 

126 "warning": "warning_pattern", 

127 "assertion": "assertion_error", 

128 "import_error": "import_error", 

129 "module_not_found": "module_not_found", 

130 "file_not_found": "file_not_found", 

131 "permission_denied": "permission_denied", 

132 "network_error": "network_error", 

133 } 

134 

135 # Map context pattern names to our validated patterns 

136 self.context_patterns = { 

137 "debugging": "debugging_context", 

138 "testing": "testing_context", 

139 "error_handling": "error_handling_context", 

140 "performance": "performance_context", 

141 "security": "security_context", 

142 } 

143 

144 def extract_error_patterns(self, content: str) -> list[dict[str, Any]]: 

145 """Extract error patterns and debugging context from content.""" 

146 patterns = [] 

147 

148 # Find error patterns using validated patterns 

149 for pattern_name, safe_pattern_key in self.error_patterns.items(): 

150 safe_pattern = SAFE_PATTERNS[safe_pattern_key] 

151 # Use search() method to find matches with position info 

152 match = safe_pattern.search(content) 

153 if match: 

154 patterns.append( 

155 { 

156 "type": "error", 

157 "subtype": pattern_name, 

158 "content": match.group(0), 

159 "start": match.start(), 

160 "end": match.end(), 

161 "groups": match.groups() or [], 

162 }, 

163 ) 

164 

165 # Find context patterns using validated patterns 

166 for context_name, safe_pattern_key in self.context_patterns.items(): 

167 safe_pattern = SAFE_PATTERNS[safe_pattern_key] 

168 if safe_pattern.test(content): 

169 patterns.append( 

170 { 

171 "type": "context", 

172 "subtype": context_name, 

173 "content": content, 

174 "relevance": "high" 

175 if context_name in {"debugging", "error_handling"} 

176 else "medium", 

177 }, 

178 ) 

179 

180 return patterns 

181 

182 

183class TemporalSearchParser: 

184 """Parse natural language time expressions for conversation search.""" 

185 

186 def __init__(self) -> None: 

187 self.relative_patterns = { 

188 "today": timedelta(hours=0), 

189 "yesterday": timedelta(days=1), 

190 "this week": timedelta(weeks=1), 

191 "last week": timedelta(weeks=1, days=7), 

192 "this month": relativedelta(months=1) 

193 if DATEUTIL_AVAILABLE 

194 else timedelta(days=30), 

195 "last month": relativedelta(months=2) 

196 if DATEUTIL_AVAILABLE 

197 else timedelta(days=60), 

198 "this year": relativedelta(years=1) 

199 if DATEUTIL_AVAILABLE 

200 else timedelta(days=365), 

201 } 

202 

203 # Map to validated time parsing patterns 

204 self.time_patterns = [ 

205 "time_ago_pattern", 

206 "relative_time_pattern", 

207 "since_time_pattern", 

208 "last_duration_pattern", 

209 "iso_date_pattern", 

210 "us_date_pattern", 

211 ] 

212 

213 def _calculate_delta(self, amount: int, unit: str) -> timedelta: 

214 """Calculate timedelta from amount and unit.""" 

215 if unit == "minute": 

216 return timedelta(minutes=amount) 

217 if unit == "hour": 

218 return timedelta(hours=amount) 

219 if unit == "day": 

220 return timedelta(days=amount) 

221 if unit == "week": 

222 return timedelta(weeks=amount) 

223 if unit == "month": 

224 # Always use timedelta approximation for type safety 

225 return timedelta(days=amount * 30) 

226 if unit == "year": 

227 # Always use timedelta approximation for type safety 

228 return timedelta(days=amount * 365) 

229 return timedelta() 

230 

231 def _parse_relative_patterns( 

232 self, 

233 expression: str, 

234 now: datetime, 

235 ) -> TimeRange: 

236 """Parse relative time patterns.""" 

237 for pattern, delta in self.relative_patterns.items(): 

238 if pattern in expression: 

239 if "last" in pattern or pattern == "yesterday": 

240 end_time = now - delta 

241 start_time = end_time - delta 

242 else: 

243 start_time = now - delta 

244 end_time = now 

245 return TimeRange(start=start_time, end=end_time) 

246 return TimeRange() 

247 

248 def _parse_ago_pattern( 

249 self, 

250 expression: str, 

251 now: datetime, 

252 ) -> TimeRange: 

253 """Parse 'X time units ago' pattern.""" 

254 match = SAFE_PATTERNS["time_ago_pattern"].search(expression) 

255 if match: 

256 amount = int(match.group(1)) 

257 unit = match.group(2) 

258 delta = self._calculate_delta(amount, unit) 

259 end_time = now - delta 

260 return TimeRange(start=end_time, end=now) 

261 return TimeRange() 

262 

263 def _parse_last_pattern( 

264 self, 

265 expression: str, 

266 now: datetime, 

267 ) -> TimeRange: 

268 """Parse 'in the last X units' pattern.""" 

269 match = SAFE_PATTERNS["last_duration_pattern"].search(expression) 

270 if match: 

271 amount = int(match.group(1)) 

272 unit = match.group(2) 

273 delta = self._calculate_delta(amount, unit) 

274 start_time = now - delta 

275 return TimeRange(start=start_time, end=now) 

276 return TimeRange() 

277 

278 def _parse_absolute_date( 

279 self, 

280 expression: str, 

281 ) -> TimeRange: 

282 """Parse absolute date expressions.""" 

283 if not DATEUTIL_AVAILABLE: 

284 return TimeRange() 

285 

286 from contextlib import suppress 

287 

288 with suppress(ValueError, TypeError): 

289 parsed_date = parse_date(expression) 

290 # Ensure parsed_date is a datetime object 

291 if isinstance(parsed_date, datetime): 

292 # Return day range (start of day to end of day) 

293 start_time = parsed_date.replace( 

294 hour=0, 

295 minute=0, 

296 second=0, 

297 microsecond=0, 

298 ) 

299 end_time = start_time + timedelta(days=1) 

300 return TimeRange(start=start_time, end=end_time) 

301 return TimeRange() 

302 

303 def parse_time_expression( 

304 self, 

305 expression: str, 

306 ) -> TimeRange: 

307 """Parse time expression into start and end datetime.""" 

308 expression = expression.lower().strip() 

309 now = datetime.now() 

310 

311 # Try different parsing strategies 

312 parsers = [ 

313 self._parse_relative_patterns, 

314 self._parse_ago_pattern, 

315 self._parse_last_pattern, 

316 lambda expr, dt: self._parse_absolute_date(expr), 

317 ] 

318 

319 for parser in parsers: 

320 result = parser(expression, now) 

321 if result.start is not None or result.end is not None: 

322 return result 

323 

324 return TimeRange() 

325 

326 

327class EnhancedSearchEngine: 

328 """Main search engine that combines all enhanced search capabilities.""" 

329 

330 def __init__(self, reflection_db: ReflectionDatabase) -> None: 

331 self.reflection_db = reflection_db 

332 self.code_searcher = CodeSearcher() 

333 self.error_matcher = ErrorPatternMatcher() 

334 self.temporal_parser = TemporalSearchParser() 

335 

336 async def search_code_patterns( 

337 self, 

338 query: str, 

339 pattern_type: str | None = None, 

340 limit: int = 10, 

341 ) -> list[dict[str, Any]]: 

342 """Search for code patterns in conversations.""" 

343 conversations = self._get_all_conversations() 

344 if not conversations: 

345 return [] 

346 

347 results = [] 

348 for conv in conversations: 

349 conv_results = self._process_conversation_for_code_patterns( 

350 conv, 

351 query, 

352 pattern_type, 

353 ) 

354 results.extend(conv_results) 

355 

356 return self._sort_and_limit_results(results, limit) 

357 

358 def _get_all_conversations(self) -> list[tuple[str, str, str, str, str]]: 

359 """Get all conversations from database.""" 

360 if not hasattr(self.reflection_db, "conn") or not self.reflection_db.conn: 

361 return [] 

362 

363 cursor = self.reflection_db.conn.execute( 

364 "SELECT id, content, project, timestamp, metadata FROM conversations", 

365 ) 

366 return cast("list[tuple[str, str, str, str, str]]", cursor.fetchall()) 

367 

368 def _process_conversation_for_code_patterns( 

369 self, 

370 conv: tuple[str, str, str, str, str], 

371 query: str, 

372 pattern_type: str | None, 

373 ) -> list[dict[str, Any]]: 

374 """Process a single conversation for code patterns.""" 

375 conv_id, content, project, timestamp, _metadata = conv 

376 patterns = self.code_searcher.extract_code_patterns(content) 

377 results = [] 

378 

379 for pattern in patterns: 

380 if pattern_type and pattern["type"] != pattern_type: 

381 continue 

382 

383 relevance = self._calculate_code_relevance(pattern, query) 

384 if relevance > 0.3: # Threshold for relevance 

385 results.append( 

386 { 

387 "conversation_id": conv_id, 

388 "project": project, 

389 "timestamp": timestamp, 

390 "pattern": pattern, 

391 "relevance": relevance, 

392 "snippet": content[:500] + "..." 

393 if len(content) > 500 

394 else content, 

395 }, 

396 ) 

397 

398 return results 

399 

400 def _sort_and_limit_results( 

401 self, 

402 results: list[dict[str, Any]], 

403 limit: int, 

404 ) -> list[dict[str, Any]]: 

405 """Sort results by relevance and limit.""" 

406 results.sort(key=operator.itemgetter("relevance"), reverse=True) 

407 return results[:limit] 

408 

409 async def search_error_patterns( 

410 self, 

411 query: str, 

412 error_type: str | None = None, 

413 limit: int = 10, 

414 ) -> list[dict[str, Any]]: 

415 """Search for error patterns and debugging contexts.""" 

416 conversations = self._get_all_conversations() 

417 if not conversations: 

418 return [] 

419 

420 results = [] 

421 for conv in conversations: 

422 conv_results = self._process_conversation_for_error_patterns( 

423 conv, 

424 query, 

425 error_type, 

426 ) 

427 results.extend(conv_results) 

428 

429 return self._sort_and_limit_results(results, limit) 

430 

431 def _process_conversation_for_error_patterns( 

432 self, 

433 conv: tuple[str, str, str, str, str], 

434 query: str, 

435 error_type: str | None, 

436 ) -> list[dict[str, Any]]: 

437 """Process a single conversation for error patterns.""" 

438 conv_id, content, project, timestamp, _metadata = conv 

439 patterns = self.error_matcher.extract_error_patterns(content) 

440 results = [] 

441 

442 for pattern in patterns: 

443 if error_type and pattern["subtype"] != error_type: 

444 continue 

445 

446 relevance = self._calculate_error_relevance(pattern, query) 

447 if relevance > 0.2: # Lower threshold for errors 

448 results.append( 

449 { 

450 "conversation_id": conv_id, 

451 "project": project, 

452 "timestamp": timestamp, 

453 "pattern": pattern, 

454 "relevance": relevance, 

455 "snippet": content[:500] + "..." 

456 if len(content) > 500 

457 else content, 

458 }, 

459 ) 

460 

461 return results 

462 

463 async def search_temporal( 

464 self, 

465 time_expression: str, 

466 query: str | None = None, 

467 limit: int = 10, 

468 ) -> list[dict[str, Any]]: 

469 """Search conversations within a time range.""" 

470 time_range = self.temporal_parser.parse_time_expression( 

471 time_expression, 

472 ) 

473 

474 if not time_range.start or not time_range.end: 

475 return [{"error": f"Could not parse time expression: {time_expression}"}] 

476 

477 start_time = time_range.start 

478 end_time = time_range.end 

479 

480 results = [] 

481 

482 if hasattr(self.reflection_db, "conn") and self.reflection_db.conn: 

483 # Convert to ISO format for database query 

484 start_iso = start_time.isoformat() 

485 end_iso = end_time.isoformat() 

486 

487 sql_query = """ 

488 SELECT id, content, project, timestamp, metadata 

489 FROM conversations 

490 WHERE timestamp BETWEEN ? AND ? 

491 ORDER BY timestamp DESC 

492 """ 

493 

494 cursor = self.reflection_db.conn.execute(sql_query, (start_iso, end_iso)) 

495 conversations = cursor.fetchall() 

496 

497 for conv in conversations: 

498 conv_id, content, project, timestamp, _metadata = conv 

499 

500 # If query provided, filter by content relevance 

501 if query: 

502 relevance = self._calculate_text_relevance(content, query) 

503 if relevance < 0.3: 

504 continue 

505 else: 

506 relevance = 1.0 

507 

508 results.append( 

509 { 

510 "conversation_id": conv_id, 

511 "project": project, 

512 "timestamp": timestamp, 

513 "content": content[:500] + "..." 

514 if len(content) > 500 

515 else content, 

516 "relevance": relevance, 

517 }, 

518 ) 

519 

520 return results[:limit] 

521 

522 def _calculate_code_relevance(self, pattern: dict[str, Any], query: str) -> float: 

523 """Calculate relevance score for code patterns.""" 

524 relevance = 0.0 

525 query_lower = query.lower() 

526 

527 # Type matching 

528 if pattern["type"] in query_lower: 

529 relevance += 0.5 

530 

531 # Name matching (for functions/classes) 

532 if "name" in pattern and pattern["name"].lower() in query_lower: 

533 relevance += 0.7 

534 

535 # Content matching 

536 if query_lower in pattern["content"].lower(): 

537 relevance += 0.4 

538 

539 # Module/import matching 

540 if "modules" in pattern: 

541 for module in pattern["modules"]: 

542 if module.lower() in query_lower: 

543 relevance += 0.3 

544 

545 return min(relevance, 1.0) 

546 

547 def _calculate_error_relevance(self, pattern: dict[str, Any], query: str) -> float: 

548 """Calculate relevance score for error patterns.""" 

549 relevance = 0.0 

550 query_lower = query.lower() 

551 

552 # Error type matching 

553 if pattern["subtype"] in query_lower: 

554 relevance += 0.6 

555 

556 # Content matching 

557 if "content" in pattern and query_lower in pattern["content"].lower(): 

558 relevance += 0.5 

559 

560 # Context relevance boost 

561 if pattern["type"] == "context" and pattern.get("relevance") == "high": 

562 relevance += 0.3 

563 

564 return min(relevance, 1.0) 

565 

566 def _calculate_text_relevance(self, content: str, query: str) -> float: 

567 """Simple text relevance calculation.""" 

568 query_lower = query.lower() 

569 content_lower = content.lower() 

570 

571 # Simple keyword matching 

572 query_words = query_lower.split() 

573 content_words = content_lower.split() 

574 

575 matches = sum(1 for word in query_words if word in content_words) 

576 return matches / len(query_words) if query_words else 0.0