Coverage for mcpgateway/services/gateway_service.py: 44%

420 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""Gateway Service Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements gateway federation according to the MCP specification. 

9It handles: 

10- Gateway discovery and registration 

11- Request forwarding 

12- Capability aggregation 

13- Health monitoring 

14- Active/inactive gateway management 

15""" 

16 

17# Standard 

18import asyncio 

19from datetime import datetime, timezone 

20import logging 

21import os 

22import tempfile 

23from typing import Any, AsyncGenerator, Dict, List, Optional, Set 

24import uuid 

25 

26# Third-Party 

27from filelock import FileLock, Timeout 

28import httpx 

29from mcp import ClientSession 

30from mcp.client.sse import sse_client 

31from mcp.client.streamable_http import streamablehttp_client 

32from sqlalchemy import select 

33from sqlalchemy.orm import Session 

34 

35# First-Party 

36from mcpgateway.config import settings 

37from mcpgateway.db import Gateway as DbGateway 

38from mcpgateway.db import SessionLocal 

39from mcpgateway.db import Tool as DbTool 

40from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, ToolCreate 

41from mcpgateway.services.tool_service import ToolService 

42from mcpgateway.utils.create_slug import slugify 

43from mcpgateway.utils.services_auth import decode_auth 

44 

45try: 

46 # Third-Party 

47 import redis 

48 

49 REDIS_AVAILABLE = True 

50except ImportError: 

51 REDIS_AVAILABLE = False 

52 logging.info("Redis is not utilized in this environment.") 

53 

54# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks 

55logger = logging.getLogger(__name__) 

56 

57 

58GW_FAILURE_THRESHOLD = settings.unhealthy_threshold 

59GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval 

60 

61 

62class GatewayError(Exception): 

63 """Base class for gateway-related errors.""" 

64 

65 

66class GatewayNotFoundError(GatewayError): 

67 """Raised when a requested gateway is not found.""" 

68 

69 

70class GatewayNameConflictError(GatewayError): 

71 """Raised when a gateway name conflicts with existing (active or inactive) gateway.""" 

72 

73 def __init__(self, name: str, enabled: bool = True, gateway_id: Optional[int] = None): 

74 """Initialize the error with gateway information. 

75 

76 Args: 

77 name: The conflicting gateway name 

78 enabled: Whether the existing gateway is enabled 

79 gateway_id: ID of the existing gateway if available 

80 """ 

81 self.name = name 

82 self.enabled = enabled 

83 self.gateway_id = gateway_id 

84 message = f"Gateway already exists with name: {name}" 

85 if not enabled: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true

86 message += f" (currently inactive, ID: {gateway_id})" 

87 super().__init__(message) 

88 

89 

90class GatewayConnectionError(GatewayError): 

91 """Raised when gateway connection fails.""" 

92 

93 

94class GatewayService: 

95 """Service for managing federated gateways. 

96 

97 Handles: 

98 - Gateway registration and health checks 

99 - Request forwarding 

100 - Capability negotiation 

101 - Federation events 

102 - Active/inactive status management 

103 """ 

104 

105 def __init__(self) -> None: 

106 """Initialize the gateway service.""" 

107 self._event_subscribers: List[asyncio.Queue] = [] 

108 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

109 self._health_check_interval = GW_HEALTH_CHECK_INTERVAL 

110 self._health_check_task: Optional[asyncio.Task] = None 

111 self._active_gateways: Set[str] = set() # Track active gateway URLs 

112 self._stream_response = None 

113 self._pending_responses = {} 

114 self.tool_service = ToolService() 

115 self._gateway_failure_counts: dict[str, int] = {} 

116 

117 # For health checks, we determine the leader instance. 

118 self.redis_url = settings.redis_url if settings.cache_type == "redis" else None 

119 

120 if self.redis_url and REDIS_AVAILABLE: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true

121 self._redis_client = redis.from_url(self.redis_url) 

122 self._instance_id = str(uuid.uuid4()) # Unique ID for this process 

123 self._leader_key = "gateway_service_leader" 

124 self._leader_ttl = 40 # seconds 

125 elif settings.cache_type != "none": 125 ↛ 137line 125 didn't jump to line 137 because the condition on line 125 was always true

126 # Fallback: File-based lock 

127 self._redis_client = None 

128 

129 temp_dir = tempfile.gettempdir() 

130 user_path = os.path.normpath(settings.filelock_name) 

131 if os.path.isabs(user_path): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true

132 user_path = os.path.relpath(user_path, start=os.path.splitdrive(user_path)[0] + os.sep) 

133 full_path = os.path.join(temp_dir, user_path) 

134 self._lock_path = full_path.replace("\\", "/") 

135 self._file_lock = FileLock(self._lock_path) 

136 else: 

137 self._redis_client = None 

138 

139 async def initialize(self) -> None: 

140 """Initialize the service and start health check if this instance is the leader. 

