Coverage for /home/fedora/jumpstarter/packages/jumpstarter-driver-network/jumpstarter_driver_network/streams/websocket.py: 46%

61 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-05 20:29 +0000

1from contextlib import suppress 

2from dataclasses import dataclass, field 

3from typing import Tuple 

4 

5from anyio import BrokenResourceError, WouldBlock, create_memory_object_stream 

6from anyio.abc import AnyByteStream, ObjectStream 

7from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream 

8from websockets.asyncio.client import ClientConnection as WSSClientConnection 

9from wsproto import ConnectionType, WSConnection 

10from wsproto.connection import ConnectionState 

11from wsproto.events import ( 

12 AcceptConnection, 

13 CloseConnection, 

14 Message, 

15 Ping, 

16 Request, 

17) 

18from wsproto.frame_protocol import CloseReason 

19from wsproto.utilities import LocalProtocolError, RemoteProtocolError 

20 

21 

22@dataclass(kw_only=True) 

23class WebsocketServerStream(ObjectStream[bytes]): 

24 stream: AnyByteStream 

25 

26 ws: WSConnection = field(init=False, default_factory=lambda: WSConnection(ConnectionType.SERVER)) 

27 queue: Tuple[MemoryObjectSendStream[bytes], MemoryObjectReceiveStream[bytes]] = field( 

28 init=False, default_factory=lambda: create_memory_object_stream[bytes](32) 

29 ) 

30 

31 async def send(self, data: bytes) -> None: 

32 try: 

33 self.ws.receive_data(data) 

34 except RemoteProtocolError as e: 

35 raise BrokenResourceError from e 

36 

37 try: 

38 for event in self.ws.events(): 

39 match event: 

40 case Request(): 

41 await self.queue[0].send(self.ws.send(AcceptConnection())) 

42 case CloseConnection(): 

43 await self.queue[0].send(self.ws.send(event.response())) 

44 case Message(): 

45 await self.stream.send(event.data) 

46 case Ping(): 

47 await self.queue[0].send(self.ws.send(event.response())) 

48 except LocalProtocolError as e: 

49 raise BrokenResourceError from e 

50 

51 async def receive(self) -> bytes: 

52 with suppress(WouldBlock): 

53 return self.queue[1].receive_nowait() 

54 

55 if self.ws.state == ConnectionState.CONNECTING: 

56 return await self.queue[1].receive() 

57 

58 try: 

59 return self.ws.send(Message(data=await self.stream.receive())) 

60 except LocalProtocolError as e: 

61 raise BrokenResourceError from e 

62 

63 async def send_eof(self): 

64 # websocket does not have half closed connections 

65 pass 

66 

67 async def aclose(self): 

68 with suppress(LocalProtocolError): 

69 await self.stream.send(self.ws.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE))) 

70 await self.stream.aclose() 

71 

72 

73@dataclass(kw_only=True) 

74class WebsocketClientStream(ObjectStream[bytes]): 

75 ''' 

76 Websocket client streaming. 

77 ''' 

78 conn: WSSClientConnection 

79 

80 async def send(self, data: bytes) -> None: 

81 await self.conn.send(data) 

82 

83 async def receive(self) -> bytes: 

84 return await self.conn.recv() 

85 

86 async def send_eof(self): 

87 pass 

88 

89 async def aclose(self): 

90 await self.conn.close() 

91