Coverage for mcpgateway/services/gateway_service.py: 44%
420 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""Gateway Service Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements gateway federation according to the MCP specification.
9It handles:
10- Gateway discovery and registration
11- Request forwarding
12- Capability aggregation
13- Health monitoring
14- Active/inactive gateway management
15"""
17# Standard
18import asyncio
19from datetime import datetime, timezone
20import logging
21import os
22import tempfile
23from typing import Any, AsyncGenerator, Dict, List, Optional, Set
24import uuid
26# Third-Party
27from filelock import FileLock, Timeout
28import httpx
29from mcp import ClientSession
30from mcp.client.sse import sse_client
31from mcp.client.streamable_http import streamablehttp_client
32from sqlalchemy import select
33from sqlalchemy.orm import Session
35# First-Party
36from mcpgateway.config import settings
37from mcpgateway.db import Gateway as DbGateway
38from mcpgateway.db import SessionLocal
39from mcpgateway.db import Tool as DbTool
40from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, ToolCreate
41from mcpgateway.services.tool_service import ToolService
42from mcpgateway.utils.create_slug import slugify
43from mcpgateway.utils.services_auth import decode_auth
45try:
46 # Third-Party
47 import redis
49 REDIS_AVAILABLE = True
50except ImportError:
51 REDIS_AVAILABLE = False
52 logging.info("Redis is not utilized in this environment.")
54# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks
55logger = logging.getLogger(__name__)
58GW_FAILURE_THRESHOLD = settings.unhealthy_threshold
59GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval
62class GatewayError(Exception):
63 """Base class for gateway-related errors."""
66class GatewayNotFoundError(GatewayError):
67 """Raised when a requested gateway is not found."""
70class GatewayNameConflictError(GatewayError):
71 """Raised when a gateway name conflicts with existing (active or inactive) gateway."""
73 def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None):
74 """Initialize the error with gateway information.
76 Args:
77 name: The conflicting gateway name
78 enabled: Whether the existing gateway is enabled
79 gateway_id: ID of the existing gateway if available
80 """
81 self.name = name
82 self.enabled = enabled
83 self.gateway_id = gateway_id
84 message = f"Gateway already exists with name: {name}"
85 if not enabled: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true
86 message += f" (currently inactive, ID: {gateway_id})"
87 super().__init__(message)
90class GatewayConnectionError(GatewayError):
91 """Raised when gateway connection fails."""
94class GatewayService:
95 """Service for managing federated gateways.
97 Handles:
98 - Gateway registration and health checks
99 - Request forwarding
100 - Capability negotiation
101 - Federation events
102 - Active/inactive status management
103 """
105 def __init__(self) -> None:
106 """Initialize the gateway service."""
107 self._event_subscribers: List[asyncio.Queue] = []
108 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
109 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL
110 self._health_check_task: Optional[asyncio.Task] = None
111 self._active_gateways: Set[str] = set() # Track active gateway URLs
112 self._stream_response = None
113 self._pending_responses = {}
114 self.tool_service = ToolService()
115 self._gateway_failure_counts: dict[str, int] = {}
117 # For health checks, we determine the leader instance.
118 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None
120 if self.redis_url and REDIS_AVAILABLE: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true
121 self._redis_client = redis.from_url(self.redis_url)
122 self._instance_id = str(uuid.uuid4()) # Unique ID for this process
123 self._leader_key = "gateway_service_leader"
124 self._leader_ttl = 40 # seconds
125 elif settings.cache_type != "none": 125 ↛ 137line 125 didn't jump to line 137 because the condition on line 125 was always true
126 # Fallback: File-based lock
127 self._redis_client = None
129 temp_dir = tempfile.gettempdir()
130 user_path = os.path.normpath(settings.filelock_name)
131 if os.path.isabs(user_path): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true
132 user_path = os.path.relpath(user_path, start=os.path.splitdrive(user_path)[0] + os.sep)
133 full_path = os.path.join(temp_dir, user_path)
134 self._lock_path = full_path.replace("\\", "/")
135 self._file_lock = FileLock(self._lock_path)
136 else:
137 self._redis_client = None
139 async def initialize(self) -> None:
140 """Initialize the service and start health check if this instance is the leader.
142 Raises:
143 ConnectionError: When redis ping fails
144 """
145 logger.info("Initializing gateway service")
147 if self._redis_client:
148 # Check if Redis is available
149 pong = self._redis_client.ping()
150 if not pong:
151 raise ConnectionError("Redis ping failed.")
153 is_leader = self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True)
154 if is_leader:
155 logger.info("Acquired Redis leadership. Starting health check task.")
156 self._health_check_task = asyncio.create_task(self._run_health_checks())
157 else:
158 # Always create the health check task in filelock mode; leader check is handled inside.
159 self._health_check_task = asyncio.create_task(self._run_health_checks())
161 async def shutdown(self) -> None:
162 """Shutdown the service."""
163 if self._health_check_task:
164 self._health_check_task.cancel()
165 try:
166 await self._health_check_task
167 except asyncio.CancelledError:
168 pass
170 await self._http_client.aclose()
171 self._event_subscribers.clear()
172 self._active_gateways.clear()
173 logger.info("Gateway service shutdown complete")
175 async def register_gateway(self, db: Session, gateway: GatewayCreate) -> GatewayRead:
176 """Register a new gateway.
178 Args:
179 db: Database session
180 gateway: Gateway creation schema
182 Returns:
183 Created gateway information
185 Raises:
186 GatewayNameConflictError: If gateway name already exists
187 []: When ExceptionGroup found
188 """
189 try:
190 # Check for name conflicts (both active and inactive)
191 existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none()
193 if existing_gateway:
194 raise GatewayNameConflictError(
195 gateway.name,
196 enabled=existing_gateway.enabled,
197 gateway_id=existing_gateway.id,
198 )
200 auth_type = getattr(gateway, "auth_type", None)
201 auth_value = getattr(gateway, "auth_value", {})
203 capabilities, tools = await self._initialize_gateway(gateway.url, auth_value, gateway.transport)
205 tools = [
206 DbTool(
207 original_name=tool.name,
208 original_name_slug=slugify(tool.name),
209 url=gateway.url,
210 description=tool.description,
211 integration_type=tool.integration_type,
212 request_type=tool.request_type,
213 headers=tool.headers,
214 input_schema=tool.input_schema,
215 annotations=tool.annotations,
216 jsonpath_filter=tool.jsonpath_filter,
217 auth_type=auth_type,
218 auth_value=auth_value,
219 )
220 for tool in tools
221 ]
223 # Create DB model
224 db_gateway = DbGateway(
225 name=gateway.name,
226 slug=slugify(gateway.name),
227 url=gateway.url,
228 description=gateway.description,
229 transport=gateway.transport,
230 capabilities=capabilities,
231 last_seen=datetime.now(timezone.utc),
232 auth_type=auth_type,
233 auth_value=auth_value,
234 tools=tools,
235 )
237 # Add to DB
238 db.add(db_gateway)
239 db.commit()
240 db.refresh(db_gateway)
242 # Update tracking
243 self._active_gateways.add(db_gateway.url)
245 # Notify subscribers
246 await self._notify_gateway_added(db_gateway)
248 return GatewayRead.model_validate(gateway)
249 except* GatewayConnectionError as ge:
250 logger.error("GatewayConnectionError in group: %s", ge.exceptions)
251 raise ge.exceptions[0]
252 except* ValueError as ve:
253 logger.error("ValueErrors in group: %s", ve.exceptions)
254 raise ve.exceptions[0]
255 except* RuntimeError as re:
256 logger.error("RuntimeErrors in group: %s", re.exceptions)
257 raise re.exceptions[0]
258 except* BaseException as other: # catches every other sub-exception
259 logger.error("Other grouped errors: %s", other.exceptions)
260 raise other.exceptions[0]
262 async def list_gateways(self, db: Session, include_inactive: bool = False) -> List[GatewayRead]:
263 """List all registered gateways.
265 Args:
266 db: Database session
267 include_inactive: Whether to include inactive gateways
269 Returns:
270 List of registered gateways
271 """
272 query = select(DbGateway)
274 if not include_inactive: 274 ↛ 277line 274 didn't jump to line 277 because the condition on line 274 was always true
275 query = query.where(DbGateway.enabled)
277 gateways = db.execute(query).scalars().all()
278 return [GatewayRead.model_validate(g) for g in gateways]
280 async def update_gateway(self, db: Session, gateway_id: str, gateway_update: GatewayUpdate) -> GatewayRead:
281 """Update a gateway.
283 Args:
284 db: Database session
285 gateway_id: Gateway ID to update
286 gateway_update: Updated gateway data
288 Returns:
289 Updated gateway information
291 Raises:
292 GatewayNotFoundError: If gateway not found
293 GatewayError: For other update errors
294 GatewayNameConflictError: If gateway name conflict occurs
295 """
296 try:
297 # Find gateway
298 gateway = db.get(DbGateway, gateway_id)
299 if not gateway:
300 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
302 if not gateway.enabled: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true
303 raise GatewayNotFoundError(f"Gateway '{gateway.name}' exists but is inactive")
305 # Check for name conflicts if name is being changed
306 if gateway_update.name is not None and gateway_update.name != gateway.name: 306 ↛ 317line 306 didn't jump to line 317 because the condition on line 306 was always true
307 existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none()
309 if existing_gateway:
310 raise GatewayNameConflictError(
311 gateway_update.name,
312 enabled=existing_gateway.enabled,
313 gateway_id=existing_gateway.id,
314 )
316 # Update fields if provided
317 if gateway_update.name is not None: 317 ↛ 320line 317 didn't jump to line 320 because the condition on line 317 was always true
318 gateway.name = gateway_update.name
319 gateway.slug = slugify(gateway_update.name)
320 if gateway_update.url is not None: 320 ↛ 322line 320 didn't jump to line 322 because the condition on line 320 was always true
321 gateway.url = gateway_update.url
322 if gateway_update.description is not None: 322 ↛ 324line 322 didn't jump to line 324 because the condition on line 322 was always true
323 gateway.description = gateway_update.description
324 if gateway_update.transport is not None: 324 ↛ 327line 324 didn't jump to line 327 because the condition on line 324 was always true
325 gateway.transport = gateway_update.transport
327 if getattr(gateway, "auth_type", None) is not None: 327 ↛ 335line 327 didn't jump to line 335 because the condition on line 327 was always true
328 gateway.auth_type = gateway_update.auth_type
330 # if auth_type is not None and only then check auth_value
331 if getattr(gateway, "auth_value", {}) != {}: 331 ↛ 332line 331 didn't jump to line 332 because the condition on line 331 was never true
332 gateway.auth_value = gateway_update.auth_value
334 # Try to reinitialize connection if URL changed
335 if gateway_update.url is not None: 335 ↛ 369line 335 didn't jump to line 369 because the condition on line 335 was always true
336 try:
337 capabilities, tools = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport)
338 new_tool_names = [tool.name for tool in tools]
340 for tool in tools: 340 ↛ 341line 340 didn't jump to line 341 because the loop on line 340 never started
341 existing_tool = db.execute(select(DbTool).where(DbTool.original_name == tool.name).where(DbTool.gateway_id == gateway_id)).scalar_one_or_none()
342 if not existing_tool:
343 gateway.tools.append(
344 DbTool(
345 original_name=tool.name,
346 original_name_slug=slugify(tool.name),
347 url=gateway.url,
348 description=tool.description,
349 integration_type=tool.integration_type,
350 request_type=tool.request_type,
351 headers=tool.headers,
352 input_schema=tool.input_schema,
353 jsonpath_filter=tool.jsonpath_filter,
354 auth_type=gateway.auth_type,
355 auth_value=gateway.auth_value,
356 )
357 )
359 gateway.capabilities = capabilities
360 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows
361 gateway.last_seen = datetime.now(timezone.utc)
363 # Update tracking with new URL
364 self._active_gateways.discard(gateway.url)
365 self._active_gateways.add(gateway.url)
366 except Exception as e:
367 logger.warning(f"Failed to initialize updated gateway: {e}")
369 gateway.updated_at = datetime.now(timezone.utc)
370 db.commit()
371 db.refresh(gateway)
373 # Notify subscribers
374 await self._notify_gateway_updated(gateway)
376 logger.info(f"Updated gateway: {gateway.name}")
377 return GatewayRead.model_validate(gateway)
379 except Exception as e:
380 db.rollback()
381 raise GatewayError(f"Failed to update gateway: {str(e)}")
383 async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = False) -> GatewayRead:
384 """Get a specific gateway by ID.
386 Args:
387 db: Database session
388 gateway_id: Gateway ID
389 include_inactive: Whether to include inactive gateways
391 Returns:
392 Gateway information
394 Raises:
395 GatewayNotFoundError: If gateway not found
396 """
397 gateway = db.get(DbGateway, gateway_id)
398 if not gateway:
399 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
401 if not gateway.enabled and not include_inactive:
402 raise GatewayNotFoundError(f"Gateway '{gateway.name}' exists but is inactive")
404 return GatewayRead.model_validate(gateway)
406 async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False) -> GatewayRead:
407 """Toggle gateway active status.
409 Args:
410 db: Database session
411 gateway_id: Gateway ID to toggle
412 activate: True to activate, False to deactivate
413 reachable: True if the gateway is reachable, False otherwise
414 only_update_reachable: If True, only updates reachable status without changing enabled status. Applicable for changing tool status. If the tool is manually deactivated, it will not be reactivated if reachable.
416 Returns:
417 Updated gateway information
419 Raises:
420 GatewayNotFoundError: If gateway not found
421 GatewayError: For other errors
422 """
423 try:
424 gateway = db.get(DbGateway, gateway_id)
425 if not gateway: 425 ↛ 426line 425 didn't jump to line 426 because the condition on line 425 was never true
426 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
428 # Update status if it's different
429 if (gateway.enabled != activate) or (gateway.reachable != reachable): 429 ↛ 489line 429 didn't jump to line 489 because the condition on line 429 was always true
430 gateway.enabled = activate
431 gateway.reachable = reachable
432 gateway.updated_at = datetime.now(timezone.utc)
434 # Update tracking
435 if activate and reachable: 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true
436 self._active_gateways.add(gateway.url)
437 # Try to initialize if activating
438 try:
439 capabilities, tools = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport)
440 new_tool_names = [tool.name for tool in tools]
442 for tool in tools:
443 existing_tool = db.execute(select(DbTool).where(DbTool.original_name == tool.name).where(DbTool.gateway_id == gateway_id)).scalar_one_or_none()
444 if not existing_tool:
445 gateway.tools.append(
446 DbTool(
447 original_name=tool.name,
448 original_name_slug=slugify(tool.name),
449 url=gateway.url,
450 description=tool.description,
451 integration_type=tool.integration_type,
452 request_type=tool.request_type,
453 headers=tool.headers,
454 input_schema=tool.input_schema,
455 jsonpath_filter=tool.jsonpath_filter,
456 auth_type=gateway.auth_type,
457 auth_value=gateway.auth_value,
458 )
459 )
461 gateway.capabilities = capabilities
462 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows
463 gateway.last_seen = datetime.now(timezone.utc)
464 except Exception as e:
465 logger.warning(f"Failed to initialize reactivated gateway: {e}")
466 else:
467 self._active_gateways.discard(gateway.url)
469 db.commit()
470 db.refresh(gateway)
472 tools = db.query(DbTool).filter(DbTool.gateway_id == gateway_id).all()
474 if only_update_reachable: 474 ↛ 475line 474 didn't jump to line 475 because the condition on line 474 was never true
475 for tool in tools:
476 await self.tool_service.toggle_tool_status(db, tool.id, tool.enabled, reachable)
477 else:
478 for tool in tools:
479 await self.tool_service.toggle_tool_status(db, tool.id, activate, reachable)
481 # Notify subscribers
482 if activate: 482 ↛ 483line 482 didn't jump to line 483 because the condition on line 482 was never true
483 await self._notify_gateway_activated(gateway)
484 else:
485 await self._notify_gateway_deactivated(gateway)
487 logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}")
489 return GatewayRead.model_validate(gateway)
491 except Exception as e:
492 db.rollback()
493 raise GatewayError(f"Failed to toggle gateway status: {str(e)}")
495 async def _notify_gateway_updated(self, gateway: DbGateway) -> None:
496 """
497 Notify subscribers of gateway update.
499 Args:
500 gateway: Gateway to update
501 """
502 event = {
503 "type": "gateway_updated",
504 "data": {
505 "id": gateway.id,
506 "name": gateway.name,
507 "url": gateway.url,
508 "description": gateway.description,
509 "enabled": gateway.enabled,
510 },
511 "timestamp": datetime.now(timezone.utc).isoformat(),
512 }
513 await self._publish_event(event)
515 async def delete_gateway(self, db: Session, gateway_id: str) -> None:
516 """Permanently delete a gateway.
518 Args:
519 db: Database session
520 gateway_id: Gateway ID to delete
522 Raises:
523 GatewayNotFoundError: If gateway not found
524 GatewayError: For other deletion errors
525 """
526 try:
527 # Find gateway
528 gateway = db.get(DbGateway, gateway_id)
529 if not gateway:
530 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
532 # Store gateway info for notification before deletion
533 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url}
535 # Hard delete gateway
536 db.delete(gateway)
537 db.commit()
539 # Update tracking
540 self._active_gateways.discard(gateway.url)
542 # Notify subscribers
543 await self._notify_gateway_deleted(gateway_info)
545 logger.info(f"Permanently deleted gateway: {gateway.name}")
547 except Exception as e:
548 db.rollback()
549 raise GatewayError(f"Failed to delete gateway: {str(e)}")
551 async def forward_request(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None) -> Any:
552 """Forward a request to a gateway.
554 Args:
555 gateway: Gateway to forward to
556 method: RPC method name
557 params: Optional method parameters
559 Returns:
560 Gateway response
562 Raises:
563 GatewayConnectionError: If forwarding fails
564 GatewayError: If gateway gave an error
565 """
566 if not gateway.enabled: 566 ↛ 567line 566 didn't jump to line 567 because the condition on line 566 was never true
567 raise GatewayConnectionError(f"Cannot forward request to inactive gateway: {gateway.name}")
569 try:
570 # Build RPC request
571 request = {"jsonrpc": "2.0", "id": 1, "method": method}
572 if params:
573 request["params"] = params
575 # Directly use the persistent HTTP client (no async with)
576 response = await self._http_client.post(f"{gateway.url}/rpc", json=request, headers=self._get_auth_headers())
577 response.raise_for_status()
578 result = response.json()
580 # Update last seen timestamp
581 gateway.last_seen = datetime.now(timezone.utc)
583 if "error" in result:
584 raise GatewayError(f"Gateway error: {result['error'].get('message')}")
585 return result.get("result")
587 except Exception as e:
588 raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}")
590 async def _handle_gateway_failure(self, gateway: str) -> None:
591 """
592 Tracks and handles gateway failures during health checks.
593 If the failure count exceeds the threshold, the gateway is deactivated.
595 Args:
596 gateway (str): The gateway object that failed its health check.
598 Returns:
599 None
600 """
601 if GW_FAILURE_THRESHOLD == -1:
602 return # Gateway failure action disabled
604 if not gateway.enabled:
605 return # No action needed for inactive gateways
607 if not gateway.reachable:
608 return # No action needed for unreachable gateways
610 count = self._gateway_failure_counts.get(gateway.id, 0) + 1
611 self._gateway_failure_counts[gateway.id] = count
613 logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).")
615 if count >= GW_FAILURE_THRESHOLD:
616 logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...")
617 with SessionLocal() as db:
618 await self.toggle_gateway_status(db, gateway.id, activate=True, reachable=False, only_update_reachable=True)
619 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation
621 async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
622 """Health check for a list of gateways.
624 Deactivates gateway if gateway is not healthy.
626 Args:
627 gateways (List[DbGateway]): List of gateways to check if healthy
629 Returns:
630 bool: True if all active gateways are healthy
631 """
632 # Reuse a single HTTP client for all requests
633 async with httpx.AsyncClient() as client:
634 for gateway in gateways:
635 logger.debug(f"Checking health of gateway: {gateway.name} ({gateway.url})")
636 try:
637 # Ensure auth_value is a dict
638 auth_data = gateway.auth_value or {}
639 headers = decode_auth(auth_data)
641 # Perform the GET and raise on 4xx/5xx
642 if (gateway.transport).lower() == "sse":
643 timeout = httpx.Timeout(settings.health_check_timeout)
644 async with client.stream("GET", gateway.url, headers=headers, timeout=timeout) as response:
645 # This will raise immediately if status is 4xx/5xx
646 response.raise_for_status()
647 elif (gateway.transport).lower() == "streamablehttp":
648 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.health_check_timeout) as (read_stream, write_stream, _get_session_id):
649 async with ClientSession(read_stream, write_stream) as session:
650 # Initialize the session
651 response = await session.initialize()
653 # Reactivate gateway if it was previously inactive and health check passed now
654 if gateway.enabled and not gateway.reachable:
655 with SessionLocal() as db:
656 logger.info(f"Reactivating gateway: {gateway.name}, as it is healthy now")
657 await self.toggle_gateway_status(db, gateway.id, activate=True, reachable=True, only_update_reachable=True)
659 # Mark successful check
660 gateway.last_seen = datetime.now(timezone.utc)
662 except Exception:
663 await self._handle_gateway_failure(gateway)
665 # All gateways passed
666 return True
668 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
669 """Aggregate capabilities from all gateways.
671 Args:
672 db: Database session
674 Returns:
675 Combined capabilities
676 """
677 capabilities = {
678 "prompts": {"listChanged": True},
679 "resources": {"subscribe": True, "listChanged": True},
680 "tools": {"listChanged": True},
681 "logging": {},
682 }
684 # Get all active gateways
685 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
687 # Combine capabilities
688 for gateway in gateways:
689 if gateway.capabilities:
690 for key, value in gateway.capabilities.items():
691 if key not in capabilities:
692 capabilities[key] = value
693 elif isinstance(value, dict):
694 capabilities[key].update(value)
696 return capabilities
698 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
699 """Subscribe to gateway events.
701 Yields:
702 Gateway event messages
703 """
704 queue: asyncio.Queue = asyncio.Queue()
705 self._event_subscribers.append(queue)
706 try:
707 while True:
708 event = await queue.get()
709 yield event
710 finally:
711 self._event_subscribers.remove(queue)
713 async def _initialize_gateway(self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str = "SSE") -> Any:
714 """Initialize connection to a gateway and retrieve its capabilities.
716 Args:
717 url: Gateway URL
718 authentication: Optional authentication headers
719 transport: Transport type ("SSE" or "StreamableHTTP")
721 Returns:
722 Capabilities dictionary as provided by the gateway.
724 Raises:
725 GatewayConnectionError: If initialization fails.
726 """
727 try:
728 if authentication is None:
729 authentication = {}
731 async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
732 """
733 Connect to an MCP server running with SSE transport
735 Args:
736 server_url: URL to connect to the server
737 authentication: Authentication headers for connection to URL
739 Returns:
740 list, list: List of capabilities and tools
741 """
742 if authentication is None:
743 authentication = {}
744 # Store the context managers so they stay alive
745 decoded_auth = decode_auth(authentication)
747 # Use async with for both sse_client and ClientSession
748 async with sse_client(url=server_url, headers=decoded_auth) as streams:
749 async with ClientSession(*streams) as session:
750 # Initialize the session
751 response = await session.initialize()
752 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
754 response = await session.list_tools()
755 tools = response.tools
756 tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
757 tools = [ToolCreate.model_validate(tool) for tool in tools]
759 return capabilities, tools
761 async def connect_to_streamablehttp_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
762 """
763 Connect to an MCP server running with Streamable HTTP transport
765 Args:
766 server_url: URL to connect to the server
767 authentication: Authentication headers for connection to URL
769 Returns:
770 list, list: List of capabilities and tools
771 """
772 if authentication is None:
773 authentication = {}
774 # Store the context managers so they stay alive
775 decoded_auth = decode_auth(authentication)
777 # Use async with for both streamablehttp_client and ClientSession
778 async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, _get_session_id):
779 async with ClientSession(read_stream, write_stream) as session:
780 # Initialize the session
781 response = await session.initialize()
782 # if get_session_id:
783 # session_id = get_session_id()
784 # if session_id:
785 # print(f"Session ID: {session_id}")
786 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
787 response = await session.list_tools()
788 tools = response.tools
789 tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
790 tools = [ToolCreate.model_validate(tool) for tool in tools]
791 for tool in tools:
792 tool.request_type = "STREAMABLEHTTP"
794 return capabilities, tools
796 capabilities = {}
797 tools = []
798 if transport.lower() == "sse":
799 capabilities, tools = await connect_to_sse_server(url, authentication)
800 elif transport.lower() == "streamablehttp":
801 capabilities, tools = await connect_to_streamablehttp_server(url, authentication)
803 return capabilities, tools
804 except Exception as e:
805 raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}")
807 def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]:
808 """Sync function for database operations (runs in thread).
810 Args:
811 include_inactive: Whether to include inactive gateways
813 Returns:
814 List[DbGateway]: List of active gateways
815 """
816 with SessionLocal() as db:
817 if include_inactive:
818 return db.execute(select(DbGateway)).scalars().all()
819 # Only return active gateways
820 return db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all()
822 async def _run_health_checks(self) -> None:
823 """Run health checks periodically,
824 Uses Redis or FileLock - for multiple workers.
825 Uses simple health check for single worker mode."""
827 while True:
828 try:
829 if self._redis_client and settings.cache_type == "redis":
830 # Redis-based leader check
831 current_leader = self._redis_client.get(self._leader_key)
832 if current_leader != self._instance_id.encode():
833 return
834 self._redis_client.expire(self._leader_key, self._leader_ttl)
836 # Run health checks
837 gateways = await asyncio.to_thread(self._get_gateways)
838 if gateways:
839 await self.check_health_of_gateways(gateways)
841 await asyncio.sleep(self._health_check_interval)
843 elif settings.cache_type == "none":
844 try:
845 # For single worker mode, run health checks directly
846 gateways = await asyncio.to_thread(self._get_gateways)
847 if gateways:
848 await self.check_health_of_gateways(gateways)
849 except Exception as e:
850 logger.error(f"Health check run failed: {str(e)}")
852 await asyncio.sleep(self._health_check_interval)
854 else:
855 # FileLock-based leader fallback
856 try:
857 self._file_lock.acquire(timeout=0)
858 logger.info("File lock acquired. Running health checks.")
860 while True:
861 gateways = await asyncio.to_thread(self._get_gateways)
862 if gateways:
863 await self.check_health_of_gateways(gateways)
864 await asyncio.sleep(self._health_check_interval)
866 except Timeout:
867 logger.debug("File lock already held. Retrying later.")
868 await asyncio.sleep(self._health_check_interval)
870 except Exception as e:
871 logger.error(f"FileLock health check failed: {str(e)}")
873 finally:
874 if self._file_lock.is_locked:
875 try:
876 self._file_lock.release()
877 logger.info("Released file lock.")
878 except Exception as e:
879 logger.warning(f"Failed to release file lock: {str(e)}")
881 except Exception as e:
882 logger.error(f"Unexpected error in health check loop: {str(e)}")
883 await asyncio.sleep(self._health_check_interval)
885 def _get_auth_headers(self) -> Dict[str, str]:
886 """
887 Get headers for gateway authentication.
889 Returns:
890 dict: Authorization header dict
891 """
892 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}"
893 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key, "Content-Type": "application/json"}
895 async def _notify_gateway_added(self, gateway: DbGateway) -> None:
896 """
897 Notify subscribers of gateway addition.
899 Args:
900 gateway: Gateway to add
901 """
902 event = {
903 "type": "gateway_added",
904 "data": {
905 "id": gateway.id,
906 "name": gateway.name,
907 "url": gateway.url,
908 "description": gateway.description,
909 "enabled": gateway.enabled,
910 },
911 "timestamp": datetime.now(timezone.utc).isoformat(),
912 }
913 await self._publish_event(event)
915 async def _notify_gateway_activated(self, gateway: DbGateway) -> None:
916 """
917 Notify subscribers of gateway activation.
919 Args:
920 gateway: Gateway to activate
921 """
922 event = {
923 "type": "gateway_activated",
924 "data": {
925 "id": gateway.id,
926 "name": gateway.name,
927 "url": gateway.url,
928 "enabled": gateway.enabled,
929 },
930 "timestamp": datetime.now(timezone.utc).isoformat(),
931 }
932 await self._publish_event(event)
934 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None:
935 """
936 Notify subscribers of gateway deactivation.
938 Args:
939 gateway: Gateway database object
940 """
941 event = {
942 "type": "gateway_deactivated",
943 "data": {
944 "id": gateway.id,
945 "name": gateway.name,
946 "url": gateway.url,
947 "enabled": gateway.enabled,
948 },
949 "timestamp": datetime.now(timezone.utc).isoformat(),
950 }
951 await self._publish_event(event)
953 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None:
954 """
955 Notify subscribers of gateway deletion.
957 Args:
958 gateway_info: Dict containing information about gateway to delete
959 """
960 event = {
961 "type": "gateway_deleted",
962 "data": gateway_info,
963 "timestamp": datetime.now(timezone.utc).isoformat(),
964 }
965 await self._publish_event(event)
967 async def _notify_gateway_removed(self, gateway: DbGateway) -> None:
968 """
969 Notify subscribers of gateway removal (deactivation).
971 Args:
972 gateway: Gateway to remove
973 """
974 event = {
975 "type": "gateway_removed",
976 "data": {"id": gateway.id, "name": gateway.name, "enabled": gateway.enabled},
977 "timestamp": datetime.now(timezone.utc).isoformat(),
978 }
979 await self._publish_event(event)
981 async def _publish_event(self, event: Dict[str, Any]) -> None:
982 """
983 Publish event to all subscribers.
985 Args:
986 event: event dictionary
987 """
988 for queue in self._event_subscribers:
989 await queue.put(event)