Coverage for session_buddy / tools / validated_memory_tools.py: 77.55%

291 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""Example integration of parameter validation models with MCP tools. 

3 

4This module demonstrates how to integrate Pydantic parameter validation 

5models with existing MCP tools for improved type safety and error handling. 

6 

7Following crackerjack patterns: 

8- EVERY LINE IS A LIABILITY: Clean, focused tool implementations 

9- DRY: Reusable validation across all tools 

10- KISS: Simple integration without over-engineering 

11 

12Refactored to use utility modules for reduced code duplication. 

13""" 

14 

15from __future__ import annotations 

16 

17# ============================================================================ 

18# Helper Functions 

19# ============================================================================ 

20from contextlib import suppress 

21from datetime import datetime 

22from typing import TYPE_CHECKING, Any 

23 

24from session_buddy.adapters.reflection_adapter import ReflectionDatabaseAdapter 

25from session_buddy.parameter_models import ( 

26 ConceptSearchParams, 

27 FileSearchParams, 

28 ReflectionStoreParams, 

29 SearchQueryParams, 

30 validate_mcp_params, 

31) 

32from session_buddy.reflection_tools import ReflectionDatabase 

33from session_buddy.utils.error_handlers import ValidationError, _get_logger 

34from session_buddy.utils.tool_wrapper import execute_database_tool 

35 

36# Define type alias for backward compatibility during migration 

37# NOTE: With 'from __future__ import annotations', we use the actual types, not strings 

38ReflectionDatabaseType = ReflectionDatabaseAdapter | ReflectionDatabase 

39 

40 

41async def _get_reflection_database() -> ReflectionDatabaseType: 

42 """Get reflection database instance with cached availability semantics.""" 

43 db = await _get_reflection_database_async() 

44 if db is None: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true

45 msg = "Reflection tools not available" 

46 raise ImportError(msg) 

47 return db 

48 

49 

50def _format_result_item(res: dict[str, Any], index: int) -> list[str]: 

51 """Format a single search result item.""" 

52 lines = [f"\n{index}. 📝 {res['content'][:200]}..."] 

53 if res.get("project"): 53 ↛ 55line 53 didn't jump to line 55 because the condition on line 53 was always true

54 lines.append(f" 📁 Project: {res['project']}") 

55 if res.get("score") is not None: 55 ↛ 57line 55 didn't jump to line 57 because the condition on line 55 was always true

56 lines.append(f" ⭐ Relevance: {res['score']:.2f}") 

57 if res.get("timestamp"): 57 ↛ 59line 57 didn't jump to line 59 because the condition on line 57 was always true

58 lines.append(f" 📅 Date: {res['timestamp']}") 

59 return lines 

60 

61 

62def _format_search_results(results: list[dict[str, Any]]) -> list[str]: 

63 """Format search results with common structure.""" 

64 if not results: 

65 return [ 

66 "🔍 No conversations found about this file", 

67 "💡 The file might not have been discussed in previous sessions", 

68 ] 

69 

70 lines = [f"📈 Found {len(results)} relevant conversations:"] 

71 for i, res in enumerate(results, 1): 

72 lines.extend(_format_result_item(res, i)) 

73 return lines 

74 

75 

76def _format_concept_results( 

77 results: list[dict[str, Any]], include_files: bool 

78) -> list[str]: 

79 """Format concept search results with optional file information.""" 

80 if not results: 

81 return [ 

82 "🔍 No conversations found about this concept", 

83 "💡 Try related terms or broader concepts", 

84 ] 

85 

86 lines = [f"📈 Found {len(results)} related conversations:"] 

87 for i, res in enumerate(results, 1): 

88 item_lines = [f"\n{i}. 📝 {res['content'][:250]}..."] 

89 if res.get("project"): 89 ↛ 91line 89 didn't jump to line 91 because the condition on line 89 was always true

90 item_lines.append(f" 📁 Project: {res['project']}") 

91 if res.get("score") is not None: 91 ↛ 93line 91 didn't jump to line 93 because the condition on line 91 was always true

92 item_lines.append(f" ⭐ Relevance: {res['score']:.2f}") 

93 if res.get("timestamp"): 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was always true

94 item_lines.append(f" 📅 Date: {res['timestamp']}") 

95 if include_files and res.get("files"): 95 ↛ 99line 95 didn't jump to line 99 because the condition on line 95 was always true

96 files = res["files"][:3] 

97 if files: 97 ↛ 99line 97 didn't jump to line 99 because the condition on line 97 was always true

98 item_lines.append(f" 📄 Files: {', '.join(files)}") 

99 lines.extend(item_lines) 

100 return lines 

101 

102 

103# ============================================================================ 

104# Validated Tool Implementations 

105# ============================================================================ 

106 

107 

108def _validate_reflection_params(**params: Any) -> ReflectionStoreParams | str: 

109 """Validate reflection store parameters. 

