Coverage for mcpgateway/services/completion_service.py: 92%

70 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""Completion Service Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements argument completion according to the MCP specification. 

9It handles completion suggestions for prompt arguments and resource URIs. 

10""" 

11 

12# Standard 

13import logging 

14from typing import Any, Dict, List 

15 

16# Third-Party 

17from sqlalchemy import select 

18from sqlalchemy.orm import Session 

19 

20# First-Party 

21from mcpgateway.db import Prompt as DbPrompt 

22from mcpgateway.db import Resource as DbResource 

23from mcpgateway.models import CompleteResult 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28class CompletionError(Exception): 

29 """Base class for completion errors.""" 

30 

31 

32class CompletionService: 

33 """MCP completion service. 

34 

35 Handles argument completion for: 

36 - Prompt arguments based on schema 

37 - Resource URIs with templates 

38 - Custom completion sources 

39 """ 

40 

41 def __init__(self): 

42 """Initialize completion service.""" 

43 self._custom_completions: Dict[str, List[str]] = {} 

44 

45 async def initialize(self) -> None: 

46 """Initialize completion service.""" 

47 logger.info("Initializing completion service") 

48 

49 async def shutdown(self) -> None: 

50 """Shutdown completion service.""" 

51 logger.info("Shutting down completion service") 

52 self._custom_completions.clear() 

53 

54 async def handle_completion(self, db: Session, request: Dict[str, Any]) -> CompleteResult: 

55 """Handle completion request. 

56 

57 Args: 

58 db: Database session 

59 request: Completion request 

60 

61 Returns: 

62 Completion result with suggestions 

63 

64 Raises: 

65 CompletionError: If completion fails 

66 """ 

67 try: 

68 # Get reference and argument info 

69 ref = request.get("ref", {}) 

70 ref_type = ref.get("type") 

71 arg = request.get("argument", {}) 

72 arg_name = arg.get("name") 

73 arg_value = arg.get("value", "") 

74 

75 if not ref_type or not arg_name: 

76 raise CompletionError("Missing reference type or argument name") 

77 

78 # Handle different reference types 

79 if ref_type == "ref/prompt": 79 ↛ 80line 79 didn't jump to line 80 because the condition on line 79 was never true

80 result = await self._complete_prompt_argument(db, ref, arg_name, arg_value) 

81 elif ref_type == "ref/resource": 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true

82 result = await self._complete_resource_uri(db, ref, arg_value) 

83 else: 

84 raise CompletionError(f"Invalid reference type: {ref_type}") 

85 

86 return result 

87 

88 except Exception as e: 

89 logger.error(f"Completion error: {e}") 

90 raise CompletionError(str(e)) 

91 

92 async def _complete_prompt_argument(self, db: Session, ref: Dict[str, Any], arg_name: str, arg_value: str) -> CompleteResult: 

93 """Complete prompt argument value. 

94 

95 Args: 

96 db: Database session 

97 ref: Prompt reference 

98 arg_name: Argument name 

99 arg_value: Current argument value 

100 

101 Returns: 

102 Completion suggestions 

103 

104 Raises: 

105 CompletionError: If URI template is missing 

106 """ 

107 # Get prompt 

108 prompt_name = ref.get("name") 

109 if not prompt_name: 

110 raise CompletionError("Missing prompt name") 

111 

112 prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt_name).where(DbPrompt.is_active)).scalar_one_or_none() 

113 

114 if not prompt: 

115 raise CompletionError(f"Prompt not found: {prompt_name}") 

116 

117 # Find argument in schema 

118 arg_schema = None 

119 for arg in prompt.argument_schema.get("properties", {}).values(): 

120 if arg.get("name") == arg_name: 

121 arg_schema = arg 

122 break 

123 

124 if not arg_schema: 

125 raise CompletionError(f"Argument not found: {arg_name}") 

126 

127 # Get enum values if defined 

128 if "enum" in arg_schema: 

129 values = [v for v in arg_schema["enum"] if arg_value.lower() in str(v).lower()] 

130 return CompleteResult( 

131 completion={ 

132 "values": values[:100], 

133 "total": len(values), 

134 "hasMore": len(values) > 100, 

135 } 

136 ) 

137 

138 # Check custom completions 

139 if arg_name in self._custom_completions: 

140 values = [v for v in self._custom_completions[arg_name] if arg_value.lower() in v.lower()] 

141 return CompleteResult( 

142 completion={ 

143 "values": values[:100], 

144 "total": len(values), 

145 "hasMore": len(values) > 100, 

146 } 

147 ) 

148 

149 # No completions available 

150 return CompleteResult(completion={"values": [], "total": 0, "hasMore": False}) 

151 

152 async def _complete_resource_uri(self, db: Session, ref: Dict[str, Any], arg_value: str) -> CompleteResult: 

153 """Complete resource URI. 

154 

155 Args: 

156 db: Database session 

157 ref: Resource reference 

158 arg_value: Current URI value 

159 

160 Returns: 

161 URI completion suggestions 

162 

163 Raises: 

164 CompletionError: If URI template is missing 

165 """ 

166 # Get base URI template 

167 uri_template = ref.get("uri") 

168 if not uri_template: 

169 raise CompletionError("Missing URI template") 

170 

171 # List matching resources 

172 resources = db.execute(select(DbResource).where(DbResource.is_active)).scalars().all() 

173 

174 # Filter by URI pattern 

175 matches = [] 

176 for resource in resources: 

177 if arg_value.lower() in resource.uri.lower(): 

178 matches.append(resource.uri) 

179 

180 return CompleteResult( 

181 completion={ 

182 "values": matches[:100], 

183 "total": len(matches), 

184 "hasMore": len(matches) > 100, 

185 } 

186 ) 

187 

188 def register_completions(self, arg_name: str, values: List[str]) -> None: 

189 """Register custom completion values. 

190 

191 Args: 

192 arg_name: Argument name 

193 values: Completion values 

194 """ 

195 self._custom_completions[arg_name] = list(values) 

196 

197 def unregister_completions(self, arg_name: str) -> None: 

198 """Unregister custom completion values. 

199 

200 Args: 

201 arg_name: Argument name 

202 """ 

203 self._custom_completions.pop(arg_name, None)