Coverage for session_buddy / parameter_models.py: 73.33%
315 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Pydantic parameter validation models for MCP tools.
4This module provides reusable parameter validation models that can be integrated
5with FastMCP @mcp.tool() decorators to ensure type safety and consistent
6validation across all MCP tools.
8Following crackerjack patterns:
9- EVERY LINE IS A LIABILITY: Focused, single-responsibility models
10- DRY: Reusable validation patterns across tools
11- KISS: Simple, clear validation without over-engineering
12"""
14from __future__ import annotations
16import os
17from pathlib import Path
18from typing import TYPE_CHECKING, Any, Literal, NamedTuple
20from pydantic import BaseModel, Field, field_validator
22if TYPE_CHECKING:
23 from collections.abc import Callable
25 from pydantic import ValidationError
28class ValidationResponse(NamedTuple):
29 """Response from parameter validation containing status and data."""
31 is_valid: bool
32 params: BaseModel | None = None
33 errors: str | None = None
36# Helper functions for common validation patterns
37def validate_non_empty_string(v: Any, field_name: str) -> str:
38 """Validate and normalize non-empty string."""
39 if not isinstance(v, str):
40 msg = f"{field_name} must be a string"
41 raise TypeError(msg)
42 stripped = v.strip()
43 if not stripped:
44 msg = f"{field_name} cannot be empty"
45 raise ValueError(msg)
46 return stripped
49def validate_and_expand_path(v: Any, field_name: str) -> str:
50 """Validate and expand file paths."""
51 if not isinstance(v, str):
52 msg = f"{field_name} must be a string"
53 raise TypeError(msg)
54 if field_name.endswith(("_path", "_directory")):
55 expanded = os.path.expanduser(v.strip()) if v.strip() else v
56 if (
57 field_name.endswith("_directory")
58 and expanded
59 and not Path(expanded).is_absolute()
60 ):
61 # For directory fields, ensure absolute paths
62 expanded = str(Path(expanded).resolve())
63 return expanded
64 return v
67# Core parameter models for common patterns
68class WorkingDirectoryParams(BaseModel):
69 """Standard working directory parameter."""
71 working_directory: str | None = Field(
72 default=None,
73 description="Optional working directory override (defaults to PWD environment variable or current directory)",
74 examples=[".", "/Users/username/project", "~/Projects/my-app"],
75 )
77 @field_validator("working_directory")
78 @classmethod
79 def validate_working_directory(cls, v: str | None) -> str | None:
80 """Validate working directory exists if provided."""
81 if v is not None:
82 v = v.strip()
83 if not v:
84 return None
85 # Expand user paths
86 expanded = os.path.expanduser(v)
87 if not Path(expanded).exists():
88 msg = f"Working directory does not exist: {expanded}"
89 raise ValueError(msg)
90 if not Path(expanded).is_dir():
91 msg = f"Working directory is not a directory: {expanded}"
92 raise ValueError(msg)
93 return expanded
94 return v
97class ProjectContextParams(BaseModel):
98 """Project context parameters."""
100 project: str | None = Field(
101 default=None,
102 description="Optional project identifier for scoped operations",
103 min_length=1,
104 max_length=200,
105 examples=["my-app", "session-mgmt-mcp", "microservice-auth"],
106 )
108 @field_validator("project")
109 @classmethod
110 def validate_project(cls, v: str | None) -> str | None:
111 """Validate project identifier."""
112 if v is not None:
113 v = v.strip()
114 if not v: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true
115 return None
116 return v
119class SearchLimitParams(BaseModel):
120 """Standard search and pagination parameters."""
122 limit: int = Field(
123 default=10,
124 ge=1,
125 le=1000,
126 description="Maximum number of results to return",
127 )
129 offset: int = Field(
130 default=0,
131 ge=0,
132 description="Number of results to skip for pagination",
133 )
136class TimeRangeParams(BaseModel):
137 """Time range parameters for filtering."""
139 days: int = Field(
140 default=7,
141 ge=1,
142 le=3650, # 10 years max
143 description="Number of days to look back",
144 )
147class ScoreThresholdParams(BaseModel):
148 """Score threshold parameters for relevance filtering."""
150 min_score: float = Field(
151 default=0.7,
152 ge=0.0,
153 le=1.0,
154 description="Minimum relevance score threshold (0.0-1.0)",
155 )
158class TagParams(BaseModel):
159 """Tag parameter validation."""
161 tags: list[str] | None = Field(
162 default=None,
163 description="Optional list of tags for categorization",
164 examples=[["python", "async"], ["bug", "critical"], ["feature", "ui"]],
165 )
167 @field_validator("tags")
168 @classmethod
169 def validate_tags(cls, v: list[str] | None) -> list[str] | None:
170 """Validate tag format and content."""
171 if v is None:
172 return None
174 cls._validate_tags_type(v)
175 validated_tags = [
176 normalized_tag
177 for tag in v
178 if (normalized_tag := cls._process_single_tag(tag)) is not None
179 ]
181 return validated_tags or None
183 @classmethod
184 def _validate_tags_type(cls, tags: Any) -> None:
185 """Validate that tags input is correct type."""
186 if not isinstance(tags, list): 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 msg = "Tags must be a list of strings"
188 raise TypeError(msg)
190 @classmethod
191 def _process_single_tag(cls, tag: Any) -> str | None:
192 """Process and validate a single tag."""
193 if not isinstance(tag, str): 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 msg = "Each tag must be a string"
195 raise TypeError(msg)
197 normalized_tag = tag.strip().lower()
198 if not normalized_tag:
199 return None # Skip empty tags
201 cls._validate_tag_length(normalized_tag)
202 cls._validate_tag_format(normalized_tag)
204 return normalized_tag
206 @classmethod
207 def _validate_tag_length(cls, tag: str) -> None:
208 """Validate tag length constraints."""
209 if len(tag) > 50:
210 msg = f"Tag too long (max 50 chars): {tag}"
211 raise ValueError(msg)
213 @classmethod
214 def _validate_tag_format(cls, tag: str) -> None:
215 """Validate tag format (alphanumeric with hyphens and underscores)."""
216 if not tag.replace("-", "").replace("_", "").isalnum():
217 msg = f"Tags must contain only letters, numbers, hyphens, and underscores: {tag}"
218 raise ValueError(msg)
221class IDParams(BaseModel):
222 """ID parameter validation for various entity types."""
224 id: str = Field(
225 description="Unique identifier",
226 min_length=1,
227 max_length=100,
228 examples=["abc123", "session_20250106", "reflection-456"],
229 )
231 @field_validator("id")
232 @classmethod
233 def validate_id_format(cls, v: str) -> str:
234 """Validate ID format."""
235 v = v.strip()
236 if not v:
237 msg = "ID cannot be empty"
238 raise ValueError(msg)
239 # Allow alphanumeric, hyphens, underscores, and dots
240 if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
241 msg = (
242 "ID must contain only letters, numbers, hyphens, underscores, and dots"
243 )
244 raise ValueError(msg)
245 return v
248class FilePathParams(BaseModel):
249 """File path parameter validation."""
251 file_path: str = Field(
252 description="Path to a file",
253 min_length=1,
254 examples=["README.md", "src/main.py", "/absolute/path/file.txt"],
255 )
257 @field_validator("file_path")
258 @classmethod
259 def validate_file_path(cls, v: str) -> str:
260 """Validate file path format."""
261 v = v.strip()
262 if not v:
263 msg = "File path cannot be empty"
264 raise ValueError(msg)
266 # Basic path validation - don't require file to exist (might not exist yet)
267 if "\x00" in v:
268 msg = "File path cannot contain null characters"
269 raise ValueError(msg)
271 return v
274class CommandExecutionParams(BaseModel):
275 """Command execution parameters."""
277 command: str = Field(
278 description="Command to execute",
279 min_length=1,
280 max_length=1000,
281 examples=["lint", "test", "analyze"],
282 )
284 args: str = Field(
285 default="",
286 max_length=2000,
287 description="Command arguments as space-separated string",
288 )
290 timeout: int = Field(
291 default=300,
292 ge=1,
293 le=3600,
294 description="Command timeout in seconds",
295 )
297 @field_validator("command")
298 @classmethod
299 def validate_command(cls, v: str) -> str:
300 """Validate command string."""
301 v = v.strip()
302 if not v: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true
303 msg = "Command cannot be empty"
304 raise ValueError(msg)
305 return v
308class BooleanFlagParams(BaseModel):
309 """Common boolean flag parameters."""
311 force: bool = Field(
312 default=False,
313 description="Force operation, bypassing safety checks",
314 )
316 verbose: bool = Field(default=False, description="Enable verbose output")
318 dry_run: bool = Field(
319 default=False,
320 description="Show what would be done without executing",
321 )
324# Specific MCP tool parameter models
325class SessionInitParams(WorkingDirectoryParams):
326 """Parameters for session initialization."""
328 # Just uses working_directory from base
331class SessionStatusParams(WorkingDirectoryParams):
332 """Parameters for session status check."""
334 # Just uses working_directory from base
337class ReflectionStoreParams(BaseModel):
338 """Parameters for storing reflections."""
340 content: str = Field(
341 description="Content to store as reflection",
342 min_length=1,
343 max_length=50000,
344 examples=["Learned that async/await patterns improve database performance"],
345 )
347 tags: list[str] | None = Field(
348 default=None,
349 description="Optional tags for categorization",
350 )
352 @field_validator("content")
353 @classmethod
354 def validate_content(cls, v: str) -> str:
355 """Validate reflection content."""
356 v = v.strip()
357 if not v:
358 msg = "Content cannot be empty"
359 raise ValueError(msg)
360 return v
362 @field_validator("tags")
363 @classmethod
364 def validate_tags(cls, v: list[str] | None) -> list[str] | None:
365 """Use the TagParams validation."""
366 return TagParams(tags=v).tags
369class SearchQueryParams(ProjectContextParams, SearchLimitParams, ScoreThresholdParams):
370 """Parameters for search operations."""
372 query: str = Field(
373 description="Search query text",
374 min_length=1,
375 max_length=1000,
376 examples=["python async patterns", "database migration", "error handling"],
377 )
379 @field_validator("query")
380 @classmethod
381 def validate_query(cls, v: str) -> str:
382 """Validate search query."""
383 v = v.strip()
384 if not v:
385 msg = "Query cannot be empty"
386 raise ValueError(msg)
387 return v
390class FileSearchParams(SearchLimitParams, ProjectContextParams, ScoreThresholdParams):
391 """Parameters for file-based search."""
393 file_path: str = Field(
394 description="File path to search for in conversations",
395 min_length=1,
396 examples=["src/main.py", "README.md", "config/database.yml"],
397 )
399 @field_validator("file_path")
400 @classmethod
401 def validate_file_path(cls, v: str) -> str:
402 """Validate file path for search."""
403 v = v.strip()
404 if not v: 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true
405 msg = "File path cannot be empty"
406 raise ValueError(msg)
407 return v
410class ConceptSearchParams(
411 SearchLimitParams, ProjectContextParams, ScoreThresholdParams
412):
413 """Parameters for concept-based search."""
415 concept: str = Field(
416 description="Development concept to search for",
417 min_length=1,
418 max_length=200,
419 examples=["authentication", "caching", "error handling", "async patterns"],
420 )
422 include_files: bool = Field(
423 default=True,
424 description="Include related files in search results",
425 )
427 @field_validator("concept")
428 @classmethod
429 def validate_concept(cls, v: str) -> str:
430 """Validate concept query."""
431 v = v.strip()
432 if not v: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true
433 msg = "Concept cannot be empty"
434 raise ValueError(msg)
435 return v
438class CrackerjackExecutionParams(CommandExecutionParams, WorkingDirectoryParams):
439 """Parameters for crackerjack command execution."""
441 ai_agent_mode: bool = Field(
442 default=False,
443 description="Enable AI agent mode for autonomous fixing",
444 )
447class CrackerjackHistoryParams(TimeRangeParams, WorkingDirectoryParams):
448 """Parameters for crackerjack execution history."""
450 command_filter: str = Field(
451 default="",
452 max_length=100,
453 description="Filter commands by name",
454 )
457class TeamUserParams(BaseModel):
458 """Parameters for team user operations."""
460 user_id: str = Field(
461 description="Unique user identifier",
462 min_length=1,
463 max_length=100,
464 )
466 username: str = Field(description="Display username", min_length=1, max_length=100)
468 role: Literal["owner", "admin", "moderator", "contributor", "viewer"] = Field(
469 default="contributor",
470 description="User role in the team",
471 )
473 email: str | None = Field(default=None, description="Optional email address")
475 @field_validator("user_id", "username")
476 @classmethod
477 def validate_required_strings(cls, v: str) -> str:
478 """Validate required string fields."""
479 v = v.strip()
480 if not v: 480 ↛ 481line 480 didn't jump to line 481 because the condition on line 480 was never true
481 msg = "Field cannot be empty"
482 raise ValueError(msg)
483 return v
485 @field_validator("email")
486 @classmethod
487 def validate_email(cls, v: str | None) -> str | None:
488 """Basic email validation."""
489 if v is None:
490 return None
492 v = v.strip()
493 if not v:
494 return None
496 # Basic email format validation
497 if len(v) > 254: # RFC 5321 limit
498 msg = "Email address too long"
499 raise ValueError(msg)
501 # Must contain exactly one @ symbol
502 if v.count("@") != 1:
503 msg = "Invalid email format"
504 raise ValueError(msg)
506 local, domain = v.split("@")
508 # Local part cannot be empty
509 if not local:
510 msg = "Invalid email format"
511 raise ValueError(msg)
513 # Domain part must contain at least one dot and cannot be empty
514 if not domain or "." not in domain:
515 msg = "Invalid email format"
516 raise ValueError(msg)
518 # Domain cannot start or end with dot
519 if domain.startswith(".") or domain.endswith("."):
520 msg = "Invalid email format"
521 raise ValueError(msg)
523 return v
526class TeamCreationParams(BaseModel):
527 """Parameters for team creation."""
529 team_id: str = Field(
530 description="Unique team identifier",
531 min_length=1,
532 max_length=100,
533 )
535 name: str = Field(description="Team display name", min_length=1, max_length=200)
537 description: str = Field(
538 description="Team description",
539 min_length=1,
540 max_length=1000,
541 )
543 owner_id: str = Field(
544 description="User ID of the team owner",
545 min_length=1,
546 max_length=100,
547 )
549 @field_validator("team_id", "name", "description", "owner_id")
550 @classmethod
551 def validate_required_strings(cls, v: str) -> str:
552 """Validate required string fields."""
553 v = v.strip()
554 if not v: 554 ↛ 555line 554 didn't jump to line 555 because the condition on line 554 was never true
555 msg = "Field cannot be empty"
556 raise ValueError(msg)
557 return v
560class TeamReflectionParams(ReflectionStoreParams):
561 """Parameters for team reflection operations."""
563 author_id: str = Field(
564 description="ID of the reflection author",
565 min_length=1,
566 max_length=100,
567 )
569 team_id: str | None = Field(
570 default=None,
571 description="Optional team ID for team-specific reflections",
572 min_length=1,
573 max_length=100,
574 )
576 project_id: str | None = Field(
577 default=None,
578 description="Optional project ID for project-specific reflections",
579 min_length=1,
580 max_length=100,
581 )
583 access_level: Literal["private", "team", "public"] = Field(
584 default="team",
585 description="Access level for the reflection",
586 )
588 @field_validator("author_id")
589 @classmethod
590 def validate_author_id(cls, v: str) -> str:
591 """Validate author ID."""
592 v = v.strip()
593 if not v: 593 ↛ 594line 593 didn't jump to line 594 because the condition on line 593 was never true
594 msg = "Author ID cannot be empty"
595 raise ValueError(msg)
596 return v
598 @field_validator("team_id", "project_id")
599 @classmethod
600 def validate_optional_ids(cls, v: str | None) -> str | None:
601 """Validate optional ID fields."""
602 if v is not None: 602 ↛ 606line 602 didn't jump to line 606 because the condition on line 602 was always true
603 v = v.strip()
604 if not v: 604 ↛ 605line 604 didn't jump to line 605 because the condition on line 604 was never true
605 return None
606 return v
609class TeamSearchParams(SearchQueryParams):
610 """Parameters for team knowledge search."""
612 user_id: str = Field(
613 description="ID of the user performing the search",
614 min_length=1,
615 max_length=100,
616 )
618 team_id: str | None = Field(
619 default=None,
620 description="Optional team ID to scope the search",
621 min_length=1,
622 max_length=100,
623 )
625 project_id: str | None = Field(
626 default=None,
627 description="Optional project ID to scope the search",
628 min_length=1,
629 max_length=100,
630 )
632 @field_validator("user_id")
633 @classmethod
634 def validate_user_id(cls, v: str) -> str:
635 """Validate user ID."""
636 v = v.strip()
637 if not v: 637 ↛ 638line 637 didn't jump to line 638 because the condition on line 637 was never true
638 msg = "User ID cannot be empty"
639 raise ValueError(msg)
640 return v
642 @field_validator("team_id", "project_id")
643 @classmethod
644 def validate_optional_ids(cls, v: str | None) -> str | None:
645 """Validate optional ID fields."""
646 if v is not None: 646 ↛ 650line 646 didn't jump to line 650 because the condition on line 646 was always true
647 v = v.strip()
648 if not v: 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true
649 return None
650 return v
653# Validation helper functions
654def validate_mcp_params(
655 model_class: type[BaseModel], **params: Any
656) -> ValidationResponse:
657 """Helper function to validate MCP tool parameters using a Pydantic model.
659 Args:
660 model_class: The Pydantic model class to use for validation
661 **params: Parameter values to validate
663 Returns:
664 ValidationResponse object with is_valid, params, and errors attributes
666 Example:
667 @mcp.tool()
668 async def search_reflections(**params) -> str:
669 validated = validate_mcp_params(SearchQueryParams, **params)
670 if not validated.is_valid:
671 return f"Validation failed: {validated.errors}"
672 query = validated.params.query # access Pydantic model attribute
673 limit = validated.params.limit # access Pydantic model attribute
674 # ... rest of implementation
676 """
677 try:
678 validated_model = model_class(**params)
679 return ValidationResponse(
680 is_valid=True,
681 params=validated_model,
682 errors=None,
683 )
684 except Exception as e:
685 # Convert Pydantic validation errors to more user-friendly messages
686 # Import ValidationError for runtime type checking
687 from pydantic import ValidationError
689 if isinstance(e, ValidationError): 689 ↛ 697line 689 didn't jump to line 697 because the condition on line 689 was always true
690 error_messages = []
691 for error in e.errors(): # type: ignore[attr-defined]
692 field = error.get("loc", ["unknown"])[-1]
693 msg = error.get("msg", "validation error")
694 error_messages.append(f"{field}: {msg}")
695 errors = f"Parameter validation failed: {'; '.join(error_messages)}"
696 return ValidationResponse(is_valid=False, params=None, errors=errors)
697 errors = f"Parameter validation failed: {e!s}"
698 return ValidationResponse(is_valid=False, params=None, errors=errors)
701def create_mcp_validator(
702 model_class: type[BaseModel],
703) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
704 """Decorator factory to create MCP tool parameter validators.
706 Args:
707 model_class: The Pydantic model class to use for validation
709 Returns:
710 Decorator function that validates parameters before tool execution
712 Example:
713 @mcp.tool()
714 @create_mcp_validator(SearchQueryParams)
715 async def search_reflections(**params) -> str:
716 # params are now validated and typed
717 query = params['query']
718 # ... implementation
720 """
722 def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
723 async def wrapper(**params: Any) -> Any:
724 validated_response = validate_mcp_params(model_class, **params)
725 if not validated_response.is_valid:
726 msg = f"Parameter validation failed: {validated_response.errors}"
727 raise ValueError(msg)
728 if validated_response.params is None:
729 return await func()
730 # Convert the Pydantic model to a dictionary for unpacking
731 params_dict = validated_response.params.model_dump()
732 return await func(**params_dict)
734 # Preserve function metadata
735 wrapper.__name__ = func.__name__
736 wrapper.__doc__ = func.__doc__
737 wrapper.__annotations__ = func.__annotations__
739 return wrapper
741 return decorator