141 

142 Raises: 

143 ConnectionError: When redis ping fails 

144 """ 

145 logger.info("Initializing gateway service") 

146 

147 if self._redis_client: 

148 # Check if Redis is available 

149 pong = self._redis_client.ping() 

150 if not pong: 

151 raise ConnectionError("Redis ping failed.") 

152 

153 is_leader = self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True) 

154 if is_leader: 

155 logger.info("Acquired Redis leadership. Starting health check task.") 

156 self._health_check_task = asyncio.create_task(self._run_health_checks()) 

157 else: 

158 # Always create the health check task in filelock mode; leader check is handled inside. 

159 self._health_check_task = asyncio.create_task(self._run_health_checks()) 

160 

161 async def shutdown(self) -> None: 

162 """Shutdown the service.""" 

163 if self._health_check_task: 

164 self._health_check_task.cancel() 

165 try: 

166 await self._health_check_task 

167 except asyncio.CancelledError: 

168 pass 

169 

170 await self._http_client.aclose() 

171 self._event_subscribers.clear() 

172 self._active_gateways.clear() 

173 logger.info("Gateway service shutdown complete") 

174 

175 async def register_gateway(self, db: Session, gateway: GatewayCreate) -> GatewayRead: 

176 """Register a new gateway. 

177 

178 Args: 

179 db: Database session 

180 gateway: Gateway creation schema 

181 

182 Returns: 

183 Created gateway information 

184 

185 Raises: 

186 GatewayNameConflictError: If gateway name already exists 

187 []: When ExceptionGroup found 

188 """ 

189 try: 

190 # Check for name conflicts (both active and inactive) 

191 existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none() 

192 

193 if existing_gateway: 

194 raise GatewayNameConflictError( 

195 gateway.name, 

196 enabled=existing_gateway.enabled, 

197 gateway_id=existing_gateway.id, 

198 ) 

199 

200 auth_type = getattr(gateway, "auth_type", None) 

201 auth_value = getattr(gateway, "auth_value", {}) 

202 

203 capabilities, tools = await self._initialize_gateway(gateway.url, auth_value, gateway.transport) 

204 

205 tools = [ 

206 DbTool( 

207 original_name=tool.name, 

208 original_name_slug=slugify(tool.name), 

209 url=gateway.url, 

210 description=tool.description, 

211 integration_type=tool.integration_type, 

212 request_type=tool.request_type, 

213 headers=tool.headers, 

214 input_schema=tool.input_schema, 

215 annotations=tool.annotations, 

216 jsonpath_filter=tool.jsonpath_filter, 

217 auth_type=auth_type, 

218 auth_value=auth_value, 

219 ) 

220 for tool in tools 

221 ] 

222 

223 # Create DB model 

224 db_gateway = DbGateway( 

225 name=gateway.name, 

226 slug=slugify(gateway.name), 

227 url=gateway.url, 

228 description=gateway.description, 

229 transport=gateway.transport, 

230 capabilities=capabilities, 

231 last_seen=datetime.now(timezone.utc), 

232 auth_type=auth_type, 

233 auth_value=auth_value, 

234 tools=tools, 

235 ) 

236 

237 # Add to DB 

238 db.add(db_gateway) 

239 db.commit() 

240 db.refresh(db_gateway) 

241 

242 # Update tracking 

243 self._active_gateways.add(db_gateway.url) 

244 

245 # Notify subscribers 

246 await self._notify_gateway_added(db_gateway) 

247 

248 return GatewayRead.model_validate(gateway) 

249 except* GatewayConnectionError as ge: 

250 logger.error("GatewayConnectionError in group: %s", ge.exceptions) 

251 raise ge.exceptions[0] 

252 except* ValueError as ve: 

253 logger.error("ValueErrors in group: %s", ve.exceptions) 

254 raise ve.exceptions[0] 

255 except* RuntimeError as re: 

256 logger.error("RuntimeErrors in group: %s", re.exceptions) 

257 raise re.exceptions[0] 

258 except* BaseException as other: # catches every other sub-exception 

259 logger.error("Other grouped errors: %s", other.exceptions) 

260 raise other.exceptions[0] 

261 

262 async def list_gateways(self, db: Session, include_inactive: bool = False) -> List[GatewayRead]: 

263 """List all registered gateways. 

