Coverage for mcpgateway/federation/forward.py: 53%
122 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""Federation Request Forwarding.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements request forwarding for federated MCP Gateways.
9It handles:
10- Request routing to appropriate gateways
11- Response aggregation
12- Error handling and retry logic
13- Request/response transformation
14"""
16# Standard
17import asyncio
18from datetime import datetime, timezone
19import logging
20from typing import Any, Dict, List, Optional, Set, Tuple, Union
22# Third-Party
23import httpx
24from sqlalchemy import select
25from sqlalchemy.orm import Session
27# First-Party
28from mcpgateway.config import settings
29from mcpgateway.db import Gateway as DbGateway
30from mcpgateway.db import Tool as DbTool
31from mcpgateway.models import ToolResult
33logger = logging.getLogger(__name__)
36class ForwardingError(Exception):
37 """Base class for forwarding-related errors."""
40class ForwardingService:
41 """Service for handling request forwarding across gateways.
43 Handles:
44 - Request routing
45 - Response aggregation
46 - Error handling
47 - Request transformation
48 """
50 def __init__(self):
51 """Initialize forwarding service."""
52 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
54 # Track active requests
55 self._active_requests: Dict[str, asyncio.Task] = {}
57 # Request history for rate limiting
58 self._request_history: Dict[str, List[datetime]] = {}
60 # Cache gateway information
61 self._gateway_tools: Dict[int, Set[str]] = {}
63 async def start(self) -> None:
64 """Start forwarding service."""
65 logger.info("Request forwarding service started")
67 async def stop(self) -> None:
68 """Stop forwarding service."""
69 # Cancel active requests
70 for request_id, task in self._active_requests.items(): 70 ↛ 71line 70 didn't jump to line 71 because the loop on line 70 never started
71 logger.info(f"Cancelling request {request_id}")
72 task.cancel()
73 try:
74 await task
75 except asyncio.CancelledError:
76 pass
78 await self._http_client.aclose()
79 logger.info("Request forwarding service stopped")
81 async def forward_request(
82 self,
83 db: Session,
84 method: str,
85 params: Optional[Dict[str, Any]] = None,
86 target_gateway_id: Optional[int] = None,
87 ) -> Any:
88 """Forward a request to gateway(s).
90 Args:
91 db: Database session
92 method: RPC method name
93 params: Optional method parameters
94 target_gateway_id: Optional specific gateway ID
96 Returns:
97 Forwarded response(s)
99 Raises:
100 ForwardingError: If forwarding fails
101 """
102 try:
103 if target_gateway_id:
104 # Forward to specific gateway
105 return await self._forward_to_gateway(db, target_gateway_id, method, params)
107 # Forward to all relevant gateways
108 return await self._forward_to_all(db, method, params)
110 except Exception as e:
111 raise ForwardingError(f"Forward request failed: {str(e)}")
113 async def forward_tool_request(self, db: Session, tool_name: str, arguments: Dict[str, Any]) -> ToolResult:
114 """Forward a tool invocation request.
116 Args:
117 db: Database session
118 tool_name: Tool to invoke
119 arguments: Tool arguments
121 Returns:
122 Tool result
124 Raises:
125 ForwardingError: If forwarding fails
126 """
127 try:
128 # Find tool
129 tool = db.execute(select(DbTool).where(DbTool.name == tool_name).where(DbTool.enabled)).scalar_one_or_none()
131 if not tool: 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true
132 raise ForwardingError(f"Tool not found: {tool_name}")
134 if not tool.gateway_id: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 raise ForwardingError(f"Tool {tool_name} is not federated")
137 # Forward to gateway
138 result = await self._forward_to_gateway(
139 db,
140 tool.gateway_id,
141 "tools/invoke",
142 {"name": tool_name, "arguments": arguments},
143 )
145 # Parse result
146 return ToolResult(
147 content=result.get("content", []),
148 is_error=result.get("is_error", False),
149 )
151 except Exception as e:
152 raise ForwardingError(f"Failed to forward tool request: {str(e)}")
154 async def forward_resource_request(self, db: Session, uri: str) -> Tuple[Union[str, bytes], str]:
155 """Forward a resource read request.
157 Args:
158 db: Database session
159 uri: Resource URI
161 Returns:
162 Tuple of (content, mime_type)
164 Raises:
165 ForwardingError: If forwarding fails
166 """
167 try:
168 # Find gateway for resource
169 gateway = await self._find_resource_gateway(db, uri)
170 if not gateway:
171 raise ForwardingError(f"No gateway found for resource: {uri}")
173 # Forward request
174 result = await self._forward_to_gateway(db, gateway.id, "resources/read", {"uri": uri})
176 # Parse result
177 if "text" in result:
178 return result["text"], result.get("mime_type", "text/plain")
179 if "blob" in result:
180 return result["blob"], result.get("mime_type", "application/octet-stream")
182 raise ForwardingError("Invalid resource response format")
184 except Exception as e:
185 raise ForwardingError(f"Failed to forward resource request: {str(e)}")
187 async def _forward_to_gateway(
188 self,
189 db: Session,
190 gateway_id: str,
191 method: str,
192 params: Optional[Dict[str, Any]] = None,
193 ) -> Any:
194 """Forward request to a specific gateway.
196 Args:
197 db: Database session
198 gateway_id: Gateway to forward to
199 method: RPC method name
200 params: Optional method parameters
202 Returns:
203 Gateway response
205 Raises:
206 ForwardingError: If forwarding fails
207 httpx.TimeoutException: If unable to connect after retries
208 """
209 # Get gateway
210 gateway = db.get(DbGateway, gateway_id)
211 if not gateway or not gateway.enabled: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true
212 raise ForwardingError(f"Gateway not found: {gateway_id}")
214 # Check rate limits
215 if not self._check_rate_limit(gateway.url): 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true
216 raise ForwardingError("Rate limit exceeded")
218 try:
219 # Build request
220 request = {"jsonrpc": "2.0", "id": 1, "method": method}
221 if params: 221 ↛ 225line 221 didn't jump to line 225 because the condition on line 221 was always true
222 request["params"] = params
224 # Send request with retries using the persistent client directly
225 for attempt in range(settings.max_tool_retries): 225 ↛ exitline 225 didn't return from function '_forward_to_gateway' because the loop on line 225 didn't complete
226 try:
227 response = await self._http_client.post(
228 f"{gateway.url}/rpc",
229 json=request,
230 headers=self._get_auth_headers(),
231 )
232 response.raise_for_status()
233 result = response.json()
235 # Update last seen
236 gateway.last_seen = datetime.now(timezone.utc)
238 # Handle response
239 if "error" in result: 239 ↛ 240line 239 didn't jump to line 240 because the condition on line 239 was never true
240 raise ForwardingError(f"Gateway error: {result['error'].get('message')}")
241 return result.get("result")
243 except httpx.TimeoutException:
244 if attempt == settings.max_tool_retries - 1:
245 raise
246 await asyncio.sleep(1 * (attempt + 1))
248 except Exception as e:
249 raise ForwardingError(f"Failed to forward to {gateway.name}: {str(e)}")
251 async def _forward_to_all(self, db: Session, method: str, params: Optional[Dict[str, Any]] = None) -> List[Any]:
252 """Forward request to all active gateways.
254 Args:
255 db: Database session
256 method: RPC method name
257 params: Optional method parameters
259 Returns:
260 List of responses
262 Raises:
263 ForwardingError: If all forwards fail
264 """
265 # Get active gateways
266 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
268 # Forward to each gateway
269 results = []
270 errors = []
272 for gateway in gateways:
273 try:
274 result = await self._forward_to_gateway(db, gateway.id, method, params)
275 results.append(result)
276 except Exception as e:
277 errors.append(str(e))
279 if not results and errors: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true
280 raise ForwardingError(f"All forwards failed: {'; '.join(errors)}")
282 return results
284 async def _find_resource_gateway(self, db: Session, uri: str) -> Optional[DbGateway]:
285 """Find gateway hosting a resource.
287 Args:
288 db: Database session
289 uri: Resource URI
291 Returns:
292 Gateway record or None
293 """
294 # Get active gateways
295 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
297 # Check each gateway
298 for gateway in gateways:
299 try:
300 resources = await self._forward_to_gateway(db, gateway.id, "resources/list")
301 for resource in resources:
302 if resource.get("uri") == uri:
303 return gateway
304 except Exception as e:
305 logger.error(f"Failed to check gateway {gateway.name} for resource {uri}: {str(e)}")
306 continue
308 return None
310 def _check_rate_limit(self, gateway_url: str) -> bool:
311 """Check if gateway request is within rate limits.
313 Args:
314 gateway_url: Gateway URL
316 Returns:
317 True if request allowed
318 """
319 now = datetime.now(timezone.utc)
321 # Clean old history
322 self._request_history[gateway_url] = [t for t in self._request_history.get(gateway_url, []) if (now - t).total_seconds() < 60]
324 # Check limit
325 if len(self._request_history[gateway_url]) >= settings.tool_rate_limit:
326 return False
328 # Record request
329 self._request_history[gateway_url].append(now)
330 return True
332 def _get_auth_headers(self) -> Dict[str, str]:
333 """
334 Get headers for gateway authentication.
336 Returns:
337 dict: Authorization header dict
338 """
339 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}"
340 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}