110 

111 Args: 

112 **params: Raw parameters from MCP call 

113 

114 Returns: 

115 Validated params object or error message string 

116 

117 """ 

118 from typing import cast 

119 

120 try: 

121 validated = validate_mcp_params(ReflectionStoreParams, **params) 

122 if not validated.is_valid: 

123 return f"Parameter validation error: {validated.errors}" 

124 return cast("ReflectionStoreParams", validated.params) 

125 except ValidationError as e: 

126 return f"Parameter validation error: {e}" 

127 

128 

129async def _execute_store_reflection( 

130 params_obj: ReflectionStoreParams, db: Any 

131) -> dict[str, Any]: 

132 """Execute the reflection storage operation. 

133 

134 Args: 

135 params_obj: Validated parameters 

136 db: Database instance 

137 

138 Returns: 

139 Operation result dictionary 

140 

141 """ 

142 reflection_id = await db.store_reflection( 

143 params_obj.content, 

144 tags=params_obj.tags or [], 

145 ) 

146 

147 return { 

148 "success": reflection_id not in (None, False), 

149 "id": reflection_id, 

150 "content": params_obj.content, 

151 "tags": params_obj.tags or [], 

152 "timestamp": datetime.now().isoformat(), 

153 } 

154 

155 

156def _format_reflection_result(result: dict[str, Any]) -> str: 

157 """Format reflection storage result for user display. 

158 

159 Args: 

160 result: Operation result dictionary 

161 

162 Returns: 

163 Formatted string message 

164 

