Coverage for mcpgateway/transports/websocket_transport.py: 97%

80 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""WebSocket Transport Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements WebSocket transport for MCP, providing 

9full-duplex communication between client and server. 

10""" 

11 

12# Standard 

13import asyncio 

14import logging 

15from typing import Any, AsyncGenerator, Dict, Optional 

16 

17# Third-Party 

18from fastapi import WebSocket, WebSocketDisconnect 

19 

20# First-Party 

21from mcpgateway.config import settings 

22from mcpgateway.transports.base import Transport 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27class WebSocketTransport(Transport): 

28 """Transport implementation using WebSocket.""" 

29 

30 def __init__(self, websocket: WebSocket): 

31 """Initialize WebSocket transport. 

32 

33 Args: 

34 websocket: FastAPI WebSocket connection 

35 """ 

36 self._websocket = websocket 

37 self._connected = False 

38 self._ping_task: Optional[asyncio.Task] = None 

39 

40 async def connect(self) -> None: 

41 """Set up WebSocket connection.""" 

42 await self._websocket.accept() 

43 self._connected = True 

44 

45 # Start ping task 

46 if settings.websocket_ping_interval > 0: 46 ↛ 49line 46 didn't jump to line 49 because the condition on line 46 was always true

47 self._ping_task = asyncio.create_task(self._ping_loop()) 

48 

49 logger.info("WebSocket transport connected") 

50 

51 async def disconnect(self) -> None: 

52 """Clean up WebSocket connection.""" 

53 try: 

54 loop = asyncio.get_running_loop() 

55 except RuntimeError: 

56 # No running loop (interpreter shutdown, for example) 

57 return 

58 

59 if loop.is_closed(): 59 ↛ 61line 59 didn't jump to line 61 because the condition on line 59 was never true

60 # The loop is already closed - further asyncio calls are illegal 

61 return 

62 

63 ping_task = getattr(self, "_ping_task", None) 

64 

65 should_cancel = ping_task and not ping_task.done() and ping_task is not asyncio.current_task() # task exists # still running # not *this* coroutine 

66 

67 if should_cancel: 

68 ping_task.cancel() 

69 try: 

70 await ping_task # allow it to exit gracefully 

71 except asyncio.CancelledError: 

72 pass 

73 

74 # ──────────────────────────────────────────────────────────────── 

75 # 3. Close the WebSocket connection (if still open) 

76 # ──────────────────────────────────────────────────────────────── 

77 if getattr(self, "_connected", False): 

78 try: 

79 await self._websocket.close() 

80 finally: 

81 self._connected = False 

82 logger.info("WebSocket transport disconnected") 

83 

84 async def send_message(self, message: Dict[str, Any]) -> None: 

85 """Send a message over WebSocket. 

86 

87 Args: 

88 message: Message to send 

89 

90 Raises: 

91 RuntimeError: If transport is not connected 

92 Exception: If unable to send json to websocket 

93 """ 

94 if not self._connected: 

95 raise RuntimeError("Transport not connected") 

96 

97 try: 

98 await self._websocket.send_json(message) 

99 except Exception as e: 

100 logger.error(f"Failed to send message: {e}") 

101 raise 

102 

103 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]: 

104 """Receive messages from WebSocket. 

105 

106 Yields: 

107 Received messages 

108 

109 Raises: 

110 RuntimeError: If transport is not connected 

111 """ 

112 if not self._connected: 

113 raise RuntimeError("Transport not connected") 

114 

115 try: 

116 while True: 

117 message = await self._websocket.receive_json() 

118 yield message 

119 

120 except WebSocketDisconnect: 

121 logger.info("WebSocket client disconnected") 

122 self._connected = False 

123 except Exception as e: 

124 logger.error(f"Error receiving message: {e}") 

125 self._connected = False 

126 finally: 

127 await self.disconnect() 

128 

129 async def is_connected(self) -> bool: 

130 """Check if transport is connected. 

131 

132 Returns: 

133 True if connected 

134 """ 

135 return self._connected 

136 

137 async def _ping_loop(self) -> None: 

138 """Send periodic ping messages to keep connection alive.""" 

139 try: 

140 while self._connected: 

141 await asyncio.sleep(settings.websocket_ping_interval) 

142 await self._websocket.send_bytes(b"ping") 

143 try: 

144 resp = await asyncio.wait_for( 

145 self._websocket.receive_bytes(), 

146 timeout=settings.websocket_ping_interval / 2, 

147 ) 

148 if resp != b"pong": 

149 logger.warning("Invalid ping response") 

150 except asyncio.TimeoutError: 

151 logger.warning("Ping timeout") 

152 break 

153 except Exception as e: 

154 logger.error(f"Ping loop error: {e}") 

155 finally: 

156 await self.disconnect() 

157 

158 async def send_ping(self) -> None: 

159 """Send a manual ping message.""" 

160 if self._connected: 

161 await self._websocket.send_bytes(b"ping")