14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149 | @dataclass
class BaseConnClient:
model_config = {"arbitrary_types_allowed": True}
url: Optional[str] = "ws://127.0.0.1:8000"
is_connected: Optional[bool] = False
pending_requests: Dict[str, Tuple[asyncio.Future, Callable]] = field(default_factory=dict)
event_handlers: Dict[str, List[Callable]] = field(default_factory=dict)
websocket: websockets.connect = None
receive_task: Optional[asyncio.Task] = None
async def _handle_message(self, message: str):
"""处理接收到的消息"""
try:
data = json.loads(message)
except json.JSONDecodeError:
logger.error(f"unable parse message: {message}")
try:
message_type = data.get("type")
if message_type == "response":
res = ConnResponse(**data)
# 处理响应消息
request_id = res.request_id
if request_id in self.pending_requests:
future,callback = self.pending_requests.pop(request_id)
future.set_result(res)
if callback:
callback(res)
elif message_type == "event":
event = ConnEvent(**data)
await self._handle_event(event)
except Exception as e:
logger.error(f"error in handle message: {traceback.format_exc()}")
async def _handle_event(self, event: ConnEvent):
"""处理事件"""
if event.event_type in self.event_handlers:
tasks = []
for handler in self.event_handlers.get(event.event_type, []):
if inspect.iscoroutinefunction(handler):
t_handler = handler(event)
tasks.append(t_handler)
else:
t_handler = asyncio.to_thread(handler, event)
tasks.append(t_handler)
await asyncio.gather(*tasks)
else:
logger.info(f"no handler for event type : {event.event_type}")
async def _receive_messages(self):
while self.is_connected:
try:
message = await self.websocket.recv()
await self._handle_message(message)
except websockets.exceptions.ConnectionClosed:
logger.error("WebSocket connection closed")
self.is_connected = False
except Exception as e:
logger.error(f"error in receive or handle message: {e}")
async def connect(self)-> bool:
try:
self.websocket = await websockets.connect(self.url, close_timeout=1)
self.is_connected = True
logger.info("success connected to %s", self.url)
if self.receive_task is None or self.receive_task.done():
self.receive_task = asyncio.create_task(self._receive_messages())
return True
except Exception as e:
logger.error("failed to connect to %s: %s", self.url, e)
self.is_connected = False
return False
async def _send_request(self, request: ConnRequest, callback: Callable = None)->ConnResponse:
if not self.is_connected:
raise Exception("websocket not connected")
future = asyncio.Future()
self.pending_requests[request.request_id] = (future, callback)
try:
# 发送请求
await self.websocket.send(request.model_dump_json())
except Exception as e:
# 移除pending请求
if request.request_id in self.pending_requests:
self.pending_requests.pop(request.request_id)
raise e
else:
result = await future
return result
async def send_request(self, command: CommandType, payload:Optional[Any]=None, callback: Optional[Callable] = None)->ConnResponse:
"""发送请求并等待响应"""
request = ConnRequest(command=command, payload=payload)
return await self._send_request(request, callback)
async def close(self):
"""关闭WebSocket连接"""
if self.is_connected:
self.is_connected = False
if self.receive_task and not self.receive_task.done():
self.receive_task.cancel()
try:
await self.receive_task
except asyncio.CancelledError:
pass
if self.websocket:
try:
await self.websocket.close()
except Exception:
pass
logger.info("WebSocket closed.")
def on_event(self, event_type: str, handler: Optional[Callable] = None):
"""注册事件处理器,支持装饰器语法"""
# 如果提供了handler参数,直接注册
if handler is not None:
if event_type not in self.event_handlers:
self.event_handlers[event_type] = []
self.event_handlers[event_type].append(handler)
return handler
# 否则返回装饰器函数
def decorator(func: Callable) -> Callable:
if event_type not in self.event_handlers:
self.event_handlers[event_type] = []
self.event_handlers[event_type].append(func)
return func
return decorator
|