165 """ 

166 lines = [ 

167 "💾 Reflection stored successfully!", 

168 f"🆔 ID: {result['id']}", 

169 f"📝 Content: {result['content'][:100]}...", 

170 ] 

171 if result["tags"]: 

172 lines.append(f"🏷️ Tags: {', '.join(result['tags'])}") 

173 lines.append(f"📅 Stored: {result['timestamp']}") 

174 

175 _get_logger().info( 

176 f"Validated reflection stored | Context: {{'reflection_id': '{result['id']}', 'content_length': {len(result['content'])}, 'tags_count': {len(result['tags']) if result['tags'] else 0}}}" 

177 ) 

178 return "\n".join(lines) 

179 

180 

181async def _store_reflection_validated_impl(**params: Any) -> str: 

182 """Implementation for store_reflection tool with parameter validation.""" 

183 # Check if tools are available 

184 if not _check_reflection_tools_available(): 

185 return "❌ Reflection tools not available. Install with: `uv sync --extra embeddings`\n💡 This enables conversation memory and semantic search capabilities." 

186 

187 # Validate parameters 

188 params_validation = _validate_reflection_params(**params) 

189 if isinstance(params_validation, str): 

190 return params_validation 

191 params_obj = params_validation 

192 

193 try: 

194 # Get database instance 

195 db = await _get_reflection_database_async() 

196 if not db: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 return "❌ Failed to connect to reflection database" 

198 

199 # Execute storage operation 

200 result = await _execute_store_reflection(params_obj, db) 

201 if not result["success"]: 

202 error_msg = f"Failed to store reflection: {result['id']}" 

203 _get_logger().error(error_msg) 

204 return error_msg 

205 

206 return _format_reflection_result(result) 

207 

208 except ValidationError as e: 

209 return f"Parameter validation failed: {e}" 

210 except ImportError: 

211 error_msg = "Failed to connect to reflection database: Import error" 

212 _get_logger().error(error_msg) 

213 return error_msg 

214 except Exception as e: 

215 error_msg = f"Failed to store reflection: {e}" 

216 _get_logger().error(error_msg) 

217 return error_msg 

218 

219 

220async def _quick_search_validated_impl(**params: Any) -> str: 

221 """Implementation for quick_search tool with parameter validation.""" 

222 from typing import cast 

223 

224 # Validate parameters 

225 try: 

226 validated = validate_mcp_params(SearchQueryParams, **params) 

227 if not validated.is_valid: 

228 return f"Parameter validation error: {validated.errors}" 

229 params_obj = cast("SearchQueryParams", validated.params) 

230 except ValidationError as e: 

231 return f"Parameter validation error: {e}" 

232 

233 async def operation(db: Any) -> dict[str, Any]: 

234 """Quick search operation.""" 

235 results = await db.search_reflections( 

236 params_obj.query, 

237 limit=1, 

238 min_score=params_obj.min_score, 

239 ) 

240 

241 return { 

242 "query": params_obj.query, 

243 "results": results, 

244 "total_count": len(results), 

245 } 

246 

247 def formatter(result: dict[str, Any]) -> str: 

248 """Format quick search results.""" 

249 lines = [f"🔍 Quick search for: '{result['query']}'"] 

250 

251 if not result["results"]: 

252 lines.extend( 

253 [ 

254 "🔍 No results found", 

255 "💡 Try adjusting your search terms or lowering min_score", 

256 ] 

257 ) 

258 else: 

259 lines.extend(_format_top_result(result["results"][0])) 

260 

261 _get_logger().info( 

262 f"Validated quick search executed | Context: {{'query': '{result['query']}', 'results_count': {result['total_count']}}}" 

263 ) 

264 return "\n".join(lines) 

265 

266 # Check if tools are available 

267 if not _check_reflection_tools_available(): 

268 return "❌ Reflection tools not available. Install with: `uv sync --extra embeddings`\n💡 This enables conversation memory and semantic search capabilities." 

269 

270 try: 

271 # Get database instance and execute operation 

272 db = await _get_reflection_database_async() 

273 if not db: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true

274 return "❌ Failed to connect to reflection database" 

275 

276 result = await operation(db) 

277 return formatter(result) 

278 except ValidationError as e: 

279 # Return validation errors as strings instead of raising 

280 return f"Parameter validation failed: {e}" 

281 except ImportError: 

282 # Handle import errors from database initialization 

283 error_msg = "Failed to connect to reflection database: Import error" 

284 _get_logger().error(error_msg) 

285 return error_msg 

286 except Exception as e: 

287 error_msg = f"Failed to perform quick search: {e}" 

288 _get_logger().error(error_msg) 

289 return error_msg 

290 

291 

292def _format_top_result(top_result: dict[str, Any]) -> list[str]: 

293 """Format the top search result.""" 

294 lines = [ 

295 "📊 Found results (showing top 1)", 

296 f"📝 {top_result['content'][:150]}...", 

297 ] 

298 if top_result.get("project"): 298 ↛ 300line 298 didn't jump to line 300 because the condition on line 298 was always true

299 lines.append(f"📁 Project: {top_result['project']}") 

300 if top_result.get("score") is not None: 300 ↛ 302line 300 didn't jump to line 302 because the condition on line 300 was always true

301 lines.append(f"⭐ Relevance: {top_result['score']:.2f}") 

302 if top_result.get("timestamp"): 302 ↛ 305line 302 didn't jump to line 305 because the condition on line 302 was always true

303 lines.append(f"📅 Date: {top_result['timestamp']}") 

304 

305 return lines 

306 

307 

308async def _search_by_file_validated_impl(**params: Any) -> str: 

309 """Implementation for search_by_file tool with parameter validation.""" 

310 from typing import cast 

311 

312 # Validate parameters 

313 try: 

314 validated = validate_mcp_params(FileSearchParams, **params) 

315 if not validated.is_valid: 

316 return f"Parameter validation error: {validated.errors}" 

317 params_obj = cast("FileSearchParams", validated.params) 

318 except ValidationError as e: 

319 return f"Parameter validation error: {e}" 

320 

321 async def operation(db: Any) -> dict[str, Any]: 

322 """File search operation.""" 

323 results = await db.search_reflections( 

324 params_obj.file_path, 

325 limit=params_obj.limit, 

326 min_score=params_obj.min_score, 

327 ) 

328 

329 return { 

330 "file_path": params_obj.file_path, 

331 "results": results, 

332 } 

333 

334 def formatter(result: dict[str, Any]) -> str: 

335 """Format file search results.""" 

336 file_path = result["file_path"] 

337 results = result["results"] 

338 

339 lines = [f"📁 Searching conversations about: {file_path}", "=" * 50] 

340 lines.extend(_format_search_results(results)) 

341 

342 _get_logger().info( 

343 f"Validated file search executed | Context: {{'file_path': '{file_path}', 'results_count': {len(results)}}}" 

344 ) 

345 return "\n".join(lines) 

346 

347 # Check if tools are available 

348 if not _check_reflection_tools_available(): 

349 return "❌ Reflection tools not available. Install with: `uv sync --extra embeddings`\n💡 This enables conversation memory and semantic search capabilities." 

350 

351 try: 

352 # Get database instance and execute operation 

353 db = await _get_reflection_database_async() 

354 if not db: 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true

355 return "❌ Failed to connect to reflection database" 

356 

357 result = await operation(db) 

358 return formatter(result) 

359 except ValidationError as e: 

360 # Return validation errors as strings instead of raising 

361 return f"Parameter validation failed: {e}" 

362 except ImportError: 

363 # Handle import errors from database initialization 

364 error_msg = "Failed to connect to reflection database: Import error" 

365 _get_logger().error(error_msg) 

366 return error_msg 

367 except Exception as e: 

368 error_msg = f"Failed to perform file search: {e}" 

369 _get_logger().error(error_msg) 

370 return error_msg 

371 

372 

373async def _search_by_concept_validated_impl(**params: Any) -> str: 

374 """Implementation for search_by_concept tool with parameter validation.""" 

375 from typing import cast 

376 

377 # Validate parameters 

378 try: 

379 validated = validate_mcp_params(ConceptSearchParams, **params) 

380 if not validated.is_valid: 

381 return f"Parameter validation error: {validated.errors}" 

382 params_obj = cast("ConceptSearchParams", validated.params) 

383 except ValidationError as e: 

384 return f"Parameter validation error: {e}" 

385 

386 async def operation(db: Any) -> dict[str, Any]: 

387 """Concept search operation.""" 

388 results = await db.search_reflections( 

389 params_obj.concept, 

390 limit=params_obj.limit, 

391 min_score=params_obj.min_score, 

392 ) 

393 

394 return { 

395 "concept": params_obj.concept, 

396 "include_files": params_obj.include_files, 

397 "results": results, 

398 } 

399 

400 def formatter(result: dict[str, Any]) -> str: 

401 """Format concept search results.""" 

402 concept = result["concept"] 

403 results = result["results"] 

404 

405 lines = [f"🧠 Searching for concept: '{concept}'", "=" * 50] 

406 lines.extend(_format_concept_results(results, result["include_files"])) 

407 

408 _get_logger().info( 

409 f"Validated concept search executed | Context: {{'concept': '{concept}', 'results_count': {len(results)}}}" 

410 ) 

411 return "\n".join(lines) 

412 

413 # Check if tools are available 

414 if not _check_reflection_tools_available(): 

415 return "❌ Reflection tools not available. Install with: `uv sync --extra embeddings`\n💡 This enables conversation memory and semantic search capabilities." 

416 

417 try: 

418 # Get database instance and execute operation 

419 db = await _get_reflection_database_async() 

420 if not db: 420 ↛ 421line 420 didn't jump to line 421 because the condition on line 420 was never true

421 return "❌ Failed to connect to reflection database" 

422 

423 result = await operation(db) 

424 return formatter(result) 

425 except ValidationError as e: 

426 # Return validation errors as strings instead of raising 

427 return f"Parameter validation failed: {e}" 

428 except ImportError: 

429 # Handle import errors from database initialization 

430 error_msg = "Failed to connect to reflection database: Import error" 

431 _get_logger().error(error_msg) 

432 return error_msg 

433 except Exception as e: 

434 error_msg = f"Failed to perform concept search: {e}" 

435 _get_logger().error(error_msg) 

436 return error_msg 

437 

438 

439def _format_file_search_header(file_path: str) -> list[str]: 

440 """Format header for file search results.""" 

441 return [ 

442 f"📁 Searching conversations about: {file_path}", 

443 "=" * 50, 

444 ] 

445 

446 

447def _format_file_search_result(res: dict[str, Any], index: int) -> list[str]: 

448 """Format individual file search result.""" 

449 lines = [ 

450 f"{index}. 📝 {res['content'][:200]}...", 

451 ] 

452 

453 if res.get("timestamp"): 453 ↛ 456line 453 didn't jump to line 456 because the condition on line 453 was always true

454 lines.append(f" 📅 Date: {res['timestamp']}") 

455 

456 if res.get("project"): 

457 lines.append(f" 📁 Project: {res['project']}") 

458 

459 if res.get("score") is not None: 459 ↛ 462line 459 didn't jump to line 462 because the condition on line 459 was always true

460 lines.append(f" ⭐ Relevance: {res['score']:.2f}") 

461 

462 return lines 

463 

464 

465def _format_file_search_results(results: list[dict[str, Any]], query: str) -> list[str]: 

466 """Format all file search results.""" 

467 if not results: 

468 return [ 

469 "No conversations found about this file", 

470 f"🔍 No conversations found discussing '{query}'", 

471 "💡 The file might not have been discussed in previous sessions", 

472 ] 

473 

474 lines = [ 

475 f"📁 Searching conversations about: {query}", 

476 "=" * 50, 

477 f"📈 Found {len(results)} relevant conversations:", 

478 ] 

479 

480 for i, res in enumerate(results, 1): 

481 result_lines = _format_file_search_result(res, i) 

482 if isinstance(result_lines, list): 482 ↛ 485line 482 didn't jump to line 485 because the condition on line 482 was always true

483 lines.extend(result_lines) 

484 else: 

485 lines.append(str(result_lines)) 

486 

487 return lines 

488 

489 

490def _format_validated_concept_result( 

491 res: dict[str, Any], index: int, include_files: bool = True 

492) -> list[str]: 

493 """Format individual concept search result.""" 

494 lines = [ 

495 f"{index}. 🧠 Concept: {res['content'][:200]}...", 

496 ] 

497 

498 if res.get("timestamp"): 498 ↛ 501line 498 didn't jump to line 501 because the condition on line 498 was always true

499 lines.append(f" 📅 Date: {res['timestamp']}") 

500 

501 if res.get("project"): 501 ↛ 504line 501 didn't jump to line 504 because the condition on line 501 was always true

502 lines.append(f" 📁 Project: {res['project']}") 

503 

504 if res.get("score") is not None: 504 ↛ 507line 504 didn't jump to line 507 because the condition on line 504 was always true

505 lines.append(f" ⭐ Relevance: {res['score']:.2f}") 

506 

507 if include_files and res.get("files"): 

508 files = res["files"][:5] # Limit to 5 files 

509 lines.append(f" 📄 Files: {', '.join(files)}") 

510 

511 return lines 

512 

513 

514# Define missing classes for backward compatibility 

515class ValidationExamples: 

516 """Placeholder class for validation examples.""" 

517 

518 def example_valid_calls(self) -> list[dict[str, Any]]: 

519 """Get examples of valid calls.""" 

520 return [{"query": "test query", "limit": 5}] 

521 

522 def example_validation_errors(self) -> list[dict[str, str]]: 

523 """Get examples of validation errors.""" 

524 return [{"field": "query", "error": "Field required"}] 

525 

526 

527class MigrationGuide: 

528 """Placeholder class for migration guide.""" 

529 

530 @staticmethod 

531 def before_migration() -> str: 

532 """Get before migration instructions.""" 

533 return "Before migrating, backup your data." 

534 

535 @staticmethod 

536 def after_migration() -> str: 

537 """Get after migration instructions.""" 

538 return "After migrating, verify your configurations." 

539 

540 

541# Global variable to cache reflection tools availability 

542_reflection_tools_available: bool | None = None 

543 

544 

545def _check_reflection_tools_available() -> bool: 

546 """Check if reflection tools are available and properly installed.""" 

547 global _reflection_tools_available 

548 

549 if _reflection_tools_available is not None: 

550 return _reflection_tools_available 

551 

552 try: 

553 # Check if reflection database module can be imported 

554 import importlib.util 

555 

556 spec = importlib.util.find_spec("session_buddy.reflection_tools") 

557 available = spec is not None 

558 _reflection_tools_available = available 

559 return available 

560 except Exception: 

561 _reflection_tools_available = False 

562 return False 

563 

564 

565async def resolve_reflection_database() -> ReflectionDatabaseType | None: 

566 """Resolve the reflection database instance using dependency injection or fallback.""" 

567 # Try to get from DI container 

568 with suppress(Exception): 

569 from typing import cast 

570 

571 from session_buddy.di.container import depends 

572 from session_buddy.reflection_tools import ReflectionDatabase 

573 

574 db = depends.get_sync(ReflectionDatabase) 

575 if db: 

576 return cast("ReflectionDatabase", db) 

577 

578 # Fallback - get a direct instance 

579 with suppress(Exception): 

580 from session_buddy.reflection_tools import get_reflection_database 

581 

582 return await get_reflection_database() 

583 

584 return None 

585 

586 

587async def _get_reflection_database_async() -> ReflectionDatabaseType | None: 

588 """Get reflection database instance with lazy initialization.""" 

589 if not _check_reflection_tools_available(): 

590 msg = "Reflection tools not available" 

591 raise ImportError(msg) 

592 

593 try: 

594 db = await resolve_reflection_database() 

595 if db is None: 

596 msg = "Reflection tools not available" 

597 raise ImportError(msg) 

598 return db 

599 except ImportError: 

600 # Re-raise import errors as they indicate unavailability 

601 raise 

602 except Exception: 

603 # For any other exception, treat as unavailable 

604 msg = "Reflection tools not available" 

605 raise ImportError(msg) 

606 

607 

608# ============================================================================ 

609# MCP Tool Registration 

610# ============================================================================ 

611 

612 

613def register_validated_memory_tools(mcp_server: Any) -> None: 

614 """Register all validated memory tools with the MCP server. 

615 

616 These tools demonstrate parameter validation using Pydantic models 

617 while using the same utility-based refactoring patterns as other tools. 

618 """ 

619 

620 @mcp_server.tool() # type: ignore[misc] 

621 async def store_reflection_validated(**params: Any) -> str: 

622 """Store a reflection with validated parameters. 

623 

624 This demonstrates how to integrate Pydantic parameter validation 

625 with MCP tools for improved type safety. 

626 """ 

627 return await _store_reflection_validated_impl(**params) 

628 

629 @mcp_server.tool() # type: ignore[misc] 

630 async def quick_search_validated(**params: Any) -> str: 

631 """Quick search with validated parameters.""" 

632 return await _quick_search_validated_impl(**params) 

633 

634 @mcp_server.tool() # type: ignore[misc] 

635 async def search_by_file_validated(**params: Any) -> str: 

636 """Search by file with validated parameters.""" 

637 return await _search_by_file_validated_impl(**params) 

638 

639 @mcp_server.tool() # type: ignore[misc] 

640 async def search_by_concept_validated(**params: Any) -> str: 

641 """Search by concept with validated parameters.""" 

642 return await _search_by_concept_validated_impl(**params)