Coverage for session_buddy / utils / tool_wrapper.py: 33.33%

105 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""Tool wrapper utilities for MCP tools. 

3 

4This module provides high-level wrappers that combine error handling, database resolution, 

5and message formatting to eliminate repetitive patterns in tool implementations. 

6""" 

7 

8from __future__ import annotations 

9 

10from typing import TYPE_CHECKING, Any, TypeVar 

11 

12from session_buddy.utils.database_helpers import ( 

13 require_reflection_database, 

14) 

15from session_buddy.utils.error_handlers import ( 

16 DatabaseUnavailableError, 

17 ValidationError, 

18 _get_logger, 

19) 

20from session_buddy.utils.messages import ToolMessages 

21 

22if TYPE_CHECKING: 

23 from collections.abc import Awaitable, Callable 

24 

25 from session_buddy.adapters.reflection_adapter import ReflectionDatabaseAdapter 

26 

27T = TypeVar("T") 

28 

29 

30async def execute_database_tool( 

31 operation: Callable[[ReflectionDatabaseAdapter], Awaitable[T]], 

32 formatter: Callable[[T], str], 

33 operation_name: str, 

34 validator: Callable[[], None] | None = None, 

35) -> str: 

36 """Generic wrapper for database-dependent tools. 

37 

38 This is the most comprehensive wrapper that combines: 

39 - Input validation 

40 - Database resolution 

41 - Operation execution 

42 - Result formatting 

43 - Error handling 

44 

45 Eliminates the common pattern of: 

46 1. Validate inputs 

47 2. Get database 

48 3. Check if available 

49 4. Execute operation 

50 5. Format result 

51 6. Handle errors 

52 

53 Args: 

54 operation: Async function that takes database and returns result 

55 formatter: Function to format the result as string 

56 operation_name: Name of operation for error messages 

57 validator: Optional function to validate inputs (raises ValidationError) 

58 

59 Returns: 

60 Formatted string result or error message 

61 

62 Example: 

63 >>> async def search_op(db): 

64 ... return await db.search_reflections("test") 

65 >>> def format_results(results): 

66 ... return f"Found {len(results)} results" 

67 >>> result = await execute_database_tool( 

68 ... search_op, format_results, "Search reflections" 

69 ... ) 

70 

71 """ 

72 try: 

73 # 1. Validate inputs if validator provided 

74 if validator: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true

75 validator() 

76 

77 # 2. Get database 

78 db = await require_reflection_database() 

79 

80 # 3. Execute operation 

81 result = await operation(db) 

82 

83 # 4. Format result 

84 return formatter(result) 

85 

86 except ValidationError as e: 

87 return ToolMessages.validation_error(operation_name, str(e)) 

88 except DatabaseUnavailableError as e: 

89 return ToolMessages.not_available(operation_name, str(e)) 

90 except Exception as e: 

91 _get_logger().exception(f"Error in {operation_name}: {e}") 

92 return ToolMessages.operation_failed(operation_name, e) 

93 

94 

95async def execute_simple_database_tool( 

96 operation: Callable[[ReflectionDatabaseAdapter], Awaitable[str]], 

97 operation_name: str, 

98) -> str: 

99 """Simplified wrapper for database tools that return strings. 

100 

101 Use this when your operation already returns a formatted string. 

102 Simpler than execute_database_tool when you don't need a separate formatter. 

103 

104 Args: 

105 operation: Async function that takes database and returns string result 

106 operation_name: Name of operation for error messages 

107 

108 Returns: 

109 String result from operation or error message 

110 

111 Example: 

112 >>> async def search_op(db): 

113 ... results = await db.search_reflections("test") 

114 ... return f"Found {len(results)} results" 

115 >>> result = await execute_simple_database_tool(search_op, "Search") 

116 

117 """ 

118 try: 

119 db = await require_reflection_database() 

120 return await operation(db) 

121 except DatabaseUnavailableError as e: 

122 return ToolMessages.not_available(operation_name, str(e)) 

123 except Exception as e: 

124 _get_logger().exception(f"Error in {operation_name}: {e}") 

125 return ToolMessages.operation_failed(operation_name, e) 

126 

127 

128async def execute_database_tool_with_dict( 

129 operation: Callable[[ReflectionDatabaseAdapter], Awaitable[dict[str, Any]]], 

130 operation_name: str, 

131 validator: Callable[[], None] | None = None, 

132) -> dict[str, Any]: 

133 """Wrapper for database tools that return structured dictionaries. 

