Coverage for mcpgateway/federation/forward.py: 53%

122 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""Federation Request Forwarding. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements request forwarding for federated MCP Gateways. 

9It handles: 

10- Request routing to appropriate gateways 

11- Response aggregation 

12- Error handling and retry logic 

13- Request/response transformation 

14""" 

15 

16# Standard 

17import asyncio 

18from datetime import datetime, timezone 

19import logging 

20from typing import Any, Dict, List, Optional, Set, Tuple, Union 

21 

22# Third-Party 

23import httpx 

24from sqlalchemy import select 

25from sqlalchemy.orm import Session 

26 

27# First-Party 

28from mcpgateway.config import settings 

29from mcpgateway.db import Gateway as DbGateway 

30from mcpgateway.db import Tool as DbTool 

31from mcpgateway.models import ToolResult 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36class ForwardingError(Exception): 

37 """Base class for forwarding-related errors.""" 

38 

39 

40class ForwardingService: 

41 """Service for handling request forwarding across gateways. 

42 

43 Handles: 

44 - Request routing 

45 - Response aggregation 

46 - Error handling 

47 - Request transformation 

48 """ 

49 

50 def __init__(self): 

51 """Initialize forwarding service.""" 

52 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

53 

54 # Track active requests 

55 self._active_requests: Dict[str, asyncio.Task] = {} 

56 

57 # Request history for rate limiting 

58 self._request_history: Dict[str, List[datetime]] = {} 

59 

60 # Cache gateway information 

61 self._gateway_tools: Dict[int, Set[str]] = {} 

62 

63 async def start(self) -> None: 

64 """Start forwarding service.""" 

65 logger.info("Request forwarding service started") 

66 

67 async def stop(self) -> None: 

68 """Stop forwarding service.""" 

69 # Cancel active requests 

70 for request_id, task in self._active_requests.items(): 70 ↛ 71line 70 didn't jump to line 71 because the loop on line 70 never started

71 logger.info(f"Cancelling request {request_id}") 

72 task.cancel() 

73 try: 

74 await task 

75 except asyncio.CancelledError: 

76 pass 

77 

78 await self._http_client.aclose() 

79 logger.info("Request forwarding service stopped") 

80 

81 async def forward_request( 

82 self, 

83 db: Session, 

84 method: str, 

85 params: Optional[Dict[str, Any]] = None, 

86 target_gateway_id: Optional[int] = None, 

87 ) -> Any: 

88 """Forward a request to gateway(s). 

89 

90 Args: 

91 db: Database session 

92 method: RPC method name 

93 params: Optional method parameters 

94 target_gateway_id: Optional specific gateway ID 

95 

96 Returns: 

97 Forwarded response(s) 

98 

99 Raises: 

100 ForwardingError: If forwarding fails 

101 """ 

102 try: 

103 if target_gateway_id: 

104 # Forward to specific gateway 

105 return await self._forward_to_gateway(db, target_gateway_id, method, params) 

106 

107 # Forward to all relevant gateways 

108 return await self._forward_to_all(db, method, params) 

109 

110 except Exception as e: 

111 raise ForwardingError(f"Forward request failed: {str(e)}") 

112 

113 async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dict[str, Any]) -> ToolResult: 

114 """Forward a tool invocation request. 

115 

116 Args: 

117 db: Database session 

118 tool_name: Tool to invoke 

119 arguments: Tool arguments 

120 

121 Returns: 

122 Tool result 

123 

124 Raises: 

125 ForwardingError: If forwarding fails 

126 """ 

127 try: 

128 # Find tool 

129 tool = db.execute(select(DbTool).where(DbTool.name == tool_name).where(DbTool.enabled)).scalar_one_or_none() 

130 

131 if not tool: 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true

132 raise ForwardingError(f"Tool not found: {tool_name}") 

133 

134 if not tool.gateway_id: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 raise ForwardingError(f"Tool {tool_name} is not federated") 

136 

137 # Forward to gateway 

138 result = await self._forward_to_gateway( 

139 db, 

140 tool.gateway_id, 

141 "tools/invoke", 

142 {"name": tool_name, "arguments": arguments}, 

143 ) 

144 

145 # Parse result 

146 return ToolResult( 

147 content=result.get("content", []), 

148 is_error=result.get("is_error", False), 

149 ) 

150 

151 except Exception as e: 

152 raise ForwardingError(f"Failed to forward tool request: {str(e)}") 

153 

154 async def forward_resource_request(self, db: Session, uri: str) -> Tuple[Union[str, bytes], str]: 

155 """Forward a resource read request. 

156 

157 Args: 

158 db: Database session 

159 uri: Resource URI 

160 

161 Returns: 

162 Tuple of (content, mime_type) 

163 

164 Raises: 

165 ForwardingError: If forwarding fails 