264 

265 Args: 

266 db: Database session 

267 include_inactive: Whether to include inactive gateways 

268 

269 Returns: 

270 List of registered gateways 

271 """ 

272 query = select(DbGateway) 

273 

274 if not include_inactive: 274 ↛ 277line 274 didn't jump to line 277 because the condition on line 274 was always true

275 query = query.where(DbGateway.enabled) 

276 

277 gateways = db.execute(query).scalars().all() 

278 return [GatewayRead.model_validate(g) for g in gateways] 

279 

280 async def update_gateway(self, db: Session, gateway_id: str, gateway_update: GatewayUpdate) -> GatewayRead: 

281 """Update a gateway. 

282 

283 Args: 

284 db: Database session 

285 gateway_id: Gateway ID to update 

286 gateway_update: Updated gateway data 

287 

288 Returns: 

289 Updated gateway information 

290 

291 Raises: 

292 GatewayNotFoundError: If gateway not found 

293 GatewayError: For other update errors 

294 GatewayNameConflictError: If gateway name conflict occurs 

295 """ 

296 try: 

297 # Find gateway 

298 gateway = db.get(DbGateway, gateway_id) 

299 if not gateway: 

300 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

301 

302 if not gateway.enabled: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true

303 raise GatewayNotFoundError(f"Gateway '{gateway.name}' exists but is inactive") 

304 

305 # Check for name conflicts if name is being changed 

306 if gateway_update.name is not None and gateway_update.name != gateway.name: 306 ↛ 317line 306 didn't jump to line 317 because the condition on line 306 was always true

307 existing_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway_update.name).where(DbGateway.id != gateway_id)).scalar_one_or_none() 

308 

309 if existing_gateway: 

310 raise GatewayNameConflictError( 

311 gateway_update.name, 

312 enabled=existing_gateway.enabled, 

313 gateway_id=existing_gateway.id, 

314 ) 

315 

316 # Update fields if provided 

317 if gateway_update.name is not None: 317 ↛ 320line 317 didn't jump to line 320 because the condition on line 317 was always true

318 gateway.name = gateway_update.name 

319 gateway.slug = slugify(gateway_update.name) 

320 if gateway_update.url is not None: 320 ↛ 322line 320 didn't jump to line 322 because the condition on line 320 was always true

321 gateway.url = gateway_update.url 

322 if gateway_update.description is not None: 322 ↛ 324line 322 didn't jump to line 324 because the condition on line 322 was always true

323 gateway.description = gateway_update.description 

324 if gateway_update.transport is not None: 324 ↛ 327line 324 didn't jump to line 327 because the condition on line 324 was always true

325 gateway.transport = gateway_update.transport 

326 

327 if getattr(gateway, "auth_type", None) is not None: 327 ↛ 335line 327 didn't jump to line 335 because the condition on line 327 was always true

328 gateway.auth_type = gateway_update.auth_type 

329 

330 # if auth_type is not None and only then check auth_value 

331 if getattr(gateway, "auth_value", {}) != {}: 331 ↛ 332line 331 didn't jump to line 332 because the condition on line 331 was never true

332 gateway.auth_value = gateway_update.auth_value 

333 

334 # Try to reinitialize connection if URL changed 

335 if gateway_update.url is not None: 335 ↛ 369line 335 didn't jump to line 369 because the condition on line 335 was always true

336 try: 

337 capabilities, tools = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport) 

338 new_tool_names = [tool.name for tool in tools] 

339 

340 for tool in tools: 340 ↛ 341line 340 didn't jump to line 341 because the loop on line 340 never started

341 existing_tool = db.execute(select(DbTool).where(DbTool.original_name == tool.name).where(DbTool.gateway_id == gateway_id)).scalar_one_or_none() 

342 if not existing_tool: 

343 gateway.tools.append( 

344 DbTool( 

345 original_name=tool.name, 

346 original_name_slug=slugify(tool.name), 

347 url=gateway.url, 

348 description=tool.description, 

349 integration_type=tool.integration_type, 

350 request_type=tool.request_type, 

351 headers=tool.headers, 

352 input_schema=tool.input_schema, 

353 jsonpath_filter=tool.jsonpath_filter, 

354 auth_type=gateway.auth_type, 

355 auth_value=gateway.auth_value, 

356 ) 

357 ) 

358 

359 gateway.capabilities = capabilities 

360 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows 

361 gateway.last_seen = datetime.now(timezone.utc) 

362 

363 # Update tracking with new URL 

364 self._active_gateways.discard(gateway.url) 

365 self._active_gateways.add(gateway.url) 

366 except Exception as e: 

367 logger.warning(f"Failed to initialize updated gateway: {e}") 

368 

369 gateway.updated_at = datetime.now(timezone.utc) 

370 db.commit() 

371 db.refresh(gateway) 

372 

373 # Notify subscribers 

374 await self._notify_gateway_updated(gateway) 

375 

376 logger.info(f"Updated gateway: {gateway.name}") 

377 return GatewayRead.model_validate(gateway) 

378 

379 except Exception as e: 

380 db.rollback() 

381 raise GatewayError(f"Failed to update gateway: {str(e)}") 

382 

383 async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = False) -> GatewayRead: 

384 """Get a specific gateway by ID. 

