Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter-driver-network/jumpstarter_driver_network/streams/websocket.py: 46%

61 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 15:50 +0200

1from contextlib import suppress 

2from dataclasses import dataclass, field 

3from typing import Tuple 

4 

5from anyio import BrokenResourceError, WouldBlock, create_memory_object_stream 

6from anyio.abc import AnyByteStream, ObjectStream 

7from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream 

8from websockets.asyncio.client import ClientConnection as WSSClientConnection 

9from wsproto import ConnectionType, WSConnection 

10from wsproto.connection import ConnectionState 

11from wsproto.events import ( 

12 AcceptConnection, 

13 CloseConnection, 

14 Message, 

15 Ping, 

16 Request, 

17) 

18from wsproto.frame_protocol import CloseReason 

19from wsproto.utilities import LocalProtocolError, RemoteProtocolError 

20 

21 

22@dataclass(kw_only=True) 

23class WebsocketServerStream(ObjectStream[bytes]): 

24 stream: AnyByteStream 

25 

26 ws: WSConnection = field(init=False, default_factory=lambda: WSConnection(ConnectionType.SERVER)) 

27 queue: Tuple[MemoryObjectSendStream[bytes], MemoryObjectReceiveStream[bytes]] = field( 

28 init=False, 

29 default_factory=lambda: create_memory_object_stream[bytes](32), # ty: ignore[call-non-callable] 

30 ) 

31 

32 async def send(self, data: bytes) -> None: 

33 try: 

34 self.ws.receive_data(data) 

35 except RemoteProtocolError as e: 

36 raise BrokenResourceError from e 

37 

38 try: 

39 for event in self.ws.events(): 

40 match event: 

41 case Request(): 

42 await self.queue[0].send(self.ws.send(AcceptConnection())) 

43 case CloseConnection(): 

44 await self.queue[0].send(self.ws.send(event.response())) 

45 case Message(): 

46 await self.stream.send(event.data) 

47 case Ping(): 

48 await self.queue[0].send(self.ws.send(event.response())) 

49 except LocalProtocolError as e: 

50 raise BrokenResourceError from e 

51 

52 async def receive(self) -> bytes: 

53 with suppress(WouldBlock): 

54 return self.queue[1].receive_nowait() 

55 

56 if self.ws.state == ConnectionState.CONNECTING: 

57 return await self.queue[1].receive() 

58 

59 try: 

60 return self.ws.send(Message(data=await self.stream.receive())) 

61 except LocalProtocolError as e: 

62 raise BrokenResourceError from e 

63 

64 async def send_eof(self): 

65 # websocket does not have half closed connections 

66 pass 

67 

68 async def aclose(self): 

69 with suppress(LocalProtocolError): 

70 await self.stream.send(self.ws.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE))) 

71 await self.stream.aclose() 

72 

73 

74@dataclass(kw_only=True) 

75class WebsocketClientStream(ObjectStream[bytes]): 

76 """ 

77 Websocket client streaming. 

78 """ 

79 

80 conn: WSSClientConnection 

81 

82 async def send(self, data: bytes) -> None: 

83 await self.conn.send(data) 

84 

85 async def receive(self) -> bytes: 

86 return await self.conn.recv() 

87 

88 async def send_eof(self): 

89 pass 

90 

91 async def aclose(self): 

92 await self.conn.close()