166 """ 

167 try: 

168 # Find gateway for resource 

169 gateway = await self._find_resource_gateway(db, uri) 

170 if not gateway: 

171 raise ForwardingError(f"No gateway found for resource: {uri}") 

172 

173 # Forward request 

174 result = await self._forward_to_gateway(db, gateway.id, "resources/read", {"uri": uri}) 

175 

176 # Parse result 

177 if "text" in result: 

178 return result["text"], result.get("mime_type", "text/plain") 

179 if "blob" in result: 

180 return result["blob"], result.get("mime_type", "application/octet-stream") 

181 

182 raise ForwardingError("Invalid resource response format") 

183 

184 except Exception as e: 

185 raise ForwardingError(f"Failed to forward resource request: {str(e)}") 

186 

187 async def _forward_to_gateway( 

188 self, 

189 db: Session, 

190 gateway_id: str, 

191 method: str, 

192 params: Optional[Dict[str, Any]] = None, 

193 ) -> Any: 

194 """Forward request to a specific gateway. 

195 

196 Args: 

197 db: Database session 

198 gateway_id: Gateway to forward to 

199 method: RPC method name 

200 params: Optional method parameters 

201 

202 Returns: 

203 Gateway response 

204 

205 Raises: 

206 ForwardingError: If forwarding fails 

207 httpx.TimeoutException: If unable to connect after retries 

208 """ 

209 # Get gateway 

210 gateway = db.get(DbGateway, gateway_id) 

211 if not gateway or not gateway.enabled: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 raise ForwardingError(f"Gateway not found: {gateway_id}") 

213 

214 # Check rate limits 

215 if not self._check_rate_limit(gateway.url): 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true

216 raise ForwardingError("Rate limit exceeded") 

217 

218 try: 

219 # Build request 

220 request = {"jsonrpc": "2.0", "id": 1, "method": method} 

221 if params: 221 ↛ 225line 221 didn't jump to line 225 because the condition on line 221 was always true

222 request["params"] = params 

223 

224 # Send request with retries using the persistent client directly 

225 for attempt in range(settings.max_tool_retries): 225 ↛ exitline 225 didn't return from function '_forward_to_gateway' because the loop on line 225 didn't complete

226 try: 

227 response = await self._http_client.post( 

228 f"{gateway.url}/rpc", 

229 json=request, 

230 headers=self._get_auth_headers(), 

231 ) 

232 response.raise_for_status() 

233 result = response.json() 

234 

235 # Update last seen 

236 gateway.last_seen = datetime.now(timezone.utc) 

237 

238 # Handle response 

239 if "error" in result: 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true

240 raise ForwardingError(f"Gateway error: {result['error'].get('message')}") 

241 return result.get("result") 

242 

243 except httpx.TimeoutException: 

244 if attempt == settings.max_tool_retries - 1: 

245 raise 

246 await asyncio.sleep(1 * (attempt + 1)) 

247 

248 except Exception as e: 

249 raise ForwardingError(f"Failed to forward to {gateway.name}: {str(e)}") 

250 

251 async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None) -> List[Any]: 

252 """Forward request to all active gateways. 

253 

254 Args: 

255 db: Database session 

256 method: RPC method name 

257 params: Optional method parameters 

258 

259 Returns: 

260 List of responses 

261 

262 Raises: 

263 ForwardingError: If all forwards fail 

264 """ 

265 # Get active gateways 

266 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

267 

268 # Forward to each gateway 

269 results = [] 

270 errors = [] 

271 

272 for gateway in gateways: 

273 try: 

274 result = await self._forward_to_gateway(db, gateway.id, method, params) 

275 results.append(result) 

276 except Exception as e: 

277 errors.append(str(e)) 

278 

279 if not results and errors: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true

280 raise ForwardingError(f"All forwards failed: {'; '.join(errors)}") 

281 

282 return results 

283 

284 async def _find_resource_gateway(self, db: Session, uri: str) -> Optional[DbGateway]: 

285 """Find gateway hosting a resource. 

286 

287 Args: 

288 db: Database session 

289 uri: Resource URI 

290 

291 Returns: 

292 Gateway record or None 

293 """ 

294 # Get active gateways 

295 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

296 

297 # Check each gateway 

298 for gateway in gateways: 

299 try: 

300 resources = await self._forward_to_gateway(db, gateway.id, "resources/list") 

301 for resource in resources: 

302 if resource.get("uri") == uri: 

303 return gateway 

304 except Exception as e: 

305 logger.error(f"Failed to check gateway {gateway.name} for resource {uri}: {str(e)}") 

306 continue 

307 

308 return None 

309 

310 def _check_rate_limit(self, gateway_url: str) -> bool: 

311 """Check if gateway request is within rate limits. 

312 

313 Args: 

314 gateway_url: Gateway URL 

315 

316 Returns: 

317 True if request allowed 

318 """ 

319 now = datetime.now(timezone.utc) 

320 

321 # Clean old history 

322 self._request_history[gateway_url] = [t for t in self._request_history.get(gateway_url, []) if (now - t).total_seconds() < 60] 

323 

324 # Check limit 

325 if len(self._request_history[gateway_url]) >= settings.tool_rate_limit: 

326 return False 

327 

328 # Record request 

329 self._request_history[gateway_url].append(now) 

330 return True 

331 

332 def _get_auth_headers(self) -> Dict[str, str]: 

333 """ 

334 Get headers for gateway authentication. 

335 

336 Returns: 

337 dict: Authorization header dict 

338 """ 

339 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}" 

340 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}