385 

386 Args: 

387 db: Database session 

388 gateway_id: Gateway ID 

389 include_inactive: Whether to include inactive gateways 

390 

391 Returns: 

392 Gateway information 

393 

394 Raises: 

395 GatewayNotFoundError: If gateway not found 

396 """ 

397 gateway = db.get(DbGateway, gateway_id) 

398 if not gateway: 

399 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

400 

401 if not gateway.enabled and not include_inactive: 

402 raise GatewayNotFoundError(f"Gateway '{gateway.name}' exists but is inactive") 

403 

404 return GatewayRead.model_validate(gateway) 

405 

406 async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False) -> GatewayRead: 

407 """Toggle gateway active status. 

408 

409 Args: 

410 db: Database session 

411 gateway_id: Gateway ID to toggle 

412 activate: True to activate, False to deactivate 

413 reachable: True if the gateway is reachable, False otherwise 

414 only_update_reachable: If True, only updates reachable status without changing enabled status. Applicable for changing tool status. If the tool is manually deactivated, it will not be reactivated if reachable. 

415 

416 Returns: 

417 Updated gateway information 

418 

419 Raises: 

420 GatewayNotFoundError: If gateway not found 

421 GatewayError: For other errors 

422 """ 

423 try: 

424 gateway = db.get(DbGateway, gateway_id) 

425 if not gateway: 425 ↛ 426line 425 didn't jump to line 426 because the condition on line 425 was never true

426 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

427 

428 # Update status if it's different 

429 if (gateway.enabled != activate) or (gateway.reachable != reachable): 429 ↛ 489line 429 didn't jump to line 489 because the condition on line 429 was always true

430 gateway.enabled = activate 

431 gateway.reachable = reachable 

432 gateway.updated_at = datetime.now(timezone.utc) 

433 

434 # Update tracking 

435 if activate and reachable: 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true

436 self._active_gateways.add(gateway.url) 

437 # Try to initialize if activating 

438 try: 

439 capabilities, tools = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport) 

440 new_tool_names = [tool.name for tool in tools] 

441 

442 for tool in tools: 

443 existing_tool = db.execute(select(DbTool).where(DbTool.original_name == tool.name).where(DbTool.gateway_id == gateway_id)).scalar_one_or_none() 

444 if not existing_tool: 

445 gateway.tools.append( 

446 DbTool( 

447 original_name=tool.name, 

448 original_name_slug=slugify(tool.name), 

449 url=gateway.url, 

450 description=tool.description, 

451 integration_type=tool.integration_type, 

452 request_type=tool.request_type, 

453 headers=tool.headers, 

454 input_schema=tool.input_schema, 

455 jsonpath_filter=tool.jsonpath_filter, 

456 auth_type=gateway.auth_type, 

457 auth_value=gateway.auth_value, 

458 ) 

459 ) 

460 

461 gateway.capabilities = capabilities 

462 gateway.tools = [tool for tool in gateway.tools if tool.original_name in new_tool_names] # keep only still-valid rows 

463 gateway.last_seen = datetime.now(timezone.utc) 

464 except Exception as e: 

465 logger.warning(f"Failed to initialize reactivated gateway: {e}") 

466 else: 

467 self._active_gateways.discard(gateway.url) 

468 

469 db.commit() 

470 db.refresh(gateway) 

471 

472 tools = db.query(DbTool).filter(DbTool.gateway_id == gateway_id).all() 

473 

474 if only_update_reachable: 474 ↛ 475line 474 didn't jump to line 475 because the condition on line 474 was never true

475 for tool in tools: 

476 await self.tool_service.toggle_tool_status(db, tool.id, tool.enabled, reachable) 

477 else: 

478 for tool in tools: 

479 await self.tool_service.toggle_tool_status(db, tool.id, activate, reachable) 

480 

481 # Notify subscribers 

482 if activate: 482 ↛ 483line 482 didn't jump to line 483 because the condition on line 482 was never true

483 await self._notify_gateway_activated(gateway) 

484 else: 

485 await self._notify_gateway_deactivated(gateway) 

486 

487 logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}") 

488 

489 return GatewayRead.model_validate(gateway) 

490 

491 except Exception as e: 

492 db.rollback() 

493 raise GatewayError(f"Failed to toggle gateway status: {str(e)}") 

494 

495 async def _notify_gateway_updated(self, gateway: DbGateway) -> None: 

496 """ 

