Coverage for session_buddy / parameter_models.py: 73.33%

315 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""Pydantic parameter validation models for MCP tools. 

3 

4This module provides reusable parameter validation models that can be integrated 

5with FastMCP @mcp.tool() decorators to ensure type safety and consistent 

6validation across all MCP tools. 

7 

8Following crackerjack patterns: 

9- EVERY LINE IS A LIABILITY: Focused, single-responsibility models 

10- DRY: Reusable validation patterns across tools 

11- KISS: Simple, clear validation without over-engineering 

12""" 

13 

14from __future__ import annotations 

15 

16import os 

17from pathlib import Path 

18from typing import TYPE_CHECKING, Any, Literal, NamedTuple 

19 

20from pydantic import BaseModel, Field, field_validator 

21 

22if TYPE_CHECKING: 

23 from collections.abc import Callable 

24 

25 from pydantic import ValidationError 

26 

27 

28class ValidationResponse(NamedTuple): 

29 """Response from parameter validation containing status and data.""" 

30 

31 is_valid: bool 

32 params: BaseModel | None = None 

33 errors: str | None = None 

34 

35 

36# Helper functions for common validation patterns 

37def validate_non_empty_string(v: Any, field_name: str) -> str: 

38 """Validate and normalize non-empty string.""" 

39 if not isinstance(v, str): 

40 msg = f"{field_name} must be a string" 

41 raise TypeError(msg) 

42 stripped = v.strip() 

43 if not stripped: 

44 msg = f"{field_name} cannot be empty" 

45 raise ValueError(msg) 

46 return stripped 

47 

48 

49def validate_and_expand_path(v: Any, field_name: str) -> str: 

50 """Validate and expand file paths.""" 

51 if not isinstance(v, str): 

52 msg = f"{field_name} must be a string" 

53 raise TypeError(msg) 

54 if field_name.endswith(("_path", "_directory")): 

55 expanded = os.path.expanduser(v.strip()) if v.strip() else v 

56 if ( 

57 field_name.endswith("_directory") 

58 and expanded 

59 and not Path(expanded).is_absolute() 

60 ): 

61 # For directory fields, ensure absolute paths 

62 expanded = str(Path(expanded).resolve()) 

63 return expanded 

64 return v 

65 

66 

67# Core parameter models for common patterns 

68class WorkingDirectoryParams(BaseModel): 

69 """Standard working directory parameter.""" 

70 

71 working_directory: str | None = Field( 

72 default=None, 

73 description="Optional working directory override (defaults to PWD environment variable or current directory)", 

74 examples=[".", "/Users/username/project", "~/Projects/my-app"], 

75 ) 

76 

77 @field_validator("working_directory") 

78 @classmethod 

79 def validate_working_directory(cls, v: str | None) -> str | None: 

80 """Validate working directory exists if provided.""" 

81 if v is not None: 

82 v = v.strip() 

83 if not v: 

84 return None 

85 # Expand user paths 

86 expanded = os.path.expanduser(v) 

87 if not Path(expanded).exists(): 

88 msg = f"Working directory does not exist: {expanded}" 

89 raise ValueError(msg) 

90 if not Path(expanded).is_dir(): 

91 msg = f"Working directory is not a directory: {expanded}" 

92 raise ValueError(msg) 

93 return expanded 

94 return v 

95 

96 

97class ProjectContextParams(BaseModel): 

98 """Project context parameters.""" 

99 

100 project: str | None = Field( 

101 default=None, 

102 description="Optional project identifier for scoped operations", 

103 min_length=1, 

104 max_length=200, 

105 examples=["my-app", "session-mgmt-mcp", "microservice-auth"], 

106 ) 

107 

108 @field_validator("project") 

109 @classmethod 

110 def validate_project(cls, v: str | None) -> str | None: 

111 """Validate project identifier.""" 

112 if v is not None: 

113 v = v.strip() 

114 if not v: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 return None 

116 return v 

117 

118 

119class SearchLimitParams(BaseModel): 

120 """Standard search and pagination parameters.""" 

121 

122 limit: int = Field( 

123 default=10, 

124 ge=1, 

125 le=1000, 

126 description="Maximum number of results to return", 

127 ) 

128 

129 offset: int = Field( 

130 default=0, 

131 ge=0, 

132 description="Number of results to skip for pagination", 

133 ) 

134 

135 

136class TimeRangeParams(BaseModel): 

137 """Time range parameters for filtering.""" 

138 

139 days: int = Field( 

140 default=7, 

141 ge=1, 

142 le=3650, # 10 years max 

143 description="Number of days to look back", 

144 ) 

145 

146 

147class ScoreThresholdParams(BaseModel): 

148 """Score threshold parameters for relevance filtering.""" 

149 

150 min_score: float = Field( 

151 default=0.7, 

152 ge=0.0, 

153 le=1.0, 

154 description="Minimum relevance score threshold (0.0-1.0)", 

155 ) 

156 

157 

158class TagParams(BaseModel): 

159 """Tag parameter validation.""" 

160 

161 tags: list[str] | None = Field( 

162 default=None, 

163 description="Optional list of tags for categorization", 

164 examples=[["python", "async"], ["bug", "critical"], ["feature", "ui"]], 

165 ) 

166 

167 @field_validator("tags") 

168 @classmethod 

169 def validate_tags(cls, v: list[str] | None) -> list[str] | None: 

170 """Validate tag format and content.""" 

171 if v is None: 

172 return None 

173 

174 cls._validate_tags_type(v) 

175 validated_tags = [ 

176 normalized_tag 

177 for tag in v 

178 if (normalized_tag := cls._process_single_tag(tag)) is not None 

179 ] 

180 

181 return validated_tags or None 

182 

183 @classmethod 

184 def _validate_tags_type(cls, tags: Any) -> None: 

185 """Validate that tags input is correct type.""" 

186 if not isinstance(tags, list): 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 msg = "Tags must be a list of strings" 

188 raise TypeError(msg) 

189 

190 @classmethod 

191 def _process_single_tag(cls, tag: Any) -> str | None: 

192 """Process and validate a single tag.""" 

193 if not isinstance(tag, str): 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 msg = "Each tag must be a string" 

195 raise TypeError(msg) 

196 

197 normalized_tag = tag.strip().lower() 

198 if not normalized_tag: 

199 return None # Skip empty tags 

200 

201 cls._validate_tag_length(normalized_tag) 

202 cls._validate_tag_format(normalized_tag) 

203 

204 return normalized_tag 

205 

206 @classmethod 

207 def _validate_tag_length(cls, tag: str) -> None: 

208 """Validate tag length constraints.""" 

209 if len(tag) > 50: 

210 msg = f"Tag too long (max 50 chars): {tag}" 

211 raise ValueError(msg) 

212 

213 @classmethod 

214 def _validate_tag_format(cls, tag: str) -> None: 

215 """Validate tag format (alphanumeric with hyphens and underscores).""" 

216 if not tag.replace("-", "").replace("_", "").isalnum(): 

217 msg = f"Tags must contain only letters, numbers, hyphens, and underscores: {tag}" 

218 raise ValueError(msg) 

219 

220 

221class IDParams(BaseModel): 

222 """ID parameter validation for various entity types.""" 

223 

224 id: str = Field( 

225 description="Unique identifier", 

226 min_length=1, 

227 max_length=100, 

228 examples=["abc123", "session_20250106", "reflection-456"], 

229 ) 

230 

231 @field_validator("id") 

232 @classmethod 

233 def validate_id_format(cls, v: str) -> str: 

234 """Validate ID format.""" 

235 v = v.strip() 

236 if not v: 

237 msg = "ID cannot be empty" 

238 raise ValueError(msg) 

239 # Allow alphanumeric, hyphens, underscores, and dots 

240 if not v.replace("-", "").replace("_", "").replace(".", "").isalnum(): 

241 msg = ( 

242 "ID must contain only letters, numbers, hyphens, underscores, and dots" 

243 ) 

244 raise ValueError(msg) 

245 return v 

246 

247 

248class FilePathParams(BaseModel): 

249 """File path parameter validation.""" 

250 

251 file_path: str = Field( 

252 description="Path to a file", 

253 min_length=1, 

254 examples=["README.md", "src/main.py", "/absolute/path/file.txt"], 

255 ) 

256 

257 @field_validator("file_path") 

258 @classmethod 

259 def validate_file_path(cls, v: str) -> str: 

260 """Validate file path format.""" 

261 v = v.strip() 

262 if not v: 

263 msg = "File path cannot be empty" 

264 raise ValueError(msg) 

265 

266 # Basic path validation - don't require file to exist (might not exist yet) 

267 if "\x00" in v: 

268 msg = "File path cannot contain null characters" 

269 raise ValueError(msg) 

270 

271 return v 

272 

273 

274class CommandExecutionParams(BaseModel): 

275 """Command execution parameters.""" 

276 

277 command: str = Field( 

278 description="Command to execute", 

279 min_length=1, 

280 max_length=1000, 

281 examples=["lint", "test", "analyze"], 

282 ) 

283 

284 args: str = Field( 

285 default="", 

286 max_length=2000, 

287 description="Command arguments as space-separated string", 

288 ) 

289 

290 timeout: int = Field( 

291 default=300, 

292 ge=1, 

293 le=3600, 

294 description="Command timeout in seconds", 

295 ) 

296 

297 @field_validator("command") 

298 @classmethod 

299 def validate_command(cls, v: str) -> str: 

300 """Validate command string.""" 

301 v = v.strip() 

302 if not v: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true

303 msg = "Command cannot be empty" 

304 raise ValueError(msg) 

305 return v 

306 

307 

308class BooleanFlagParams(BaseModel): 

309 """Common boolean flag parameters.""" 

310 

311 force: bool = Field( 

312 default=False, 

313 description="Force operation, bypassing safety checks", 

314 ) 

315 

316 verbose: bool = Field(default=False, description="Enable verbose output") 

317 

318 dry_run: bool = Field( 

319 default=False, 

320 description="Show what would be done without executing", 

321 ) 

322 

323 

324# Specific MCP tool parameter models 

325class SessionInitParams(WorkingDirectoryParams): 

326 """Parameters for session initialization.""" 

327 

328 # Just uses working_directory from base 

329 

330 

331class SessionStatusParams(WorkingDirectoryParams): 

332 """Parameters for session status check.""" 

333 

334 # Just uses working_directory from base 

335 

336 

337class ReflectionStoreParams(BaseModel): 

338 """Parameters for storing reflections.""" 

339 

340 content: str = Field( 

341 description="Content to store as reflection", 

342 min_length=1, 

343 max_length=50000, 

344 examples=["Learned that async/await patterns improve database performance"], 

345 ) 

346 

347 tags: list[str] | None = Field( 

348 default=None, 

349 description="Optional tags for categorization", 

350 ) 

351 

352 @field_validator("content") 

353 @classmethod 

354 def validate_content(cls, v: str) -> str: 

355 """Validate reflection content.""" 

356 v = v.strip() 

357 if not v: 

358 msg = "Content cannot be empty" 

359 raise ValueError(msg) 

360 return v 

361 

362 @field_validator("tags") 

363 @classmethod 

364 def validate_tags(cls, v: list[str] | None) -> list[str] | None: 

365 """Use the TagParams validation.""" 

366 return TagParams(tags=v).tags 

367 

368 

369class SearchQueryParams(ProjectContextParams, SearchLimitParams, ScoreThresholdParams): 

370 """Parameters for search operations.""" 

371 

372 query: str = Field( 

373 description="Search query text", 

374 min_length=1, 

375 max_length=1000, 

376 examples=["python async patterns", "database migration", "error handling"], 

377 ) 

378 

379 @field_validator("query") 

380 @classmethod 

381 def validate_query(cls, v: str) -> str: 

382 """Validate search query.""" 

383 v = v.strip() 

384 if not v: 

385 msg = "Query cannot be empty" 

386 raise ValueError(msg) 

387 return v 

388 

389 

390class FileSearchParams(SearchLimitParams, ProjectContextParams, ScoreThresholdParams): 

391 """Parameters for file-based search.""" 

392 

393 file_path: str = Field( 

394 description="File path to search for in conversations", 

395 min_length=1, 

396 examples=["src/main.py", "README.md", "config/database.yml"], 

397 ) 

398 

399 @field_validator("file_path") 

400 @classmethod 

401 def validate_file_path(cls, v: str) -> str: 

402 """Validate file path for search.""" 

403 v = v.strip() 

404 if not v: 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true

405 msg = "File path cannot be empty" 

406 raise ValueError(msg) 

407 return v 

408 

409 

410class ConceptSearchParams( 

411 SearchLimitParams, ProjectContextParams, ScoreThresholdParams 

412): 

413 """Parameters for concept-based search.""" 

414 

415 concept: str = Field( 

416 description="Development concept to search for", 

417 min_length=1, 

418 max_length=200, 

419 examples=["authentication", "caching", "error handling", "async patterns"], 

420 ) 

421 

422 include_files: bool = Field( 

423 default=True, 

424 description="Include related files in search results", 

425 ) 

426 

427 @field_validator("concept") 

428 @classmethod 

429 def validate_concept(cls, v: str) -> str: 

430 """Validate concept query.""" 

431 v = v.strip() 

432 if not v: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true

433 msg = "Concept cannot be empty" 

434 raise ValueError(msg) 

435 return v 

436 

437 

438class CrackerjackExecutionParams(CommandExecutionParams, WorkingDirectoryParams): 

439 """Parameters for crackerjack command execution.""" 

440 

441 ai_agent_mode: bool = Field( 

442 default=False, 

443 description="Enable AI agent mode for autonomous fixing", 

444 ) 

445 

446 

447class CrackerjackHistoryParams(TimeRangeParams, WorkingDirectoryParams): 

448 """Parameters for crackerjack execution history.""" 

449 

450 command_filter: str = Field( 

451 default="", 

452 max_length=100, 

453 description="Filter commands by name", 

454 ) 

455 

456 

457class TeamUserParams(BaseModel): 

458 """Parameters for team user operations.""" 

459 

460 user_id: str = Field( 

461 description="Unique user identifier", 

462 min_length=1, 

463 max_length=100, 

464 ) 

465 

466 username: str = Field(description="Display username", min_length=1, max_length=100) 

467 

468 role: Literal["owner", "admin", "moderator", "contributor", "viewer"] = Field( 

469 default="contributor", 

470 description="User role in the team", 

471 ) 

472 

473 email: str | None = Field(default=None, description="Optional email address") 

474 

475 @field_validator("user_id", "username") 

476 @classmethod 

477 def validate_required_strings(cls, v: str) -> str: 

478 """Validate required string fields.""" 

479 v = v.strip() 

480 if not v: 480 ↛ 481line 480 didn't jump to line 481 because the condition on line 480 was never true

481 msg = "Field cannot be empty" 

482 raise ValueError(msg) 

483 return v 

484 

485 @field_validator("email") 

486 @classmethod 

487 def validate_email(cls, v: str | None) -> str | None: 

488 """Basic email validation.""" 

489 if v is None: 

490 return None 

491 

492 v = v.strip() 

493 if not v: 

494 return None 

495 

496 # Basic email format validation 

497 if len(v) > 254: # RFC 5321 limit 

498 msg = "Email address too long" 

499 raise ValueError(msg) 

500 

501 # Must contain exactly one @ symbol 

502 if v.count("@") != 1: 

503 msg = "Invalid email format" 

504 raise ValueError(msg) 

505 

506 local, domain = v.split("@") 

507 

508 # Local part cannot be empty 

509 if not local: 

510 msg = "Invalid email format" 

511 raise ValueError(msg) 

512 

513 # Domain part must contain at least one dot and cannot be empty 

514 if not domain or "." not in domain: 

515 msg = "Invalid email format" 

516 raise ValueError(msg) 

517 

518 # Domain cannot start or end with dot 

519 if domain.startswith(".") or domain.endswith("."): 

520 msg = "Invalid email format" 

521 raise ValueError(msg) 

522 

523 return v 

524 

525 

526class TeamCreationParams(BaseModel): 

527 """Parameters for team creation.""" 

528 

529 team_id: str = Field( 

530 description="Unique team identifier", 

531 min_length=1, 

532 max_length=100, 

533 ) 

534 

535 name: str = Field(description="Team display name", min_length=1, max_length=200) 

536 

537 description: str = Field( 

538 description="Team description", 

539 min_length=1, 

540 max_length=1000, 

541 ) 

542 

543 owner_id: str = Field( 

544 description="User ID of the team owner", 

545 min_length=1, 

546 max_length=100, 

547 ) 

548 

549 @field_validator("team_id", "name", "description", "owner_id") 

550 @classmethod 

551 def validate_required_strings(cls, v: str) -> str: 

552 """Validate required string fields.""" 

553 v = v.strip() 

554 if not v: 554 ↛ 555line 554 didn't jump to line 555 because the condition on line 554 was never true

555 msg = "Field cannot be empty" 

556 raise ValueError(msg) 

557 return v 

558 

559 

560class TeamReflectionParams(ReflectionStoreParams): 

561 """Parameters for team reflection operations.""" 

562 

563 author_id: str = Field( 

564 description="ID of the reflection author", 

565 min_length=1, 

566 max_length=100, 

567 ) 

568 

569 team_id: str | None = Field( 

570 default=None, 

571 description="Optional team ID for team-specific reflections", 

572 min_length=1, 

573 max_length=100, 

574 ) 

575 

576 project_id: str | None = Field( 

577 default=None, 

578 description="Optional project ID for project-specific reflections", 

579 min_length=1, 

580 max_length=100, 

581 ) 

582 

583 access_level: Literal["private", "team", "public"] = Field( 

584 default="team", 

585 description="Access level for the reflection", 

586 ) 

587 

588 @field_validator("author_id") 

589 @classmethod 

590 def validate_author_id(cls, v: str) -> str: 

591 """Validate author ID.""" 

592 v = v.strip() 

593 if not v: 593 ↛ 594line 593 didn't jump to line 594 because the condition on line 593 was never true

594 msg = "Author ID cannot be empty" 

595 raise ValueError(msg) 

596 return v 

597 

598 @field_validator("team_id", "project_id") 

599 @classmethod 

600 def validate_optional_ids(cls, v: str | None) -> str | None: 

601 """Validate optional ID fields.""" 

602 if v is not None: 602 ↛ 606line 602 didn't jump to line 606 because the condition on line 602 was always true

603 v = v.strip() 

604 if not v: 604 ↛ 605line 604 didn't jump to line 605 because the condition on line 604 was never true

605 return None 

606 return v 

607 

608 

609class TeamSearchParams(SearchQueryParams): 

610 """Parameters for team knowledge search.""" 

611 

612 user_id: str = Field( 

613 description="ID of the user performing the search", 

614 min_length=1, 

615 max_length=100, 

616 ) 

617 

618 team_id: str | None = Field( 

619 default=None, 

620 description="Optional team ID to scope the search", 

621 min_length=1, 

622 max_length=100, 

623 ) 

624 

625 project_id: str | None = Field( 

626 default=None, 

627 description="Optional project ID to scope the search", 

628 min_length=1, 

629 max_length=100, 

630 ) 

631 

632 @field_validator("user_id") 

633 @classmethod 

634 def validate_user_id(cls, v: str) -> str: 

635 """Validate user ID.""" 

636 v = v.strip() 

637 if not v: 637 ↛ 638line 637 didn't jump to line 638 because the condition on line 637 was never true

638 msg = "User ID cannot be empty" 

639 raise ValueError(msg) 

640 return v 

641 

642 @field_validator("team_id", "project_id") 

643 @classmethod 

644 def validate_optional_ids(cls, v: str | None) -> str | None: 

645 """Validate optional ID fields.""" 

646 if v is not None: 646 ↛ 650line 646 didn't jump to line 650 because the condition on line 646 was always true

647 v = v.strip() 

648 if not v: 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true

649 return None 

650 return v 

651 

652 

653# Validation helper functions 

654def validate_mcp_params( 

655 model_class: type[BaseModel], **params: Any 

656) -> ValidationResponse: 

657 """Helper function to validate MCP tool parameters using a Pydantic model. 

