Coverage for session_buddy / search_enhanced.py: 14.12%
236 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Enhanced Search Capabilities for Session Management MCP Server.
4Provides multi-modal search including code snippets, error patterns, and time-based queries.
5"""
7import ast
8import contextlib
9from datetime import datetime, timedelta
10from typing import TYPE_CHECKING, Any, cast
12if TYPE_CHECKING:
13 from dateutil.parser import parse as parse_date
14 from dateutil.relativedelta import relativedelta
16try:
17 from dateutil.parser import parse as parse_date
18 from dateutil.relativedelta import relativedelta
20 DATEUTIL_AVAILABLE = True
21except ImportError:
22 DATEUTIL_AVAILABLE = False
25import operator
27from .reflection_tools import ReflectionDatabase
28from .session_types import TimeRange
29from .utils.regex_patterns import SAFE_PATTERNS
32class CodeSearcher:
33 """AST-based code search for Python code snippets."""
35 def __init__(self) -> None:
36 self.search_types: dict[str, type[ast.AST] | tuple[type[ast.AST], ...]] = {
37 "function": ast.FunctionDef,
38 "class": ast.ClassDef,
39 "import": (ast.Import, ast.ImportFrom),
40 "assignment": ast.Assign,
41 "call": ast.Call,
42 "loop": (ast.For, ast.While),
43 "conditional": ast.If,
44 "try": ast.Try,
45 "async": (ast.AsyncFunctionDef, ast.AsyncWith, ast.AsyncFor),
46 }
48 def _extract_pattern_info(
49 self,
50 node: ast.AST,
51 pattern_type: str,
52 code: str,
53 block_index: int,
54 ) -> dict[str, Any]:
55 """Extract pattern information from AST node."""
56 pattern_info = {
57 "type": pattern_type,
58 "content": code,
59 "block_index": block_index,
60 "line_number": getattr(node, "lineno", 0),
61 }
63 # Extract specific information based on node type
64 if isinstance(node, ast.FunctionDef):
65 pattern_info["name"] = node.name
66 pattern_info["args"] = [arg.arg for arg in node.args.args]
67 elif isinstance(node, ast.ClassDef):
68 pattern_info["name"] = node.name
69 elif isinstance(node, ast.Import | ast.ImportFrom):
70 if isinstance(node, ast.Import):
71 pattern_info["modules"] = [alias.name for alias in node.names]
72 else:
73 pattern_info["module"] = node.module
74 pattern_info["names"] = [alias.name for alias in node.names]
76 return pattern_info
78 def _process_code_block(self, code: str, block_index: int) -> list[dict[str, Any]]:
79 """Process a single code block and extract patterns."""
80 patterns = []
81 with contextlib.suppress(SyntaxError, ValueError):
82 # Not valid Python code, skip
83 tree = ast.parse(code)
84 for node in ast.walk(tree):
85 for pattern_type, node_types in self.search_types.items():
86 # Handle both single classes and tuples of classes
87 type_check = (
88 node_types if isinstance(node_types, tuple) else (node_types,)
89 )
90 if isinstance(node, type_check):
91 pattern_info = self._extract_pattern_info(
92 node,
93 pattern_type,
94 code,
95 block_index,
96 )
97 patterns.append(pattern_info)
98 return patterns
100 def extract_code_patterns(self, content: str) -> list[dict[str, Any]]:
101 """Extract code patterns from conversation content."""
102 patterns = []
104 # Extract Python code blocks using validated patterns
105 python_code_blocks = SAFE_PATTERNS["python_code_block"].findall(content)
106 generic_code_blocks = SAFE_PATTERNS["generic_code_block"].findall(content)
107 code_blocks = python_code_blocks + generic_code_blocks
109 for i, code in enumerate(code_blocks):
110 block_patterns = self._process_code_block(code, i)
111 patterns.extend(block_patterns)
113 return patterns
116class ErrorPatternMatcher:
117 """Pattern matching for error messages and debugging contexts."""
119 def __init__(self) -> None:
120 # Map pattern names to our validated patterns
121 self.error_patterns = {
122 "python_traceback": "python_traceback",
123 "python_exception": "python_exception",
124 "javascript_error": "javascript_error",
125 "compile_error": "compile_error",
126 "warning": "warning_pattern",
127 "assertion": "assertion_error",
128 "import_error": "import_error",
129 "module_not_found": "module_not_found",
130 "file_not_found": "file_not_found",
131 "permission_denied": "permission_denied",
132 "network_error": "network_error",
133 }
135 # Map context pattern names to our validated patterns
136 self.context_patterns = {
137 "debugging": "debugging_context",
138 "testing": "testing_context",
139 "error_handling": "error_handling_context",
140 "performance": "performance_context",
141 "security": "security_context",
142 }
144 def extract_error_patterns(self, content: str) -> list[dict[str, Any]]:
145 """Extract error patterns and debugging context from content."""
146 patterns = []
148 # Find error patterns using validated patterns
149 for pattern_name, safe_pattern_key in self.error_patterns.items():
150 safe_pattern = SAFE_PATTERNS[safe_pattern_key]
151 # Use search() method to find matches with position info
152 match = safe_pattern.search(content)
153 if match:
154 patterns.append(
155 {
156 "type": "error",
157 "subtype": pattern_name,
158 "content": match.group(0),
159 "start": match.start(),
160 "end": match.end(),
161 "groups": match.groups() or [],
162 },
163 )
165 # Find context patterns using validated patterns
166 for context_name, safe_pattern_key in self.context_patterns.items():
167 safe_pattern = SAFE_PATTERNS[safe_pattern_key]
168 if safe_pattern.test(content):
169 patterns.append(
170 {
171 "type": "context",
172 "subtype": context_name,
173 "content": content,
174 "relevance": "high"
175 if context_name in {"debugging", "error_handling"}
176 else "medium",
177 },
178 )
180 return patterns
183class TemporalSearchParser:
184 """Parse natural language time expressions for conversation search."""
186 def __init__(self) -> None:
187 self.relative_patterns = {
188 "today": timedelta(hours=0),
189 "yesterday": timedelta(days=1),
190 "this week": timedelta(weeks=1),
191 "last week": timedelta(weeks=1, days=7),
192 "this month": relativedelta(months=1)
193 if DATEUTIL_AVAILABLE
194 else timedelta(days=30),
195 "last month": relativedelta(months=2)
196 if DATEUTIL_AVAILABLE
197 else timedelta(days=60),
198 "this year": relativedelta(years=1)
199 if DATEUTIL_AVAILABLE
200 else timedelta(days=365),
201 }
203 # Map to validated time parsing patterns
204 self.time_patterns = [
205 "time_ago_pattern",
206 "relative_time_pattern",
207 "since_time_pattern",
208 "last_duration_pattern",
209 "iso_date_pattern",
210 "us_date_pattern",
211 ]
213 def _calculate_delta(self, amount: int, unit: str) -> timedelta:
214 """Calculate timedelta from amount and unit."""
215 if unit == "minute":
216 return timedelta(minutes=amount)
217 if unit == "hour":
218 return timedelta(hours=amount)
219 if unit == "day":
220 return timedelta(days=amount)
221 if unit == "week":
222 return timedelta(weeks=amount)
223 if unit == "month":
224 # Always use timedelta approximation for type safety
225 return timedelta(days=amount * 30)
226 if unit == "year":
227 # Always use timedelta approximation for type safety
228 return timedelta(days=amount * 365)
229 return timedelta()
231 def _parse_relative_patterns(
232 self,
233 expression: str,
234 now: datetime,
235 ) -> TimeRange:
236 """Parse relative time patterns."""
237 for pattern, delta in self.relative_patterns.items():
238 if pattern in expression:
239 if "last" in pattern or pattern == "yesterday":
240 end_time = now - delta
241 start_time = end_time - delta
242 else:
243 start_time = now - delta
244 end_time = now
245 return TimeRange(start=start_time, end=end_time)
246 return TimeRange()
248 def _parse_ago_pattern(
249 self,
250 expression: str,
251 now: datetime,
252 ) -> TimeRange:
253 """Parse 'X time units ago' pattern."""
254 match = SAFE_PATTERNS["time_ago_pattern"].search(expression)
255 if match:
256 amount = int(match.group(1))
257 unit = match.group(2)
258 delta = self._calculate_delta(amount, unit)
259 end_time = now - delta
260 return TimeRange(start=end_time, end=now)
261 return TimeRange()
263 def _parse_last_pattern(
264 self,
265 expression: str,
266 now: datetime,
267 ) -> TimeRange:
268 """Parse 'in the last X units' pattern."""
269 match = SAFE_PATTERNS["last_duration_pattern"].search(expression)
270 if match:
271 amount = int(match.group(1))
272 unit = match.group(2)
273 delta = self._calculate_delta(amount, unit)
274 start_time = now - delta
275 return TimeRange(start=start_time, end=now)
276 return TimeRange()
278 def _parse_absolute_date(
279 self,
280 expression: str,
281 ) -> TimeRange:
282 """Parse absolute date expressions."""
283 if not DATEUTIL_AVAILABLE:
284 return TimeRange()
286 from contextlib import suppress
288 with suppress(ValueError, TypeError):
289 parsed_date = parse_date(expression)
290 # Ensure parsed_date is a datetime object
291 if isinstance(parsed_date, datetime):
292 # Return day range (start of day to end of day)
293 start_time = parsed_date.replace(
294 hour=0,
295 minute=0,
296 second=0,
297 microsecond=0,
298 )
299 end_time = start_time + timedelta(days=1)
300 return TimeRange(start=start_time, end=end_time)
301 return TimeRange()
303 def parse_time_expression(
304 self,
305 expression: str,
306 ) -> TimeRange:
307 """Parse time expression into start and end datetime."""
308 expression = expression.lower().strip()
309 now = datetime.now()
311 # Try different parsing strategies
312 parsers = [
313 self._parse_relative_patterns,
314 self._parse_ago_pattern,
315 self._parse_last_pattern,
316 lambda expr, dt: self._parse_absolute_date(expr),
317 ]
319 for parser in parsers:
320 result = parser(expression, now)
321 if result.start is not None or result.end is not None:
322 return result
324 return TimeRange()
327class EnhancedSearchEngine:
328 """Main search engine that combines all enhanced search capabilities."""
330 def __init__(self, reflection_db: ReflectionDatabase) -> None:
331 self.reflection_db = reflection_db
332 self.code_searcher = CodeSearcher()
333 self.error_matcher = ErrorPatternMatcher()
334 self.temporal_parser = TemporalSearchParser()
336 async def search_code_patterns(
337 self,
338 query: str,
339 pattern_type: str | None = None,
340 limit: int = 10,
341 ) -> list[dict[str, Any]]:
342 """Search for code patterns in conversations."""
343 conversations = self._get_all_conversations()
344 if not conversations:
345 return []
347 results = []
348 for conv in conversations:
349 conv_results = self._process_conversation_for_code_patterns(
350 conv,
351 query,
352 pattern_type,
353 )
354 results.extend(conv_results)
356 return self._sort_and_limit_results(results, limit)
358 def _get_all_conversations(self) -> list[tuple[str, str, str, str, str]]:
359 """Get all conversations from database."""
360 if not hasattr(self.reflection_db, "conn") or not self.reflection_db.conn:
361 return []
363 cursor = self.reflection_db.conn.execute(
364 "SELECT id, content, project, timestamp, metadata FROM conversations",
365 )
366 return cast("list[tuple[str, str, str, str, str]]", cursor.fetchall())
368 def _process_conversation_for_code_patterns(
369 self,
370 conv: tuple[str, str, str, str, str],
371 query: str,
372 pattern_type: str | None,
373 ) -> list[dict[str, Any]]:
374 """Process a single conversation for code patterns."""
375 conv_id, content, project, timestamp, _metadata = conv
376 patterns = self.code_searcher.extract_code_patterns(content)
377 results = []
379 for pattern in patterns:
380 if pattern_type and pattern["type"] != pattern_type:
381 continue
383 relevance = self._calculate_code_relevance(pattern, query)
384 if relevance > 0.3: # Threshold for relevance
385 results.append(
386 {
387 "conversation_id": conv_id,
388 "project": project,
389 "timestamp": timestamp,
390 "pattern": pattern,
391 "relevance": relevance,
392 "snippet": content[:500] + "..."
393 if len(content) > 500
394 else content,
395 },
396 )
398 return results
400 def _sort_and_limit_results(
401 self,
402 results: list[dict[str, Any]],
403 limit: int,
404 ) -> list[dict[str, Any]]:
405 """Sort results by relevance and limit."""
406 results.sort(key=operator.itemgetter("relevance"), reverse=True)
407 return results[:limit]
409 async def search_error_patterns(
410 self,
411 query: str,
412 error_type: str | None = None,
413 limit: int = 10,
414 ) -> list[dict[str, Any]]:
415 """Search for error patterns and debugging contexts."""
416 conversations = self._get_all_conversations()
417 if not conversations:
418 return []
420 results = []
421 for conv in conversations:
422 conv_results = self._process_conversation_for_error_patterns(
423 conv,
424 query,
425 error_type,
426 )
427 results.extend(conv_results)
429 return self._sort_and_limit_results(results, limit)
431 def _process_conversation_for_error_patterns(
432 self,
433 conv: tuple[str, str, str, str, str],
434 query: str,
435 error_type: str | None,
436 ) -> list[dict[str, Any]]:
437 """Process a single conversation for error patterns."""
438 conv_id, content, project, timestamp, _metadata = conv
439 patterns = self.error_matcher.extract_error_patterns(content)
440 results = []
442 for pattern in patterns:
443 if error_type and pattern["subtype"] != error_type:
444 continue
446 relevance = self._calculate_error_relevance(pattern, query)
447 if relevance > 0.2: # Lower threshold for errors
448 results.append(
449 {
450 "conversation_id": conv_id,
451 "project": project,
452 "timestamp": timestamp,
453 "pattern": pattern,
454 "relevance": relevance,
455 "snippet": content[:500] + "..."
456 if len(content) > 500
457 else content,
458 },
459 )
461 return results
463 async def search_temporal(
464 self,
465 time_expression: str,
466 query: str | None = None,
467 limit: int = 10,
468 ) -> list[dict[str, Any]]:
469 """Search conversations within a time range."""
470 time_range = self.temporal_parser.parse_time_expression(
471 time_expression,
472 )
474 if not time_range.start or not time_range.end:
475 return [{"error": f"Could not parse time expression: {time_expression}"}]
477 start_time = time_range.start
478 end_time = time_range.end
480 results = []
482 if hasattr(self.reflection_db, "conn") and self.reflection_db.conn:
483 # Convert to ISO format for database query
484 start_iso = start_time.isoformat()
485 end_iso = end_time.isoformat()
487 sql_query = """
488 SELECT id, content, project, timestamp, metadata
489 FROM conversations
490 WHERE timestamp BETWEEN ? AND ?
491 ORDER BY timestamp DESC
492 """
494 cursor = self.reflection_db.conn.execute(sql_query, (start_iso, end_iso))
495 conversations = cursor.fetchall()
497 for conv in conversations:
498 conv_id, content, project, timestamp, _metadata = conv
500 # If query provided, filter by content relevance
501 if query:
502 relevance = self._calculate_text_relevance(content, query)
503 if relevance < 0.3:
504 continue
505 else:
506 relevance = 1.0
508 results.append(
509 {
510 "conversation_id": conv_id,
511 "project": project,
512 "timestamp": timestamp,
513 "content": content[:500] + "..."
514 if len(content) > 500
515 else content,
516 "relevance": relevance,
517 },
518 )
520 return results[:limit]
522 def _calculate_code_relevance(self, pattern: dict[str, Any], query: str) -> float:
523 """Calculate relevance score for code patterns."""
524 relevance = 0.0
525 query_lower = query.lower()
527 # Type matching
528 if pattern["type"] in query_lower:
529 relevance += 0.5
531 # Name matching (for functions/classes)
532 if "name" in pattern and pattern["name"].lower() in query_lower:
533 relevance += 0.7
535 # Content matching
536 if query_lower in pattern["content"].lower():
537 relevance += 0.4
539 # Module/import matching
540 if "modules" in pattern:
541 for module in pattern["modules"]:
542 if module.lower() in query_lower:
543 relevance += 0.3
545 return min(relevance, 1.0)
547 def _calculate_error_relevance(self, pattern: dict[str, Any], query: str) -> float:
548 """Calculate relevance score for error patterns."""
549 relevance = 0.0
550 query_lower = query.lower()
552 # Error type matching
553 if pattern["subtype"] in query_lower:
554 relevance += 0.6
556 # Content matching
557 if "content" in pattern and query_lower in pattern["content"].lower():
558 relevance += 0.5
560 # Context relevance boost
561 if pattern["type"] == "context" and pattern.get("relevance") == "high":
562 relevance += 0.3
564 return min(relevance, 1.0)
566 def _calculate_text_relevance(self, content: str, query: str) -> float:
567 """Simple text relevance calculation."""
568 query_lower = query.lower()
569 content_lower = content.lower()
571 # Simple keyword matching
572 query_words = query_lower.split()
573 content_words = content_lower.split()
575 matches = sum(1 for word in query_words if word in content_words)
576 return matches / len(query_words) if query_words else 0.0