497 Notify subscribers of gateway update. 

498 

499 Args: 

500 gateway: Gateway to update 

501 """ 

502 event = { 

503 "type": "gateway_updated", 

504 "data": { 

505 "id": gateway.id, 

506 "name": gateway.name, 

507 "url": gateway.url, 

508 "description": gateway.description, 

509 "enabled": gateway.enabled, 

510 }, 

511 "timestamp": datetime.now(timezone.utc).isoformat(), 

512 } 

513 await self._publish_event(event) 

514 

515 async def delete_gateway(self, db: Session, gateway_id: str) -> None: 

516 """Permanently delete a gateway. 

517 

518 Args: 

519 db: Database session 

520 gateway_id: Gateway ID to delete 

521 

522 Raises: 

523 GatewayNotFoundError: If gateway not found 

524 GatewayError: For other deletion errors 

525 """ 

526 try: 

527 # Find gateway 

528 gateway = db.get(DbGateway, gateway_id) 

529 if not gateway: 

530 raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") 

531 

532 # Store gateway info for notification before deletion 

533 gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url} 

534 

535 # Hard delete gateway 

536 db.delete(gateway) 

537 db.commit() 

538 

539 # Update tracking 

540 self._active_gateways.discard(gateway.url) 

541 

542 # Notify subscribers 

543 await self._notify_gateway_deleted(gateway_info) 

544 

545 logger.info(f"Permanently deleted gateway: {gateway.name}") 

546 

547 except Exception as e: 

548 db.rollback() 

549 raise GatewayError(f"Failed to delete gateway: {str(e)}") 

550 

551 async def forward_request(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None) -> Any: 

552 """Forward a request to a gateway. 

553 

554 Args: 

555 gateway: Gateway to forward to 

556 method: RPC method name 

557 params: Optional method parameters 

558 

559 Returns: 

560 Gateway response 

561 

562 Raises: 

563 GatewayConnectionError: If forwarding fails 

564 GatewayError: If gateway gave an error 

565 """ 

566 if not gateway.enabled: 566 ↛ 567line 566 didn't jump to line 567 because the condition on line 566 was never true

567 raise GatewayConnectionError(f"Cannot forward request to inactive gateway: {gateway.name}") 

568 

569 try: 

570 # Build RPC request 

571 request = {"jsonrpc": "2.0", "id": 1, "method": method} 

572 if params: 

573 request["params"] = params 

574 

575 # Directly use the persistent HTTP client (no async with) 

576 response = await self._http_client.post(f"{gateway.url}/rpc", json=request, headers=self._get_auth_headers()) 

577 response.raise_for_status() 

578 result = response.json() 

579 

580 # Update last seen timestamp 

581 gateway.last_seen = datetime.now(timezone.utc) 

582 

583 if "error" in result: 

584 raise GatewayError(f"Gateway error: {result['error'].get('message')}") 

585 return result.get("result") 

586 

587 except Exception as e: 

588 raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}") 

589 

590 async def _handle_gateway_failure(self, gateway: str) -> None: 

591 """ 

592 Tracks and handles gateway failures during health checks. 

593 If the failure count exceeds the threshold, the gateway is deactivated. 

594 

