Coverage for session_buddy / utils / tool_wrapper.py: 33.33%
105 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Tool wrapper utilities for MCP tools.
4This module provides high-level wrappers that combine error handling, database resolution,
5and message formatting to eliminate repetitive patterns in tool implementations.
6"""
8from __future__ import annotations
10from typing import TYPE_CHECKING, Any, TypeVar
12from session_buddy.utils.database_helpers import (
13 require_reflection_database,
14)
15from session_buddy.utils.error_handlers import (
16 DatabaseUnavailableError,
17 ValidationError,
18 _get_logger,
19)
20from session_buddy.utils.messages import ToolMessages
22if TYPE_CHECKING:
23 from collections.abc import Awaitable, Callable
25 from session_buddy.adapters.reflection_adapter import ReflectionDatabaseAdapter
27T = TypeVar("T")
30async def execute_database_tool(
31 operation: Callable[[ReflectionDatabaseAdapter], Awaitable[T]],
32 formatter: Callable[[T], str],
33 operation_name: str,
34 validator: Callable[[], None] | None = None,
35) -> str:
36 """Generic wrapper for database-dependent tools.
38 This is the most comprehensive wrapper that combines:
39 - Input validation
40 - Database resolution
41 - Operation execution
42 - Result formatting
43 - Error handling
45 Eliminates the common pattern of:
46 1. Validate inputs
47 2. Get database
48 3. Check if available
49 4. Execute operation
50 5. Format result
51 6. Handle errors
53 Args:
54 operation: Async function that takes database and returns result
55 formatter: Function to format the result as string
56 operation_name: Name of operation for error messages
57 validator: Optional function to validate inputs (raises ValidationError)
59 Returns:
60 Formatted string result or error message
62 Example:
63 >>> async def search_op(db):
64 ... return await db.search_reflections("test")
65 >>> def format_results(results):
66 ... return f"Found {len(results)} results"
67 >>> result = await execute_database_tool(
68 ... search_op, format_results, "Search reflections"
69 ... )
71 """
72 try:
73 # 1. Validate inputs if validator provided
74 if validator: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true
75 validator()
77 # 2. Get database
78 db = await require_reflection_database()
80 # 3. Execute operation
81 result = await operation(db)
83 # 4. Format result
84 return formatter(result)
86 except ValidationError as e:
87 return ToolMessages.validation_error(operation_name, str(e))
88 except DatabaseUnavailableError as e:
89 return ToolMessages.not_available(operation_name, str(e))
90 except Exception as e:
91 _get_logger().exception(f"Error in {operation_name}: {e}")
92 return ToolMessages.operation_failed(operation_name, e)
95async def execute_simple_database_tool(
96 operation: Callable[[ReflectionDatabaseAdapter], Awaitable[str]],
97 operation_name: str,
98) -> str:
99 """Simplified wrapper for database tools that return strings.
101 Use this when your operation already returns a formatted string.
102 Simpler than execute_database_tool when you don't need a separate formatter.
104 Args:
105 operation: Async function that takes database and returns string result
106 operation_name: Name of operation for error messages
108 Returns:
109 String result from operation or error message
111 Example:
112 >>> async def search_op(db):
113 ... results = await db.search_reflections("test")
114 ... return f"Found {len(results)} results"
115 >>> result = await execute_simple_database_tool(search_op, "Search")
117 """
118 try:
119 db = await require_reflection_database()
120 return await operation(db)
121 except DatabaseUnavailableError as e:
122 return ToolMessages.not_available(operation_name, str(e))
123 except Exception as e:
124 _get_logger().exception(f"Error in {operation_name}: {e}")
125 return ToolMessages.operation_failed(operation_name, e)
128async def execute_database_tool_with_dict(
129 operation: Callable[[ReflectionDatabaseAdapter], Awaitable[dict[str, Any]]],
130 operation_name: str,
131 validator: Callable[[], None] | None = None,
132) -> dict[str, Any]:
133 """Wrapper for database tools that return structured dictionaries.
135 Use this for tools that need to return structured data (success/error/data).
137 Args:
138 operation: Async function that takes database and returns dict
139 operation_name: Name of operation for error messages
140 validator: Optional function to validate inputs
142 Returns:
143 Dictionary with success/error fields
145 Example:
146 >>> async def search_op(db):
147 ... results = await db.search_reflections("test")
148 ... return {"results": results, "count": len(results)}
149 >>> result = await execute_database_tool_with_dict(search_op, "Search")
150 >>> if result.get("success"):
151 ... print(result["data"]["count"])
153 """
154 try:
155 if validator: 155 ↛ 158line 155 didn't jump to line 158 because the condition on line 155 was always true
156 validator()
158 db = await require_reflection_database()
159 data = await operation(db)
161 return {"success": True, "data": data}
163 except ValidationError as e:
164 return {
165 "success": False,
166 "error": f"{operation_name} validation failed: {e!s}",
167 }
168 except DatabaseUnavailableError as e:
169 return {"success": False, "error": str(e)}
170 except Exception as e:
171 _get_logger().exception(f"Error in {operation_name}: {e}")
172 return {"success": False, "error": f"{operation_name} failed: {e!s}"}
175async def execute_no_database_tool(
176 operation: Callable[..., Awaitable[T]],
177 formatter: Callable[[T], str],
178 operation_name: str,
179 *args: Any,
180 **kwargs: Any,
181) -> str:
182 """Wrapper for tools that don't need database access.
184 Provides error handling and logging for operations that don't require
185 database connectivity.
187 Args:
188 operation: Async function to execute
189 formatter: Function to format result as string
190 operation_name: Name of operation for error messages
191 *args: Arguments to pass to operation
192 **kwargs: Keyword arguments to pass to operation
194 Returns:
195 Formatted string result or error message
197 Example:
198 >>> async def validate_config():
199 ... return {"valid": True, "version": "1.0"}
200 >>> def format_config(data):
201 ... return f"Config valid: {data['version']}"
202 >>> result = await execute_no_database_tool(
203 ... validate_config, format_config, "Validate configuration"
204 ... )
206 """
207 try:
208 result = await operation(*args, **kwargs)
209 return formatter(result)
210 except Exception as e:
211 _get_logger().exception(f"Error in {operation_name}: {e}")
212 return ToolMessages.operation_failed(operation_name, e)
215def _validate_required_field(key: str, value: Any) -> None:
216 """Validate a required field."""
217 from session_buddy.utils.error_handlers import validate_required
219 field_name = key[9:] # Remove "required_" prefix
220 validate_required(value, field_name)
223def _validate_type_field(key: str, value: Any) -> None:
224 """Validate a type field."""
225 from session_buddy.utils.error_handlers import validate_type
227 parts = key.split("_")
228 if len(parts) < 3:
229 return
231 field_name = "_".join(parts[1:-1])
232 expected_type_name = parts[-1]
233 type_map = {
234 "str": str,
235 "int": int,
236 "float": float,
237 "bool": bool,
238 "list": list,
239 "dict": dict,
240 }
241 expected_type = type_map.get(expected_type_name)
243 if expected_type and isinstance(value, tuple) and len(value) == 2:
244 validate_type(value[0], expected_type, field_name)
247def _validate_range_field(key: str, value: Any) -> None:
248 """Validate a range field."""
249 from session_buddy.utils.error_handlers import validate_range
251 field_name = key[6:] # Remove "range_" prefix
252 if isinstance(value, tuple) and len(value) == 3:
253 validate_range(value[0], value[1], value[2], field_name)
256def create_validator(**validations: Any) -> Callable[[], None]:
257 """Create a validator function from validation rules.
259 Helper to create validator functions for use with execute_database_tool.
261 Args:
262 **validations: Validation rules as keyword arguments
263 - required_<name>: Value that must be non-empty
264 - type_<name>_<type>: (value, type) tuple to validate type
265 - range_<name>: (value, min, max) tuple to validate range
267 Returns:
268 Validator function that raises ValidationError if validation fails
270 Example:
271 >>> validator = create_validator(
272 ... required_query="",
273 ... type_limit_int=(limit, int),
274 ... range_limit=(limit, 1, 100),
275 ... )
276 >>> validator() # Raises ValidationError if invalid
278 """
280 def validator() -> None:
281 for key, value in validations.items():
282 if key.startswith("required_"):
283 _validate_required_field(key, value)
284 elif key.startswith("type_"):
285 _validate_type_field(key, value)
286 elif key.startswith("range_"):
287 _validate_range_field(key, value)
289 return validator
292def format_reflection_result(
293 success: bool,
294 content: str,
295 tags: list[str] | None = None,
296 timestamp: str | None = None,
297) -> str:
298 """Format a reflection storage result consistently.
300 Args:
301 success: Whether the operation succeeded
302 content: Content that was stored
303 tags: Tags that were applied
304 timestamp: When it was stored
306 Returns:
307 Formatted result message
309 Example:
310 >>> format_reflection_result(
311 ... True,
312 ... "Important insight",
313 ... ["learning", "bug-fix"],
314 ... "2025-01-12 14:30:00",
315 ... )
317 """
318 if not success: 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true
319 return ToolMessages.operation_failed("Store reflection", "Operation failed")
321 lines = ["💾 Reflection stored successfully!"]
322 lines.append(f"📝 Content: {ToolMessages.truncate_text(content, 100)}")
324 if tags:
325 lines.append(f"🏷️ Tags: {', '.join(tags)}")
327 if timestamp: 327 ↛ 330line 327 didn't jump to line 330 because the condition on line 327 was always true
328 lines.append(f"📅 Stored: {timestamp}")
330 return "\n".join(lines)
333def format_search_results(
334 results: list[dict[str, Any]],
335 query: str,
336 show_details: bool = True,
337 max_results: int = 10,
338) -> str:
339 """Format search results consistently.
341 Args:
342 results: List of search result dictionaries
343 query: Original search query
344 show_details: Whether to show result details
345 max_results: Maximum number of results to show
347 Returns:
348 Formatted search results
350 Example:
351 >>> results = [{"content": "test", "score": 0.95}]
352 >>> format_search_results(results, "test query")
354 """
355 if not results:
356 return ToolMessages.empty_results(
357 f'Search for "{query}"', "Try different search terms"
358 )
360 count = len(results)
361 lines = [
362 f'🔍 Found {ToolMessages.format_count(count, "result")} for "{query}"',
363 ]
365 if show_details:
366 display_count = min(count, max_results)
367 for i, result in enumerate(results[:display_count], 1):
368 lines.append(
369 f"\n{i}. {ToolMessages.truncate_text(result.get('content', ''), 80)}"
370 )
371 if "score" in result:
372 lines.append(f" Relevance: {result['score']:.2f}")
373 if "timestamp" in result:
374 lines.append(f" Time: {result['timestamp']}")
376 if count > max_results:
377 lines.append(f"\n... and {count - max_results} more results")
379 return "\n".join(lines)