Coverage for mcpgateway/transports/websocket_transport.py: 97%
80 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""WebSocket Transport Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements WebSocket transport for MCP, providing
9full-duplex communication between client and server.
10"""
12# Standard
13import asyncio
14import logging
15from typing import Any, AsyncGenerator, Dict, Optional
17# Third-Party
18from fastapi import WebSocket, WebSocketDisconnect
20# First-Party
21from mcpgateway.config import settings
22from mcpgateway.transports.base import Transport
24logger = logging.getLogger(__name__)
27class WebSocketTransport(Transport):
28 """Transport implementation using WebSocket."""
30 def __init__(self, websocket: WebSocket):
31 """Initialize WebSocket transport.
33 Args:
34 websocket: FastAPI WebSocket connection
35 """
36 self._websocket = websocket
37 self._connected = False
38 self._ping_task: Optional[asyncio.Task] = None
40 async def connect(self) -> None:
41 """Set up WebSocket connection."""
42 await self._websocket.accept()
43 self._connected = True
45 # Start ping task
46 if settings.websocket_ping_interval > 0: 46 ↛ 49line 46 didn't jump to line 49 because the condition on line 46 was always true
47 self._ping_task = asyncio.create_task(self._ping_loop())
49 logger.info("WebSocket transport connected")
51 async def disconnect(self) -> None:
52 """Clean up WebSocket connection."""
53 try:
54 loop = asyncio.get_running_loop()
55 except RuntimeError:
56 # No running loop (interpreter shutdown, for example)
57 return
59 if loop.is_closed(): 59 ↛ 61line 59 didn't jump to line 61 because the condition on line 59 was never true
60 # The loop is already closed - further asyncio calls are illegal
61 return
63 ping_task = getattr(self, "_ping_task", None)
65 should_cancel = ping_task and not ping_task.done() and ping_task is not asyncio.current_task() # task exists # still running # not *this* coroutine
67 if should_cancel:
68 ping_task.cancel()
69 try:
70 await ping_task # allow it to exit gracefully
71 except asyncio.CancelledError:
72 pass
74 # ────────────────────────────────────────────────────────────────
75 # 3. Close the WebSocket connection (if still open)
76 # ────────────────────────────────────────────────────────────────
77 if getattr(self, "_connected", False):
78 try:
79 await self._websocket.close()
80 finally:
81 self._connected = False
82 logger.info("WebSocket transport disconnected")
84 async def send_message(self, message: Dict[str, Any]) -> None:
85 """Send a message over WebSocket.
87 Args:
88 message: Message to send
90 Raises:
91 RuntimeError: If transport is not connected
92 Exception: If unable to send json to websocket
93 """
94 if not self._connected:
95 raise RuntimeError("Transport not connected")
97 try:
98 await self._websocket.send_json(message)
99 except Exception as e:
100 logger.error(f"Failed to send message: {e}")
101 raise
103 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]:
104 """Receive messages from WebSocket.
106 Yields:
107 Received messages
109 Raises:
110 RuntimeError: If transport is not connected
111 """
112 if not self._connected:
113 raise RuntimeError("Transport not connected")
115 try:
116 while True:
117 message = await self._websocket.receive_json()
118 yield message
120 except WebSocketDisconnect:
121 logger.info("WebSocket client disconnected")
122 self._connected = False
123 except Exception as e:
124 logger.error(f"Error receiving message: {e}")
125 self._connected = False
126 finally:
127 await self.disconnect()
129 async def is_connected(self) -> bool:
130 """Check if transport is connected.
132 Returns:
133 True if connected
134 """
135 return self._connected
137 async def _ping_loop(self) -> None:
138 """Send periodic ping messages to keep connection alive."""
139 try:
140 while self._connected:
141 await asyncio.sleep(settings.websocket_ping_interval)
142 await self._websocket.send_bytes(b"ping")
143 try:
144 resp = await asyncio.wait_for(
145 self._websocket.receive_bytes(),
146 timeout=settings.websocket_ping_interval / 2,
147 )
148 if resp != b"pong":
149 logger.warning("Invalid ping response")
150 except asyncio.TimeoutError:
151 logger.warning("Ping timeout")
152 break
153 except Exception as e:
154 logger.error(f"Ping loop error: {e}")
155 finally:
156 await self.disconnect()
158 async def send_ping(self) -> None:
159 """Send a manual ping message."""
160 if self._connected:
161 await self._websocket.send_bytes(b"ping")