595 Args: 

596 gateway (str): The gateway object that failed its health check. 

597 

598 Returns: 

599 None 

600 """ 

601 if GW_FAILURE_THRESHOLD == -1: 

602 return # Gateway failure action disabled 

603 

604 if not gateway.enabled: 

605 return # No action needed for inactive gateways 

606 

607 if not gateway.reachable: 

608 return # No action needed for unreachable gateways 

609 

610 count = self._gateway_failure_counts.get(gateway.id, 0) + 1 

611 self._gateway_failure_counts[gateway.id] = count 

612 

613 logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).") 

614 

615 if count >= GW_FAILURE_THRESHOLD: 

616 logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...") 

617 with SessionLocal() as db: 

618 await self.toggle_gateway_status(db, gateway.id, activate=True, reachable=False, only_update_reachable=True) 

619 self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation 

620 

621 async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool: 

622 """Health check for a list of gateways. 

623 

624 Deactivates gateway if gateway is not healthy. 

625 

626 Args: 

627 gateways (List[DbGateway]): List of gateways to check if healthy 

628 

629 Returns: 

630 bool: True if all active gateways are healthy 

631 """ 

632 # Reuse a single HTTP client for all requests 

633 async with httpx.AsyncClient() as client: 

634 for gateway in gateways: 

635 logger.debug(f"Checking health of gateway: {gateway.name} ({gateway.url})") 

636 try: 

637 # Ensure auth_value is a dict 

638 auth_data = gateway.auth_value or {} 

639 headers = decode_auth(auth_data) 

640 

641 # Perform the GET and raise on 4xx/5xx 

642 if (gateway.transport).lower() == "sse": 

643 timeout = httpx.Timeout(settings.health_check_timeout) 

644 async with client.stream("GET", gateway.url, headers=headers, timeout=timeout) as response: 

645 # This will raise immediately if status is 4xx/5xx 

646 response.raise_for_status() 

647 elif (gateway.transport).lower() == "streamablehttp": 

648 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.health_check_timeout) as (read_stream, write_stream, _get_session_id): 

649 async with ClientSession(read_stream, write_stream) as session: 

650 # Initialize the session 

651 response = await session.initialize() 

652 

653 # Reactivate gateway if it was previously inactive and health check passed now 

654 if gateway.enabled and not gateway.reachable: 

655 with SessionLocal() as db: 

656 logger.info(f"Reactivating gateway: {gateway.name}, as it is healthy now") 

657 await self.toggle_gateway_status(db, gateway.id, activate=True, reachable=True, only_update_reachable=True) 

658 

659 # Mark successful check 

660 gateway.last_seen = datetime.now(timezone.utc) 

661 

662 except Exception: 

663 await self._handle_gateway_failure(gateway) 

664 

665 # All gateways passed 

666 return True 

667 

668 async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]: 

669 """Aggregate capabilities from all gateways. 

670 

671 Args: 

672 db: Database session 

673 

674 Returns: 

675 Combined capabilities 

676 """ 

677 capabilities = { 

678 "prompts": {"listChanged": True}, 

679 "resources": {"subscribe": True, "listChanged": True}, 

680 "tools": {"listChanged": True}, 

681 "logging": {}, 

682 } 

683 

684 # Get all active gateways 

685 gateways = db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

686 

687 # Combine capabilities 

688 for gateway in gateways: 

689 if gateway.capabilities: 

690 for key, value in gateway.capabilities.items(): 

691 if key not in capabilities: 

692 capabilities[key] = value 

693 elif isinstance(value, dict): 

694 capabilities[key].update(value) 

695 

696 return capabilities 

697 

698 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

699 """Subscribe to gateway events. 

700 

701 Yields: 

702 Gateway event messages 

703 """ 

704 queue: asyncio.Queue = asyncio.Queue() 

705 self._event_subscribers.append(queue) 

706 try: 

707 while True: 

708 event = await queue.get() 

709 yield event 

710 finally: 

711 self._event_subscribers.remove(queue) 

712 

713 async def _initialize_gateway(self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str = "SSE") -> Any: 

714 """Initialize connection to a gateway and retrieve its capabilities. 

715 

716 Args: 

717 url: Gateway URL 

718 authentication: Optional authentication headers 

719 transport: Transport type ("SSE" or "StreamableHTTP") 

720 

721 Returns: 

722 Capabilities dictionary as provided by the gateway. 

723 

