Coverage for session_mgmt_mcp/search_enhanced.py: 10.20%
204 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
1#!/usr/bin/env python3
2"""Enhanced Search Capabilities for Session Management MCP Server.
4Provides multi-modal search including code snippets, error patterns, and time-based queries.
5"""
7import ast
8import re
9from datetime import datetime, timedelta
10from typing import Any
12try:
13 from dateutil.parser import parse as parse_date
14 from dateutil.relativedelta import relativedelta
16 DATEUTIL_AVAILABLE = True
17except ImportError:
18 DATEUTIL_AVAILABLE = False
20from .reflection_tools import ReflectionDatabase
23class CodeSearcher:
24 """AST-based code search for Python code snippets."""
26 def __init__(self) -> None:
27 self.search_types = {
28 "function": ast.FunctionDef,
29 "class": ast.ClassDef,
30 "import": (ast.Import, ast.ImportFrom),
31 "assignment": ast.Assign,
32 "call": ast.Call,
33 "loop": (ast.For, ast.While),
34 "conditional": ast.If,
35 "try": ast.Try,
36 "async": (ast.AsyncFunctionDef, ast.AsyncWith, ast.AsyncFor),
37 }
39 def extract_code_patterns(self, content: str) -> list[dict[str, Any]]:
40 """Extract code patterns from conversation content."""
41 patterns = []
43 # Extract Python code blocks
44 code_blocks = re.findall(r"```python\n(.*?)\n```", content, re.DOTALL)
45 code_blocks.extend(re.findall(r"```\n(.*?)\n```", content, re.DOTALL))
47 for i, code in enumerate(code_blocks):
48 try:
49 tree = ast.parse(code)
50 for node in ast.walk(tree):
51 for pattern_type, node_types in self.search_types.items():
52 if isinstance(node, node_types):
53 pattern_info = {
54 "type": pattern_type,
55 "content": code,
56 "block_index": i,
57 "line_number": getattr(node, "lineno", 0),
58 }
60 # Extract specific information based on node type
61 if isinstance(node, ast.FunctionDef):
62 pattern_info["name"] = node.name
63 pattern_info["args"] = [
64 arg.arg for arg in node.args.args
65 ]
66 elif isinstance(node, ast.ClassDef):
67 pattern_info["name"] = node.name
68 elif isinstance(node, ast.Import | ast.ImportFrom):
69 if isinstance(node, ast.Import):
70 pattern_info["modules"] = [
71 alias.name for alias in node.names
72 ]
73 else:
74 pattern_info["module"] = node.module
75 pattern_info["names"] = [
76 alias.name for alias in node.names
77 ]
79 patterns.append(pattern_info)
81 except (SyntaxError, ValueError):
82 # Not valid Python code, skip
83 continue
85 return patterns
88class ErrorPatternMatcher:
89 """Pattern matching for error messages and debugging contexts."""
91 def __init__(self) -> None:
92 self.error_patterns = {
93 "python_traceback": r"Traceback \(most recent call last\):.*?(?=\n\n|\Z)",
94 "python_exception": r"(\w+Error): (.+)",
95 "javascript_error": r"(Error|TypeError|ReferenceError): (.+)",
96 "compile_error": r"(error|Error): (.+) at line (\d+)",
97 "warning": r"(warning|Warning): (.+)",
98 "assertion": r"AssertionError: (.+)",
99 "import_error": r"ImportError: (.+)",
100 "module_not_found": r"ModuleNotFoundError: (.+)",
101 "file_not_found": r"FileNotFoundError: (.+)",
102 "permission_denied": r"PermissionError: (.+)",
103 "network_error": r"(ConnectionError|TimeoutError|HTTPError): (.+)",
104 }
106 self.context_patterns = {
107 "debugging": r"(debug|debugging|breakpoint|pdb|print\()",
108 "testing": r"(test|pytest|unittest|assert|mock)",
109 "error_handling": r"(try|except|finally|raise|catch)",
110 "performance": r"(slow|performance|benchmark|optimize|profil)",
111 "security": r"(security|authentication|authorization|token|password)",
112 }
114 def extract_error_patterns(self, content: str) -> list[dict[str, Any]]:
115 """Extract error patterns and debugging context from content."""
116 patterns = []
118 # Find error patterns
119 for pattern_name, regex in self.error_patterns.items():
120 matches = re.finditer(regex, content, re.MULTILINE | re.DOTALL)
121 for match in matches:
122 patterns.append(
123 {
124 "type": "error",
125 "subtype": pattern_name,
126 "content": match.group(0),
127 "start": match.start(),
128 "end": match.end(),
129 "groups": match.groups() if match.groups() else [],
130 },
131 )
133 # Find context patterns
134 for context_name, regex in self.context_patterns.items():
135 if re.search(regex, content, re.IGNORECASE):
136 patterns.append(
137 {
138 "type": "context",
139 "subtype": context_name,
140 "content": content,
141 "relevance": "high"
142 if context_name in ["debugging", "error_handling"]
143 else "medium",
144 },
145 )
147 return patterns
150class TemporalSearchParser:
151 """Parse natural language time expressions for conversation search."""
153 def __init__(self) -> None:
154 self.relative_patterns = {
155 "today": timedelta(hours=0),
156 "yesterday": timedelta(days=1),
157 "this week": timedelta(weeks=1),
158 "last week": timedelta(weeks=1, days=7),
159 "this month": relativedelta(months=1)
160 if DATEUTIL_AVAILABLE
161 else timedelta(days=30),
162 "last month": relativedelta(months=2)
163 if DATEUTIL_AVAILABLE
164 else timedelta(days=60),
165 "this year": relativedelta(years=1)
166 if DATEUTIL_AVAILABLE
167 else timedelta(days=365),
168 }
170 self.time_patterns = [
171 r"(\d+)\s+(minute|hour|day|week|month|year)s?\s+ago",
172 r"(today|yesterday|this\s+week|last\s+week|this\s+month|last\s+month)",
173 r"since\s+(today|yesterday|this\s+week|last\s+week)",
174 r"in\s+the\s+last\s+(\d+)\s+(minute|hour|day|week|month|year)s?",
175 r"(\d{4}-\d{2}-\d{2})", # ISO date
176 r"(\d{1,2}/\d{1,2}/\d{4})", # MM/DD/YYYY
177 ]
179 def _calculate_delta(self, amount: int, unit: str) -> timedelta:
180 """Calculate timedelta from amount and unit."""
181 if unit == "minute":
182 return timedelta(minutes=amount)
183 if unit == "hour":
184 return timedelta(hours=amount)
185 if unit == "day":
186 return timedelta(days=amount)
187 if unit == "week":
188 return timedelta(weeks=amount)
189 if unit == "month":
190 return (
191 relativedelta(months=amount)
192 if DATEUTIL_AVAILABLE
193 else timedelta(days=amount * 30)
194 )
195 if unit == "year":
196 return (
197 relativedelta(years=amount)
198 if DATEUTIL_AVAILABLE
199 else timedelta(days=amount * 365)
200 )
201 return timedelta()
203 def _parse_relative_patterns(
204 self,
205 expression: str,
206 now: datetime,
207 ) -> tuple[datetime | None, datetime | None]:
208 """Parse relative time patterns."""
209 for pattern, delta in self.relative_patterns.items():
210 if pattern in expression:
211 if "last" in pattern or pattern == "yesterday":
212 end_time = now - delta
213 start_time = end_time - delta
214 else:
215 start_time = now - delta
216 end_time = now
217 return start_time, end_time
218 return None, None
220 def _parse_ago_pattern(
221 self,
222 expression: str,
223 now: datetime,
224 ) -> tuple[datetime | None, datetime | None]:
225 """Parse 'X time units ago' pattern."""
226 match = re.search(
227 r"(\d+)\s+(minute|hour|day|week|month|year)s?\s+ago",
228 expression,
229 )
230 if match:
231 amount = int(match.group(1))
232 unit = match.group(2)
233 delta = self._calculate_delta(amount, unit)
234 end_time = now - delta
235 return end_time, now
236 return None, None
238 def _parse_last_pattern(
239 self,
240 expression: str,
241 now: datetime,
242 ) -> tuple[datetime | None, datetime | None]:
243 """Parse 'in the last X units' pattern."""
244 match = re.search(
245 r"in\s+the\s+last\s+(\d+)\s+(minute|hour|day|week|month|year)s?",
246 expression,
247 )
248 if match:
249 amount = int(match.group(1))
250 unit = match.group(2)
251 delta = self._calculate_delta(amount, unit)
252 start_time = now - delta
253 return start_time, now
254 return None, None
256 def _parse_absolute_date(
257 self,
258 expression: str,
259 ) -> tuple[datetime | None, datetime | None]:
260 """Parse absolute date expressions."""
261 if not DATEUTIL_AVAILABLE:
262 return None, None
264 try:
265 parsed_date = parse_date(expression)
266 # Return day range (start of day to end of day)
267 start_time = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
268 end_time = start_time + timedelta(days=1)
269 return start_time, end_time
270 except (ValueError, TypeError):
271 return None, None
273 def parse_time_expression(
274 self,
275 expression: str,
276 ) -> tuple[datetime | None, datetime | None]:
277 """Parse time expression into start and end datetime."""
278 expression = expression.lower().strip()
279 now = datetime.now()
281 # Try different parsing strategies
282 parsers = [
283 self._parse_relative_patterns,
284 self._parse_ago_pattern,
285 self._parse_last_pattern,
286 lambda expr, dt: self._parse_absolute_date(expr),
287 ]
289 for parser in parsers:
290 result = parser(expression, now)
291 if result != (None, None):
292 return result
294 return None, None
297class EnhancedSearchEngine:
298 """Main search engine that combines all enhanced search capabilities."""
300 def __init__(self, reflection_db: ReflectionDatabase) -> None:
301 self.reflection_db = reflection_db
302 self.code_searcher = CodeSearcher()
303 self.error_matcher = ErrorPatternMatcher()
304 self.temporal_parser = TemporalSearchParser()
306 async def search_code_patterns(
307 self,
308 query: str,
309 pattern_type: str | None = None,
310 limit: int = 10,
311 ) -> list[dict[str, Any]]:
312 """Search for code patterns in conversations."""
313 results = []
315 # Get all conversations from database
316 if hasattr(self.reflection_db, "conn") and self.reflection_db.conn:
317 cursor = self.reflection_db.conn.execute(
318 "SELECT id, content, project, timestamp, metadata FROM conversations",
319 )
320 conversations = cursor.fetchall()
322 for conv in conversations:
323 conv_id, content, project, timestamp, metadata = conv
324 patterns = self.code_searcher.extract_code_patterns(content)
326 for pattern in patterns:
327 if pattern_type and pattern["type"] != pattern_type:
328 continue
330 # Calculate relevance based on query similarity
331 relevance = self._calculate_code_relevance(pattern, query)
333 if relevance > 0.3: # Threshold for relevance
334 results.append(
335 {
336 "conversation_id": conv_id,
337 "project": project,
338 "timestamp": timestamp,
339 "pattern": pattern,
340 "relevance": relevance,
341 "snippet": content[:500] + "..."
342 if len(content) > 500
343 else content,
344 },
345 )
347 # Sort by relevance and limit results
348 results.sort(key=lambda x: x["relevance"], reverse=True)
349 return results[:limit]
351 async def search_error_patterns(
352 self,
353 query: str,
354 error_type: str | None = None,
355 limit: int = 10,
356 ) -> list[dict[str, Any]]:
357 """Search for error patterns and debugging contexts."""
358 results = []
360 if hasattr(self.reflection_db, "conn") and self.reflection_db.conn:
361 cursor = self.reflection_db.conn.execute(
362 "SELECT id, content, project, timestamp, metadata FROM conversations",
363 )
364 conversations = cursor.fetchall()
366 for conv in conversations:
367 conv_id, content, project, timestamp, metadata = conv
368 patterns = self.error_matcher.extract_error_patterns(content)
370 for pattern in patterns:
371 if error_type and pattern["subtype"] != error_type:
372 continue
374 # Calculate relevance based on query similarity
375 relevance = self._calculate_error_relevance(pattern, query)
377 if relevance > 0.2: # Lower threshold for errors
378 results.append(
379 {
380 "conversation_id": conv_id,
381 "project": project,
382 "timestamp": timestamp,
383 "pattern": pattern,
384 "relevance": relevance,
385 "snippet": content[:500] + "..."
386 if len(content) > 500
387 else content,
388 },
389 )
391 # Sort by relevance and limit results
392 results.sort(key=lambda x: x["relevance"], reverse=True)
393 return results[:limit]
395 async def search_temporal(
396 self,
397 time_expression: str,
398 query: str | None = None,
399 limit: int = 10,
400 ) -> list[dict[str, Any]]:
401 """Search conversations within a time range."""
402 start_time, end_time = self.temporal_parser.parse_time_expression(
403 time_expression,
404 )
406 if not start_time or not end_time:
407 return [{"error": f"Could not parse time expression: {time_expression}"}]
409 results = []
411 if hasattr(self.reflection_db, "conn") and self.reflection_db.conn:
412 # Convert to ISO format for database query
413 start_iso = start_time.isoformat()
414 end_iso = end_time.isoformat()
416 sql_query = """
417 SELECT id, content, project, timestamp, metadata
418 FROM conversations
419 WHERE timestamp BETWEEN ? AND ?
420 ORDER BY timestamp DESC
421 """
423 cursor = self.reflection_db.conn.execute(sql_query, (start_iso, end_iso))
424 conversations = cursor.fetchall()
426 for conv in conversations:
427 conv_id, content, project, timestamp, metadata = conv
429 # If query provided, filter by content relevance
430 if query:
431 relevance = self._calculate_text_relevance(content, query)
432 if relevance < 0.3:
433 continue
434 else:
435 relevance = 1.0
437 results.append(
438 {
439 "conversation_id": conv_id,
440 "project": project,
441 "timestamp": timestamp,
442 "content": content[:500] + "..."
443 if len(content) > 500
444 else content,
445 "relevance": relevance,
446 },
447 )
449 return results[:limit]
451 def _calculate_code_relevance(self, pattern: dict[str, Any], query: str) -> float:
452 """Calculate relevance score for code patterns."""
453 relevance = 0.0
454 query_lower = query.lower()
456 # Type matching
457 if pattern["type"] in query_lower:
458 relevance += 0.5
460 # Name matching (for functions/classes)
461 if "name" in pattern and pattern["name"].lower() in query_lower:
462 relevance += 0.7
464 # Content matching
465 if query_lower in pattern["content"].lower():
466 relevance += 0.4
468 # Module/import matching
469 if "modules" in pattern:
470 for module in pattern["modules"]:
471 if module.lower() in query_lower:
472 relevance += 0.3
474 return min(relevance, 1.0)
476 def _calculate_error_relevance(self, pattern: dict[str, Any], query: str) -> float:
477 """Calculate relevance score for error patterns."""
478 relevance = 0.0
479 query_lower = query.lower()
481 # Error type matching
482 if pattern["subtype"] in query_lower:
483 relevance += 0.6
485 # Content matching
486 if "content" in pattern and query_lower in pattern["content"].lower():
487 relevance += 0.5
489 # Context relevance boost
490 if pattern["type"] == "context" and pattern.get("relevance") == "high":
491 relevance += 0.3
493 return min(relevance, 1.0)
495 def _calculate_text_relevance(self, content: str, query: str) -> float:
496 """Simple text relevance calculation."""
497 query_lower = query.lower()
498 content_lower = content.lower()
500 # Simple keyword matching
501 query_words = query_lower.split()
502 content_words = content_lower.split()
504 matches = sum(1 for word in query_words if word in content_words)
505 return matches / len(query_words) if query_words else 0.0