Coverage for mcpgateway/services/completion_service.py: 92%
70 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""Completion Service Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements argument completion according to the MCP specification.
9It handles completion suggestions for prompt arguments and resource URIs.
10"""
12# Standard
13import logging
14from typing import Any, Dict, List
16# Third-Party
17from sqlalchemy import select
18from sqlalchemy.orm import Session
20# First-Party
21from mcpgateway.db import Prompt as DbPrompt
22from mcpgateway.db import Resource as DbResource
23from mcpgateway.models import CompleteResult
25logger = logging.getLogger(__name__)
28class CompletionError(Exception):
29 """Base class for completion errors."""
32class CompletionService:
33 """MCP completion service.
35 Handles argument completion for:
36 - Prompt arguments based on schema
37 - Resource URIs with templates
38 - Custom completion sources
39 """
41 def __init__(self):
42 """Initialize completion service."""
43 self._custom_completions: Dict[str, List[str]] = {}
45 async def initialize(self) -> None:
46 """Initialize completion service."""
47 logger.info("Initializing completion service")
49 async def shutdown(self) -> None:
50 """Shutdown completion service."""
51 logger.info("Shutting down completion service")
52 self._custom_completions.clear()
54 async def handle_completion(self, db: Session, request: Dict[str, Any]) -> CompleteResult:
55 """Handle completion request.
57 Args:
58 db: Database session
59 request: Completion request
61 Returns:
62 Completion result with suggestions
64 Raises:
65 CompletionError: If completion fails
66 """
67 try:
68 # Get reference and argument info
69 ref = request.get("ref", {})
70 ref_type = ref.get("type")
71 arg = request.get("argument", {})
72 arg_name = arg.get("name")
73 arg_value = arg.get("value", "")
75 if not ref_type or not arg_name:
76 raise CompletionError("Missing reference type or argument name")
78 # Handle different reference types
79 if ref_type == "ref/prompt": 79 ↛ 80line 79 didn't jump to line 80 because the condition on line 79 was never true
80 result = await self._complete_prompt_argument(db, ref, arg_name, arg_value)
81 elif ref_type == "ref/resource": 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true
82 result = await self._complete_resource_uri(db, ref, arg_value)
83 else:
84 raise CompletionError(f"Invalid reference type: {ref_type}")
86 return result
88 except Exception as e:
89 logger.error(f"Completion error: {e}")
90 raise CompletionError(str(e))
92 async def _complete_prompt_argument(self, db: Session, ref: Dict[str, Any], arg_name: str, arg_value: str) -> CompleteResult:
93 """Complete prompt argument value.
95 Args:
96 db: Database session
97 ref: Prompt reference
98 arg_name: Argument name
99 arg_value: Current argument value
101 Returns:
102 Completion suggestions
104 Raises:
105 CompletionError: If URI template is missing
106 """
107 # Get prompt
108 prompt_name = ref.get("name")
109 if not prompt_name:
110 raise CompletionError("Missing prompt name")
112 prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt_name).where(DbPrompt.is_active)).scalar_one_or_none()
114 if not prompt:
115 raise CompletionError(f"Prompt not found: {prompt_name}")
117 # Find argument in schema
118 arg_schema = None
119 for arg in prompt.argument_schema.get("properties", {}).values():
120 if arg.get("name") == arg_name:
121 arg_schema = arg
122 break
124 if not arg_schema:
125 raise CompletionError(f"Argument not found: {arg_name}")
127 # Get enum values if defined
128 if "enum" in arg_schema:
129 values = [v for v in arg_schema["enum"] if arg_value.lower() in str(v).lower()]
130 return CompleteResult(
131 completion={
132 "values": values[:100],
133 "total": len(values),
134 "hasMore": len(values) > 100,
135 }
136 )
138 # Check custom completions
139 if arg_name in self._custom_completions:
140 values = [v for v in self._custom_completions[arg_name] if arg_value.lower() in v.lower()]
141 return CompleteResult(
142 completion={
143 "values": values[:100],
144 "total": len(values),
145 "hasMore": len(values) > 100,
146 }
147 )
149 # No completions available
150 return CompleteResult(completion={"values": [], "total": 0, "hasMore": False})
152 async def _complete_resource_uri(self, db: Session, ref: Dict[str, Any], arg_value: str) -> CompleteResult:
153 """Complete resource URI.
155 Args:
156 db: Database session
157 ref: Resource reference
158 arg_value: Current URI value
160 Returns:
161 URI completion suggestions
163 Raises:
164 CompletionError: If URI template is missing
165 """
166 # Get base URI template
167 uri_template = ref.get("uri")
168 if not uri_template:
169 raise CompletionError("Missing URI template")
171 # List matching resources
172 resources = db.execute(select(DbResource).where(DbResource.is_active)).scalars().all()
174 # Filter by URI pattern
175 matches = []
176 for resource in resources:
177 if arg_value.lower() in resource.uri.lower():
178 matches.append(resource.uri)
180 return CompleteResult(
181 completion={
182 "values": matches[:100],
183 "total": len(matches),
184 "hasMore": len(matches) > 100,
185 }
186 )
188 def register_completions(self, arg_name: str, values: List[str]) -> None:
189 """Register custom completion values.
191 Args:
192 arg_name: Argument name
193 values: Completion values
194 """
195 self._custom_completions[arg_name] = list(values)
197 def unregister_completions(self, arg_name: str) -> None:
198 """Unregister custom completion values.
200 Args:
201 arg_name: Argument name
202 """
203 self._custom_completions.pop(arg_name, None)