134 

135 Use this for tools that need to return structured data (success/error/data). 

136 

137 Args: 

138 operation: Async function that takes database and returns dict 

139 operation_name: Name of operation for error messages 

140 validator: Optional function to validate inputs 

141 

142 Returns: 

143 Dictionary with success/error fields 

144 

145 Example: 

146 >>> async def search_op(db): 

147 ... results = await db.search_reflections("test") 

148 ... return {"results": results, "count": len(results)} 

149 >>> result = await execute_database_tool_with_dict(search_op, "Search") 

150 >>> if result.get("success"): 

151 ... print(result["data"]["count"]) 

152 

153 """ 

154 try: 

155 if validator: 155 ↛ 158line 155 didn't jump to line 158 because the condition on line 155 was always true

156 validator() 

157 

158 db = await require_reflection_database() 

159 data = await operation(db) 

160 

161 return {"success": True, "data": data} 

162 

163 except ValidationError as e: 

164 return { 

165 "success": False, 

166 "error": f"{operation_name} validation failed: {e!s}", 

167 } 

168 except DatabaseUnavailableError as e: 

169 return {"success": False, "error": str(e)} 

170 except Exception as e: 

171 _get_logger().exception(f"Error in {operation_name}: {e}") 

172 return {"success": False, "error": f"{operation_name} failed: {e!s}"} 

173 

174 

175async def execute_no_database_tool( 

176 operation: Callable[..., Awaitable[T]], 

177 formatter: Callable[[T], str], 

178 operation_name: str, 

179 *args: Any, 

180 **kwargs: Any, 

181) -> str: 

182 """Wrapper for tools that don't need database access. 

183 

184 Provides error handling and logging for operations that don't require 

185 database connectivity. 

186 

187 Args: 

188 operation: Async function to execute 

189 formatter: Function to format result as string 

190 operation_name: Name of operation for error messages 

191 *args: Arguments to pass to operation 

192 **kwargs: Keyword arguments to pass to operation 

193 

194 Returns: 

195 Formatted string result or error message 

196 

197 Example: 

198 >>> async def validate_config(): 

199 ... return {"valid": True, "version": "1.0"} 

200 >>> def format_config(data): 

201 ... return f"Config valid: {data['version']}" 

202 >>> result = await execute_no_database_tool( 

203 ... validate_config, format_config, "Validate configuration" 

204 ... ) 

205 

206 """ 

207 try: 

208 result = await operation(*args, **kwargs) 

209 return formatter(result) 

210 except Exception as e: 

211 _get_logger().exception(f"Error in {operation_name}: {e}") 

212 return ToolMessages.operation_failed(operation_name, e) 

213 

214 

215def _validate_required_field(key: str, value: Any) -> None: 

216 """Validate a required field.""" 

217 from session_buddy.utils.error_handlers import validate_required 

218 

219 field_name = key[9:] # Remove "required_" prefix 

220 validate_required(value, field_name) 

221 

222 

223def _validate_type_field(key: str, value: Any) -> None: 

224 """Validate a type field.""" 

225 from session_buddy.utils.error_handlers import validate_type 

226 

227 parts = key.split("_") 

228 if len(parts) < 3: 

229 return 

230 

231 field_name = "_".join(parts[1:-1]) 

232 expected_type_name = parts[-1] 

233 type_map = { 

234 "str": str, 

235 "int": int, 

236 "float": float, 

237 "bool": bool, 

238 "list": list, 

239 "dict": dict, 

240 } 

241 expected_type = type_map.get(expected_type_name) 

242 

243 if expected_type and isinstance(value, tuple) and len(value) == 2: 

244 validate_type(value[0], expected_type, field_name) 

245 

246 

247def _validate_range_field(key: str, value: Any) -> None: 

248 """Validate a range field.""" 

249 from session_buddy.utils.error_handlers import validate_range 

250 

251 field_name = key[6:] # Remove "range_" prefix 

252 if isinstance(value, tuple) and len(value) == 3: 

253 validate_range(value[0], value[1], value[2], field_name) 

254 

255 

256def create_validator(**validations: Any) -> Callable[[], None]: 

257 """Create a validator function from validation rules. 