724 Raises: 

725 GatewayConnectionError: If initialization fails. 

726 """ 

727 try: 

728 if authentication is None: 

729 authentication = {} 

730 

731 async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[str, str]] = None): 

732 """ 

733 Connect to an MCP server running with SSE transport 

734 

735 Args: 

736 server_url: URL to connect to the server 

737 authentication: Authentication headers for connection to URL 

738 

739 Returns: 

740 list, list: List of capabilities and tools 

741 """ 

742 if authentication is None: 

743 authentication = {} 

744 # Store the context managers so they stay alive 

745 decoded_auth = decode_auth(authentication) 

746 

747 # Use async with for both sse_client and ClientSession 

748 async with sse_client(url=server_url, headers=decoded_auth) as streams: 

749 async with ClientSession(*streams) as session: 

750 # Initialize the session 

751 response = await session.initialize() 

752 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

753 

754 response = await session.list_tools() 

755 tools = response.tools 

756 tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] 

757 tools = [ToolCreate.model_validate(tool) for tool in tools] 

758 

759 return capabilities, tools 

760 

761 async def connect_to_streamablehttp_server(server_url: str, authentication: Optional[Dict[str, str]] = None): 

762 """ 

763 Connect to an MCP server running with Streamable HTTP transport 

764 

765 Args: 

766 server_url: URL to connect to the server 

767 authentication: Authentication headers for connection to URL 

768 

769 Returns: 

770 list, list: List of capabilities and tools 

771 """ 

772 if authentication is None: 

773 authentication = {} 

774 # Store the context managers so they stay alive 

775 decoded_auth = decode_auth(authentication) 

776 

777 # Use async with for both streamablehttp_client and ClientSession 

778 async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, _get_session_id): 

779 async with ClientSession(read_stream, write_stream) as session: 

780 # Initialize the session 

781 response = await session.initialize() 

782 # if get_session_id: 

783 # session_id = get_session_id() 

784 # if session_id: 

785 # print(f"Session ID: {session_id}") 

786 capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) 

787 response = await session.list_tools() 

788 tools = response.tools 

789 tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] 

790 tools = [ToolCreate.model_validate(tool) for tool in tools] 

791 for tool in tools: 

792 tool.request_type = "STREAMABLEHTTP" 

793 

794 return capabilities, tools 

795 

796 capabilities = {} 

797 tools = [] 

798 if transport.lower() == "sse": 

799 capabilities, tools = await connect_to_sse_server(url, authentication) 

800 elif transport.lower() == "streamablehttp": 

801 capabilities, tools = await connect_to_streamablehttp_server(url, authentication) 

802 

803 return capabilities, tools 

804 except Exception as e: 

805 raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}") 

806 

807 def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]: 

808 """Sync function for database operations (runs in thread). 

809 

810 Args: 

811 include_inactive: Whether to include inactive gateways 

812 

813 Returns: 

814 List[DbGateway]: List of active gateways 

815 """ 

816 with SessionLocal() as db: 

817 if include_inactive: 

818 return db.execute(select(DbGateway)).scalars().all() 

819 # Only return active gateways 

820 return db.execute(select(DbGateway).where(DbGateway.enabled)).scalars().all() 

821 

822 async def _run_health_checks(self) -> None: 

823 """Run health checks periodically, 

824 Uses Redis or FileLock - for multiple workers. 