658 

659 Args: 

660 model_class: The Pydantic model class to use for validation 

661 **params: Parameter values to validate 

662 

663 Returns: 

664 ValidationResponse object with is_valid, params, and errors attributes 

665 

666 Example: 

667 @mcp.tool() 

668 async def search_reflections(**params) -> str: 

669 validated = validate_mcp_params(SearchQueryParams, **params) 

670 if not validated.is_valid: 

671 return f"Validation failed: {validated.errors}" 

672 query = validated.params.query # access Pydantic model attribute 

673 limit = validated.params.limit # access Pydantic model attribute 

674 # ... rest of implementation 

675 

676 """ 

677 try: 

678 validated_model = model_class(**params) 

679 return ValidationResponse( 

680 is_valid=True, 

681 params=validated_model, 

682 errors=None, 

683 ) 

684 except Exception as e: 

685 # Convert Pydantic validation errors to more user-friendly messages 

686 # Import ValidationError for runtime type checking 

687 from pydantic import ValidationError 

688 

689 if isinstance(e, ValidationError): 689 ↛ 697line 689 didn't jump to line 697 because the condition on line 689 was always true

690 error_messages = [] 

691 for error in e.errors(): # type: ignore[attr-defined] 

692 field = error.get("loc", ["unknown"])[-1] 

693 msg = error.get("msg", "validation error") 

694 error_messages.append(f"{field}: {msg}") 

695 errors = f"Parameter validation failed: {'; '.join(error_messages)}" 

696 return ValidationResponse(is_valid=False, params=None, errors=errors) 

697 errors = f"Parameter validation failed: {e!s}" 

698 return ValidationResponse(is_valid=False, params=None, errors=errors) 

699 

700 

701def create_mcp_validator( 

702 model_class: type[BaseModel], 

703) -> Callable[[Callable[..., Any]], Callable[..., Any]]: 

704 """Decorator factory to create MCP tool parameter validators. 

705 

706 Args: 

707 model_class: The Pydantic model class to use for validation 

708 

709 Returns: 

710 Decorator function that validates parameters before tool execution 

711 

712 Example: 

713 @mcp.tool() 

714 @create_mcp_validator(SearchQueryParams) 

715 async def search_reflections(**params) -> str: 

716 # params are now validated and typed 

717 query = params['query'] 

718 # ... implementation 

719 

720 """ 

721 

722 def decorator(func: Callable[..., Any]) -> Callable[..., Any]: 

723 async def wrapper(**params: Any) -> Any: 

724 validated_response = validate_mcp_params(model_class, **params) 

725 if not validated_response.is_valid: 

726 msg = f"Parameter validation failed: {validated_response.errors}" 

727 raise ValueError(msg) 

728 if validated_response.params is None: 

729 return await func() 

730 # Convert the Pydantic model to a dictionary for unpacking 

731 params_dict = validated_response.params.model_dump() 

732 return await func(**params_dict) 

733 

734 # Preserve function metadata 

735 wrapper.__name__ = func.__name__ 

736 wrapper.__doc__ = func.__doc__ 

737 wrapper.__annotations__ = func.__annotations__ 

738 

739 return wrapper 

740 

741 return decorator