258 

259 Helper to create validator functions for use with execute_database_tool. 

260 

261 Args: 

262 **validations: Validation rules as keyword arguments 

263 - required_<name>: Value that must be non-empty 

264 - type_<name>_<type>: (value, type) tuple to validate type 

265 - range_<name>: (value, min, max) tuple to validate range 

266 

267 Returns: 

268 Validator function that raises ValidationError if validation fails 

269 

270 Example: 

271 >>> validator = create_validator( 

272 ... required_query="", 

273 ... type_limit_int=(limit, int), 

274 ... range_limit=(limit, 1, 100), 

275 ... ) 

276 >>> validator() # Raises ValidationError if invalid 

277 

278 """ 

279 

280 def validator() -> None: 

281 for key, value in validations.items(): 

282 if key.startswith("required_"): 

283 _validate_required_field(key, value) 

284 elif key.startswith("type_"): 

285 _validate_type_field(key, value) 

286 elif key.startswith("range_"): 

287 _validate_range_field(key, value) 

288 

289 return validator 

290 

291 

292def format_reflection_result( 

293 success: bool, 

294 content: str, 

295 tags: list[str] | None = None, 

296 timestamp: str | None = None, 

297) -> str: 

298 """Format a reflection storage result consistently. 

299 

300 Args: 

301 success: Whether the operation succeeded 

302 content: Content that was stored 

303 tags: Tags that were applied 

304 timestamp: When it was stored 

305 

306 Returns: 

307 Formatted result message 

308 

309 Example: 

310 >>> format_reflection_result( 

311 ... True, 

312 ... "Important insight", 

313 ... ["learning", "bug-fix"], 

314 ... "2025-01-12 14:30:00", 

315 ... ) 

316 

317 """ 

318 if not success: 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true

319 return ToolMessages.operation_failed("Store reflection", "Operation failed") 

320 

321 lines = ["💾 Reflection stored successfully!"] 

322 lines.append(f"📝 Content: {ToolMessages.truncate_text(content, 100)}") 

323 

324 if tags: 

325 lines.append(f"🏷️ Tags: {', '.join(tags)}") 

326 

327 if timestamp: 327 ↛ 330line 327 didn't jump to line 330 because the condition on line 327 was always true

328 lines.append(f"📅 Stored: {timestamp}") 

329 

330 return "\n".join(lines) 

331 

332 

333def format_search_results( 

334 results: list[dict[str, Any]], 

335 query: str, 

336 show_details: bool = True, 

337 max_results: int = 10, 

338) -> str: 

339 """Format search results consistently. 

340 

341 Args: 

342 results: List of search result dictionaries 

343 query: Original search query 

344 show_details: Whether to show result details 

345 max_results: Maximum number of results to show 

346 

347 Returns: 

348 Formatted search results 

349 

350 Example: 

351 >>> results = [{"content": "test", "score": 0.95}] 

352 >>> format_search_results(results, "test query") 

353 

354 """ 

355 if not results: 

356 return ToolMessages.empty_results( 

357 f'Search for "{query}"', "Try different search terms" 

358 ) 

359 

360 count = len(results) 

361 lines = [ 

362 f'🔍 Found {ToolMessages.format_count(count, "result")} for "{query}"', 

363 ] 

364 

365 if show_details: 

366 display_count = min(count, max_results) 

367 for i, result in enumerate(results[:display_count], 1): 

368 lines.append( 

369 f"\n{i}. {ToolMessages.truncate_text(result.get('content', ''), 80)}" 

370 ) 

371 if "score" in result: 

372 lines.append(f" Relevance: {result['score']:.2f}") 

373 if "timestamp" in result: 

374 lines.append(f" Time: {result['timestamp']}") 

375 

376 if count > max_results: 

377 lines.append(f"\n... and {count - max_results} more results") 

378 

379 return "\n".join(lines)