Coverage for session_buddy / tools / validated_memory_tools.py: 77.55%
291 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Example integration of parameter validation models with MCP tools.
4This module demonstrates how to integrate Pydantic parameter validation
5models with existing MCP tools for improved type safety and error handling.
7Following crackerjack patterns:
8- EVERY LINE IS A LIABILITY: Clean, focused tool implementations
9- DRY: Reusable validation across all tools
10- KISS: Simple integration without over-engineering
12Refactored to use utility modules for reduced code duplication.
13"""
15from __future__ import annotations
17# ============================================================================
18# Helper Functions
19# ============================================================================
20from contextlib import suppress
21from datetime import datetime
22from typing import TYPE_CHECKING, Any
24from session_buddy.adapters.reflection_adapter import ReflectionDatabaseAdapter
25from session_buddy.parameter_models import (
26 ConceptSearchParams,
27 FileSearchParams,
28 ReflectionStoreParams,
29 SearchQueryParams,
30 validate_mcp_params,
31)
32from session_buddy.reflection_tools import ReflectionDatabase
33from session_buddy.utils.error_handlers import ValidationError, _get_logger
34from session_buddy.utils.tool_wrapper import execute_database_tool
36# Define type alias for backward compatibility during migration
37# NOTE: With 'from __future__ import annotations', we use the actual types, not strings
38ReflectionDatabaseType = ReflectionDatabaseAdapter | ReflectionDatabase
41async def _get_reflection_database() -> ReflectionDatabaseType:
42 """Get reflection database instance with cached availability semantics."""
43 db = await _get_reflection_database_async()
44 if db is None: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true
45 msg = "Reflection tools not available"
46 raise ImportError(msg)
47 return db
50def _format_result_item(res: dict[str, Any], index: int) -> list[str]:
51 """Format a single search result item."""
52 lines = [f"\n{index}. 📝 {res['content'][:200]}..."]
53 if res.get("project"): 53 ↛ 55line 53 didn't jump to line 55 because the condition on line 53 was always true
54 lines.append(f" 📁 Project: {res['project']}")
55 if res.get("score") is not None: 55 ↛ 57line 55 didn't jump to line 57 because the condition on line 55 was always true
56 lines.append(f" ⭐ Relevance: {res['score']:.2f}")
57 if res.get("timestamp"): 57 ↛ 59line 57 didn't jump to line 59 because the condition on line 57 was always true
58 lines.append(f" 📅 Date: {res['timestamp']}")
59 return lines
62def _format_search_results(results: list[dict[str, Any]]) -> list[str]:
63 """Format search results with common structure."""
64 if not results:
65 return [
66 "🔍 No conversations found about this file",
67 "💡 The file might not have been discussed in previous sessions",
68 ]
70 lines = [f"📈 Found {len(results)} relevant conversations:"]
71 for i, res in enumerate(results, 1):
72 lines.extend(_format_result_item(res, i))
73 return lines
76def _format_concept_results(
77 results: list[dict[str, Any]], include_files: bool
78) -> list[str]:
79 """Format concept search results with optional file information."""
80 if not results:
81 return [
82 "🔍 No conversations found about this concept",
83 "💡 Try related terms or broader concepts",
84 ]
86 lines = [f"📈 Found {len(results)} related conversations:"]
87 for i, res in enumerate(results, 1):
88 item_lines = [f"\n{i}. 📝 {res['content'][:250]}..."]
89 if res.get("project"): 89 ↛ 91line 89 didn't jump to line 91 because the condition on line 89 was always true
90 item_lines.append(f" 📁 Project: {res['project']}")
91 if res.get("score") is not None: 91 ↛ 93line 91 didn't jump to line 93 because the condition on line 91 was always true
92 item_lines.append(f" ⭐ Relevance: {res['score']:.2f}")
93 if res.get("timestamp"): 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was always true
94 item_lines.append(f" 📅 Date: {res['timestamp']}")
95 if include_files and res.get("files"): 95 ↛ 99line 95 didn't jump to line 99 because the condition on line 95 was always true
96 files = res["files"][:3]
97 if files: 97 ↛ 99line 97 didn't jump to line 99 because the condition on line 97 was always true
98 item_lines.append(f" 📄 Files: {', '.join(files)}")
99 lines.extend(item_lines)
100 return lines
103# ============================================================================
104# Validated Tool Implementations
105# ============================================================================
108def _validate_reflection_params(**params: Any) -> ReflectionStoreParams | str:
109 """Validate reflection store parameters.
111 Args:
112 **params: Raw parameters from MCP call
114 Returns:
115 Validated params object or error message string
117 """
118 from typing import cast
120 try:
121 validated = validate_mcp_params(ReflectionStoreParams, **params)
122 if not validated.is_valid:
123 return f"Parameter validation error: {validated.errors}"
124 return cast("ReflectionStoreParams", validated.params)
125 except ValidationError as e:
126 return f"Parameter validation error: {e}"
129async def _execute_store_reflection(
130 params_obj: ReflectionStoreParams, db: Any
131) -> dict[str, Any]:
132 """Execute the reflection storage operation.
134 Args:
135 params_obj: Validated parameters
136 db: Database instance
138 Returns:
139 Operation result dictionary
141 """
142 reflection_id = await db.store_reflection(
143 params_obj.content,
144 tags=params_obj.tags or [],
145 )
147 return {
148 "success": reflection_id not in (None, False),
149 "id": reflection_id,
150 "content": params_obj.content,
151 "tags": params_obj.tags or [],
152 "timestamp": datetime.now().isoformat(),
153 }
156def _format_reflection_result(result: dict[str, Any]) -> str:
157 """Format reflection storage result for user display.
159 Args:
160 result: Operation result dictionary
162 Returns:
163 Formatted string message
165 """
166 lines = [
167 "💾 Reflection stored successfully!",
168 f"🆔 ID: {result['id']}",
169 f"📝 Content: {result['content'][:100]}...",
170 ]
171 if result["tags"]:
172 lines.append(f"🏷️ Tags: {', '.join(result['tags'])}")
173 lines.append(f"📅 Stored: {result['timestamp']}")
175 _get_logger().info(
176 f"Validated reflection stored | Context: {{'reflection_id': '{result['id']}', 'content_length': {len(result['content'])}, 'tags_count': {len(result['tags']) if result['tags'] else 0}}}"
177 )
178 return "\n".join(lines)
181async def _store_reflection_validated_impl(**params: Any) -> str:
182 """Implementation for store_reflection tool with parameter validation."""
183 # Check if tools are available
184 if not _check_reflection_tools_available():
185 return "❌ Reflection tools not available. Install with: `uv sync --extra embeddings`\n💡 This enables conversation memory and semantic search capabilities."
187 # Validate parameters
188 params_validation = _validate_reflection_params(**params)
189 if isinstance(params_validation, str):
190 return params_validation
191 params_obj = params_validation
193 try:
194 # Get database instance
195 db = await _get_reflection_database_async()
196 if not db: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 return "❌ Failed to connect to reflection database"
199 # Execute storage operation
200 result = await _execute_store_reflection(params_obj, db)
201 if not result["success"]:
202 error_msg = f"Failed to store reflection: {result['id']}"
203 _get_logger().error(error_msg)
204 return error_msg
206 return _format_reflection_result(result)
208 except ValidationError as e:
209 return f"Parameter validation failed: {e}"
210 except ImportError:
211 error_msg = "Failed to connect to reflection database: Import error"
212 _get_logger().error(error_msg)
213 return error_msg
214 except Exception as e:
215 error_msg = f"Failed to store reflection: {e}"
216 _get_logger().error(error_msg)
217 return error_msg
220async def _quick_search_validated_impl(**params: Any) -> str:
221 """Implementation for quick_search tool with parameter validation."""
222 from typing import cast
224 # Validate parameters
225 try:
226 validated = validate_mcp_params(SearchQueryParams, **params)
227 if not validated.is_valid:
228 return f"Parameter validation error: {validated.errors}"
229 params_obj = cast("SearchQueryParams", validated.params)
230 except ValidationError as e:
231 return f"Parameter validation error: {e}"
233 async def operation(db: Any) -> dict[str, Any]:
234 """Quick search operation."""
235 results = await db.search_reflections(
236 params_obj.query,
237 limit=1,
238 min_score=params_obj.min_score,
239 )
241 return {
242 "query": params_obj.query,
243 "results": results,
244 "total_count": len(results),
245 }
247 def formatter(result: dict[str, Any]) -> str:
248 """Format quick search results."""
249 lines = [f"🔍 Quick search for: '{result['query']}'"]
251 if not result["results"]:
252 lines.extend(
253 [
254 "🔍 No results found",
255 "💡 Try adjusting your search terms or lowering min_score",
256 ]
257 )
258 else:
259 lines.extend(_format_top_result(result["results"][0]))
261 _get_logger().info(
262 f"Validated quick search executed | Context: {{'query': '{result['query']}', 'results_count': {result['total_count']}}}"
263 )
264 return "\n".join(lines)
266 # Check if tools are available
267 if not _check_reflection_tools_available():
268 return "❌ Reflection tools not available. Install with: `uv sync --extra embeddings`\n💡 This enables conversation memory and semantic search capabilities."
270 try:
271 # Get database instance and execute operation
272 db = await _get_reflection_database_async()
273 if not db: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true
274 return "❌ Failed to connect to reflection database"
276 result = await operation(db)
277 return formatter(result)
278 except ValidationError as e:
279 # Return validation errors as strings instead of raising
280 return f"Parameter validation failed: {e}"
281 except ImportError:
282 # Handle import errors from database initialization
283 error_msg = "Failed to connect to reflection database: Import error"
284 _get_logger().error(error_msg)
285 return error_msg
286 except Exception as e:
287 error_msg = f"Failed to perform quick search: {e}"
288 _get_logger().error(error_msg)
289 return error_msg
292def _format_top_result(top_result: dict[str, Any]) -> list[str]:
293 """Format the top search result."""
294 lines = [
295 "📊 Found results (showing top 1)",
296 f"📝 {top_result['content'][:150]}...",
297 ]
298 if top_result.get("project"): 298 ↛ 300line 298 didn't jump to line 300 because the condition on line 298 was always true
299 lines.append(f"📁 Project: {top_result['project']}")
300 if top_result.get("score") is not None: 300 ↛ 302line 300 didn't jump to line 302 because the condition on line 300 was always true
301 lines.append(f"⭐ Relevance: {top_result['score']:.2f}")
302 if top_result.get("timestamp"): 302 ↛ 305line 302 didn't jump to line 305 because the condition on line 302 was always true
303 lines.append(f"📅 Date: {top_result['timestamp']}")
305 return lines
308async def _search_by_file_validated_impl(**params: Any) -> str:
309 """Implementation for search_by_file tool with parameter validation."""
310 from typing import cast
312 # Validate parameters
313 try:
314 validated = validate_mcp_params(FileSearchParams, **params)
315 if not validated.is_valid:
316 return f"Parameter validation error: {validated.errors}"
317 params_obj = cast("FileSearchParams", validated.params)
318 except ValidationError as e:
319 return f"Parameter validation error: {e}"
321 async def operation(db: Any) -> dict[str, Any]:
322 """File search operation."""
323 results = await db.search_reflections(
324 params_obj.file_path,
325 limit=params_obj.limit,
326 min_score=params_obj.min_score,
327 )
329 return {
330 "file_path": params_obj.file_path,
331 "results": results,
332 }
334 def formatter(result: dict[str, Any]) -> str:
335 """Format file search results."""
336 file_path = result["file_path"]
337 results = result["results"]
339 lines = [f"📁 Searching conversations about: {file_path}", "=" * 50]
340 lines.extend(_format_search_results(results))
342 _get_logger().info(
343 f"Validated file search executed | Context: {{'file_path': '{file_path}', 'results_count': {len(results)}}}"
344 )
345 return "\n".join(lines)
347 # Check if tools are available
348 if not _check_reflection_tools_available():
349 return "❌ Reflection tools not available. Install with: `uv sync --extra embeddings`\n💡 This enables conversation memory and semantic search capabilities."
351 try:
352 # Get database instance and execute operation
353 db = await _get_reflection_database_async()
354 if not db: 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true
355 return "❌ Failed to connect to reflection database"
357 result = await operation(db)
358 return formatter(result)
359 except ValidationError as e:
360 # Return validation errors as strings instead of raising
361 return f"Parameter validation failed: {e}"
362 except ImportError:
363 # Handle import errors from database initialization
364 error_msg = "Failed to connect to reflection database: Import error"
365 _get_logger().error(error_msg)
366 return error_msg
367 except Exception as e:
368 error_msg = f"Failed to perform file search: {e}"
369 _get_logger().error(error_msg)
370 return error_msg
373async def _search_by_concept_validated_impl(**params: Any) -> str:
374 """Implementation for search_by_concept tool with parameter validation."""
375 from typing import cast
377 # Validate parameters
378 try:
379 validated = validate_mcp_params(ConceptSearchParams, **params)
380 if not validated.is_valid:
381 return f"Parameter validation error: {validated.errors}"
382 params_obj = cast("ConceptSearchParams", validated.params)
383 except ValidationError as e:
384 return f"Parameter validation error: {e}"
386 async def operation(db: Any) -> dict[str, Any]:
387 """Concept search operation."""
388 results = await db.search_reflections(
389 params_obj.concept,
390 limit=params_obj.limit,
391 min_score=params_obj.min_score,
392 )
394 return {
395 "concept": params_obj.concept,
396 "include_files": params_obj.include_files,
397 "results": results,
398 }
400 def formatter(result: dict[str, Any]) -> str:
401 """Format concept search results."""
402 concept = result["concept"]
403 results = result["results"]
405 lines = [f"🧠 Searching for concept: '{concept}'", "=" * 50]
406 lines.extend(_format_concept_results(results, result["include_files"]))
408 _get_logger().info(
409 f"Validated concept search executed | Context: {{'concept': '{concept}', 'results_count': {len(results)}}}"
410 )
411 return "\n".join(lines)
413 # Check if tools are available
414 if not _check_reflection_tools_available():
415 return "❌ Reflection tools not available. Install with: `uv sync --extra embeddings`\n💡 This enables conversation memory and semantic search capabilities."
417 try:
418 # Get database instance and execute operation
419 db = await _get_reflection_database_async()
420 if not db: 420 ↛ 421line 420 didn't jump to line 421 because the condition on line 420 was never true
421 return "❌ Failed to connect to reflection database"
423 result = await operation(db)
424 return formatter(result)
425 except ValidationError as e:
426 # Return validation errors as strings instead of raising
427 return f"Parameter validation failed: {e}"
428 except ImportError:
429 # Handle import errors from database initialization
430 error_msg = "Failed to connect to reflection database: Import error"
431 _get_logger().error(error_msg)
432 return error_msg
433 except Exception as e:
434 error_msg = f"Failed to perform concept search: {e}"
435 _get_logger().error(error_msg)
436 return error_msg
439def _format_file_search_header(file_path: str) -> list[str]:
440 """Format header for file search results."""
441 return [
442 f"📁 Searching conversations about: {file_path}",
443 "=" * 50,
444 ]
447def _format_file_search_result(res: dict[str, Any], index: int) -> list[str]:
448 """Format individual file search result."""
449 lines = [
450 f"{index}. 📝 {res['content'][:200]}...",
451 ]
453 if res.get("timestamp"): 453 ↛ 456line 453 didn't jump to line 456 because the condition on line 453 was always true
454 lines.append(f" 📅 Date: {res['timestamp']}")
456 if res.get("project"):
457 lines.append(f" 📁 Project: {res['project']}")
459 if res.get("score") is not None: 459 ↛ 462line 459 didn't jump to line 462 because the condition on line 459 was always true
460 lines.append(f" ⭐ Relevance: {res['score']:.2f}")
462 return lines
465def _format_file_search_results(results: list[dict[str, Any]], query: str) -> list[str]:
466 """Format all file search results."""
467 if not results:
468 return [
469 "No conversations found about this file",
470 f"🔍 No conversations found discussing '{query}'",
471 "💡 The file might not have been discussed in previous sessions",
472 ]
474 lines = [
475 f"📁 Searching conversations about: {query}",
476 "=" * 50,
477 f"📈 Found {len(results)} relevant conversations:",
478 ]
480 for i, res in enumerate(results, 1):
481 result_lines = _format_file_search_result(res, i)
482 if isinstance(result_lines, list): 482 ↛ 485line 482 didn't jump to line 485 because the condition on line 482 was always true
483 lines.extend(result_lines)
484 else:
485 lines.append(str(result_lines))
487 return lines
490def _format_validated_concept_result(
491 res: dict[str, Any], index: int, include_files: bool = True
492) -> list[str]:
493 """Format individual concept search result."""
494 lines = [
495 f"{index}. 🧠 Concept: {res['content'][:200]}...",
496 ]
498 if res.get("timestamp"): 498 ↛ 501line 498 didn't jump to line 501 because the condition on line 498 was always true
499 lines.append(f" 📅 Date: {res['timestamp']}")
501 if res.get("project"): 501 ↛ 504line 501 didn't jump to line 504 because the condition on line 501 was always true
502 lines.append(f" 📁 Project: {res['project']}")
504 if res.get("score") is not None: 504 ↛ 507line 504 didn't jump to line 507 because the condition on line 504 was always true
505 lines.append(f" ⭐ Relevance: {res['score']:.2f}")
507 if include_files and res.get("files"):
508 files = res["files"][:5] # Limit to 5 files
509 lines.append(f" 📄 Files: {', '.join(files)}")
511 return lines
514# Define missing classes for backward compatibility
515class ValidationExamples:
516 """Placeholder class for validation examples."""
518 def example_valid_calls(self) -> list[dict[str, Any]]:
519 """Get examples of valid calls."""
520 return [{"query": "test query", "limit": 5}]
522 def example_validation_errors(self) -> list[dict[str, str]]:
523 """Get examples of validation errors."""
524 return [{"field": "query", "error": "Field required"}]
527class MigrationGuide:
528 """Placeholder class for migration guide."""
530 @staticmethod
531 def before_migration() -> str:
532 """Get before migration instructions."""
533 return "Before migrating, backup your data."
535 @staticmethod
536 def after_migration() -> str:
537 """Get after migration instructions."""
538 return "After migrating, verify your configurations."
541# Global variable to cache reflection tools availability
542_reflection_tools_available: bool | None = None
545def _check_reflection_tools_available() -> bool:
546 """Check if reflection tools are available and properly installed."""
547 global _reflection_tools_available
549 if _reflection_tools_available is not None:
550 return _reflection_tools_available
552 try:
553 # Check if reflection database module can be imported
554 import importlib.util
556 spec = importlib.util.find_spec("session_buddy.reflection_tools")
557 available = spec is not None
558 _reflection_tools_available = available
559 return available
560 except Exception:
561 _reflection_tools_available = False
562 return False
565async def resolve_reflection_database() -> ReflectionDatabaseType | None:
566 """Resolve the reflection database instance using dependency injection or fallback."""
567 # Try to get from DI container
568 with suppress(Exception):
569 from typing import cast
571 from session_buddy.di.container import depends
572 from session_buddy.reflection_tools import ReflectionDatabase
574 db = depends.get_sync(ReflectionDatabase)
575 if db:
576 return cast("ReflectionDatabase", db)
578 # Fallback - get a direct instance
579 with suppress(Exception):
580 from session_buddy.reflection_tools import get_reflection_database
582 return await get_reflection_database()
584 return None
587async def _get_reflection_database_async() -> ReflectionDatabaseType | None:
588 """Get reflection database instance with lazy initialization."""
589 if not _check_reflection_tools_available():
590 msg = "Reflection tools not available"
591 raise ImportError(msg)
593 try:
594 db = await resolve_reflection_database()
595 if db is None:
596 msg = "Reflection tools not available"
597 raise ImportError(msg)
598 return db
599 except ImportError:
600 # Re-raise import errors as they indicate unavailability
601 raise
602 except Exception:
603 # For any other exception, treat as unavailable
604 msg = "Reflection tools not available"
605 raise ImportError(msg)
608# ============================================================================
609# MCP Tool Registration
610# ============================================================================
613def register_validated_memory_tools(mcp_server: Any) -> None:
614 """Register all validated memory tools with the MCP server.
616 These tools demonstrate parameter validation using Pydantic models
617 while using the same utility-based refactoring patterns as other tools.
618 """
620 @mcp_server.tool() # type: ignore[misc]
621 async def store_reflection_validated(**params: Any) -> str:
622 """Store a reflection with validated parameters.
624 This demonstrates how to integrate Pydantic parameter validation
625 with MCP tools for improved type safety.
626 """
627 return await _store_reflection_validated_impl(**params)
629 @mcp_server.tool() # type: ignore[misc]
630 async def quick_search_validated(**params: Any) -> str:
631 """Quick search with validated parameters."""
632 return await _quick_search_validated_impl(**params)
634 @mcp_server.tool() # type: ignore[misc]
635 async def search_by_file_validated(**params: Any) -> str:
636 """Search by file with validated parameters."""
637 return await _search_by_file_validated_impl(**params)
639 @mcp_server.tool() # type: ignore[misc]
640 async def search_by_concept_validated(**params: Any) -> str:
641 """Search by concept with validated parameters."""
642 return await _search_by_concept_validated_impl(**params)