Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter-driver-network/jumpstarter_driver_network/streams/websocket.py: 46%
61 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 15:50 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 15:50 +0200
1from contextlib import suppress
2from dataclasses import dataclass, field
3from typing import Tuple
5from anyio import BrokenResourceError, WouldBlock, create_memory_object_stream
6from anyio.abc import AnyByteStream, ObjectStream
7from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
8from websockets.asyncio.client import ClientConnection as WSSClientConnection
9from wsproto import ConnectionType, WSConnection
10from wsproto.connection import ConnectionState
11from wsproto.events import (
12 AcceptConnection,
13 CloseConnection,
14 Message,
15 Ping,
16 Request,
17)
18from wsproto.frame_protocol import CloseReason
19from wsproto.utilities import LocalProtocolError, RemoteProtocolError
22@dataclass(kw_only=True)
23class WebsocketServerStream(ObjectStream[bytes]):
24 stream: AnyByteStream
26 ws: WSConnection = field(init=False, default_factory=lambda: WSConnection(ConnectionType.SERVER))
27 queue: Tuple[MemoryObjectSendStream[bytes], MemoryObjectReceiveStream[bytes]] = field(
28 init=False,
29 default_factory=lambda: create_memory_object_stream[bytes](32), # ty: ignore[call-non-callable]
30 )
32 async def send(self, data: bytes) -> None:
33 try:
34 self.ws.receive_data(data)
35 except RemoteProtocolError as e:
36 raise BrokenResourceError from e
38 try:
39 for event in self.ws.events():
40 match event:
41 case Request():
42 await self.queue[0].send(self.ws.send(AcceptConnection()))
43 case CloseConnection():
44 await self.queue[0].send(self.ws.send(event.response()))
45 case Message():
46 await self.stream.send(event.data)
47 case Ping():
48 await self.queue[0].send(self.ws.send(event.response()))
49 except LocalProtocolError as e:
50 raise BrokenResourceError from e
52 async def receive(self) -> bytes:
53 with suppress(WouldBlock):
54 return self.queue[1].receive_nowait()
56 if self.ws.state == ConnectionState.CONNECTING:
57 return await self.queue[1].receive()
59 try:
60 return self.ws.send(Message(data=await self.stream.receive()))
61 except LocalProtocolError as e:
62 raise BrokenResourceError from e
64 async def send_eof(self):
65 # websocket does not have half closed connections
66 pass
68 async def aclose(self):
69 with suppress(LocalProtocolError):
70 await self.stream.send(self.ws.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE)))
71 await self.stream.aclose()
74@dataclass(kw_only=True)
75class WebsocketClientStream(ObjectStream[bytes]):
76 """
77 Websocket client streaming.
78 """
80 conn: WSSClientConnection
82 async def send(self, data: bytes) -> None:
83 await self.conn.send(data)
85 async def receive(self) -> bytes:
86 return await self.conn.recv()
88 async def send_eof(self):
89 pass
91 async def aclose(self):
92 await self.conn.close()