825 Uses simple health check for single worker mode.""" 

826 

827 while True: 

828 try: 

829 if self._redis_client and settings.cache_type == "redis": 

830 # Redis-based leader check 

831 current_leader = self._redis_client.get(self._leader_key) 

832 if current_leader != self._instance_id.encode(): 

833 return 

834 self._redis_client.expire(self._leader_key, self._leader_ttl) 

835 

836 # Run health checks 

837 gateways = await asyncio.to_thread(self._get_gateways) 

838 if gateways: 

839 await self.check_health_of_gateways(gateways) 

840 

841 await asyncio.sleep(self._health_check_interval) 

842 

843 elif settings.cache_type == "none": 

844 try: 

845 # For single worker mode, run health checks directly 

846 gateways = await asyncio.to_thread(self._get_gateways) 

847 if gateways: 

848 await self.check_health_of_gateways(gateways) 

849 except Exception as e: 

850 logger.error(f"Health check run failed: {str(e)}") 

851 

852 await asyncio.sleep(self._health_check_interval) 

853 

854 else: 

855 # FileLock-based leader fallback 

856 try: 

857 self._file_lock.acquire(timeout=0) 

858 logger.info("File lock acquired. Running health checks.") 

859 

860 while True: 

861 gateways = await asyncio.to_thread(self._get_gateways) 

862 if gateways: 

863 await self.check_health_of_gateways(gateways) 

864 await asyncio.sleep(self._health_check_interval) 

865 

866 except Timeout: 

867 logger.debug("File lock already held. Retrying later.") 

868 await asyncio.sleep(self._health_check_interval) 

869 

870 except Exception as e: 

871 logger.error(f"FileLock health check failed: {str(e)}") 

872 

873 finally: 

874 if self._file_lock.is_locked: 

875 try: 

876 self._file_lock.release() 

877 logger.info("Released file lock.") 

878 except Exception as e: 

879 logger.warning(f"Failed to release file lock: {str(e)}") 

880 

881 except Exception as e: 

882 logger.error(f"Unexpected error in health check loop: {str(e)}") 

883 await asyncio.sleep(self._health_check_interval) 

884 

885 def _get_auth_headers(self) -> Dict[str, str]: 

886 """ 

887 Get headers for gateway authentication. 

888 

889 Returns: 

890 dict: Authorization header dict 

891 """ 

892 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}" 

893 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key, "Content-Type": "application/json"} 

894 

895 async def _notify_gateway_added(self, gateway: DbGateway) -> None: 

896 """ 

897 Notify subscribers of gateway addition. 

898 

899 Args: 

900 gateway: Gateway to add 

901 """ 

902 event = { 

903 "type": "gateway_added", 

904 "data": { 

905 "id": gateway.id, 

906 "name": gateway.name, 

907 "url": gateway.url, 

908 "description": gateway.description, 

909 "enabled": gateway.enabled, 

910 }, 

911 "timestamp": datetime.now(timezone.utc).isoformat(), 

912 } 

913 await self._publish_event(event) 

914 

915 async def _notify_gateway_activated(self, gateway: DbGateway) -> None: 

916 """ 

917 Notify subscribers of gateway activation. 

918 

919 Args: 

920 gateway: Gateway to activate 

921 """ 

922 event = { 

923 "type": "gateway_activated", 

924 "data": { 

925 "id": gateway.id, 

926 "name": gateway.name, 

927 "url": gateway.url, 

928 "enabled": gateway.enabled, 

929 }, 

930 "timestamp": datetime.now(timezone.utc).isoformat(), 

931 } 

932 await self._publish_event(event) 

933 

934 async def _notify_gateway_deactivated(self, gateway: DbGateway) -> None: 

935 """ 

936 Notify subscribers of gateway deactivation. 

937 

938 Args: 

939 gateway: Gateway database object 

940 """ 

941 event = { 

942 "type": "gateway_deactivated", 

943 "data": { 

944 "id": gateway.id, 

945 "name": gateway.name, 

946 "url": gateway.url, 

947 "enabled": gateway.enabled, 

948 }, 

949 "timestamp": datetime.now(timezone.utc).isoformat(), 

950 } 

951 await self._publish_event(event) 

952 

953 async def _notify_gateway_deleted(self, gateway_info: Dict[str, Any]) -> None: 

954 """ 

955 Notify subscribers of gateway deletion. 

956 

957 Args: 

958 gateway_info: Dict containing information about gateway to delete 

959 """ 

960 event = { 

961 "type": "gateway_deleted", 

962 "data": gateway_info, 

963 "timestamp": datetime.now(timezone.utc).isoformat(), 

964 } 

965 await self._publish_event(event) 

966 

967 async def _notify_gateway_removed(self, gateway: DbGateway) -> None: 

968 """ 

969 Notify subscribers of gateway removal (deactivation). 

970 

971 Args: 

972 gateway: Gateway to remove 

973 """ 

974 event = { 

975 "type": "gateway_removed", 

976 "data": {"id": gateway.id, "name": gateway.name, "enabled": gateway.enabled}, 

977 "timestamp": datetime.now(timezone.utc).isoformat(), 

978 } 

979 await self._publish_event(event) 

980 

981 async def _publish_event(self, event: Dict[str, Any]) -> None: 

982 """ 

983 Publish event to all subscribers. 

984 

985 Args: 

986 event: event dictionary 

987 """ 

988 for queue in self._event_subscribers: 

989 await queue.put(event)