Coverage for agentos/api/websocket.py: 0%

262 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2WebSocket 双向流式通信 — Agent 实时交互层。 

3 

4基于 websockets 库,提供 Agent 与客户端之间的全双工实时通信。 

5支持流式进度报告、Agent 状态广播、父子 Agent 监控、暂停/恢复/取消。 

6 

7协议(JSON,双向): 

8 

9 Client → Server: 

10 {"type": "run", "task": "...", "session_id": "..."} 

11 {"type": "cancel", "session_id": "..."} 

12 {"type": "pause", "session_id": "..."} 

13 {"type": "resume", "session_id": "..."} 

14 {"type": "ping"} 

15 

16 Server → Client: 

17 {"type": "token", "text": "...", "seq": N} 

18 {"type": "progress", "value": 0.5, "step": "..."} 

19 {"type": "tool_call", "name": "...", "args": {...}} 

20 {"type": "tool_result", "name": "...", "result": ...} 

21 {"type": "status", "status": "running"|"paused"|"..."} 

22 {"type": "done", "output": "...", "iterations": N} 

23 {"type": "error", "message": "..."} 

24 {"type": "heartbeat"} 

25 {"type": "child_update", "agent_id": "...", "status": "..."} 

26 

27使用示例:: 

28 

29 from agentos.api.websocket import AgentWebSocket, serve_ws 

30 

31 mgr = SubAgentManager() 

32 

33 async def my_run(spec, ctx): 

34 await ctx.report_progress(0.5, "thinking") 

35 return "answer", 1 

36 

37 ws = AgentWebSocket(manager=mgr, run_func=my_run) 

38 await serve_ws(ws.handler, port=8765) 

39""" 

40 

41from __future__ import annotations 

42 

43import asyncio 

44import json 

45import time 

46import uuid 

47from dataclasses import dataclass, field 

48from enum import Enum 

49from typing import Any, Callable, Awaitable 

50 

51import websockets 

52from websockets.server import WebSocketServerProtocol 

53 

54from agentos.subagent.manager import SubAgentManager, SubAgentSpec, SubAgentResult 

55from agentos.subagent.parent_child import ChildContext, ChildHandle, ChildStatus 

56 

57 

58# ────────────────────────────────────────────── 

59# 消息协议 

60# ────────────────────────────────────────────── 

61 

62 

63class WSMsgType(str, Enum): 

64 """WebSocket 消息类型。""" 

65 # Client → Server 

66 RUN = "run" 

67 CANCEL = "cancel" 

68 PAUSE = "pause" 

69 RESUME = "resume" 

70 PING = "ping" 

71 

72 # Server → Client 

73 TOKEN = "token" 

74 PROGRESS = "progress" 

75 TOOL_CALL = "tool_call" 

76 TOOL_RESULT = "tool_result" 

77 STATUS = "status" 

78 DONE = "done" 

79 ERROR = "error" 

80 HEARTBEAT = "heartbeat" 

81 CHILD_UPDATE = "child_update" 

82 

83 

84@dataclass 

85class WSMessage: 

86 """WebSocket 消息体。""" 

87 type: str 

88 data: dict[str, Any] = field(default_factory=dict) 

89 

90 @classmethod 

91 def parse(cls, raw: str | bytes) -> "WSMessage": 

92 payload = json.loads(raw if isinstance(raw, str) else raw.decode()) 

93 return cls( 

94 type=payload.get("type", ""), 

95 data={k: v for k, v in payload.items() if k != "type"}, 

96 ) 

97 

98 def serialize(self) -> str: 

99 return json.dumps({"type": self.type, **self.data}, ensure_ascii=False) 

100 

101 # ── 工厂方法 ────────────────────────── 

102 

103 @classmethod 

104 def token(cls, text: str, seq: int = 0) -> "WSMessage": 

105 return cls(WSMsgType.TOKEN, {"text": text, "seq": seq}) 

106 

107 @classmethod 

108 def progress(cls, value: float, step: str = "", agent_id: str = "") -> "WSMessage": 

109 return cls(WSMsgType.PROGRESS, {"value": value, "step": step, "agent_id": agent_id}) 

110 

111 @classmethod 

112 def tool_call(cls, name: str, args: dict) -> "WSMessage": 

113 return cls(WSMsgType.TOOL_CALL, {"name": name, "args": args}) 

114 

115 @classmethod 

116 def tool_result(cls, name: str, result: Any) -> "WSMessage": 

117 return cls(WSMsgType.TOOL_RESULT, {"name": name, "result": result}) 

118 

119 @classmethod 

120 def status(cls, status: str, agent_id: str = "") -> "WSMessage": 

121 return cls(WSMsgType.STATUS, {"status": status, "agent_id": agent_id}) 

122 

123 @classmethod 

124 def done(cls, output: str, iterations: int = 0, agent_id: str = "") -> "WSMessage": 

125 return cls(WSMsgType.DONE, {"output": output, "iterations": iterations, "agent_id": agent_id}) 

126 

127 @classmethod 

128 def error(cls, message: str, code: str = "UNKNOWN") -> "WSMessage": 

129 return cls(WSMsgType.ERROR, {"message": message, "code": code}) 

130 

131 @classmethod 

132 def heartbeat(cls) -> "WSMessage": 

133 return cls(WSMsgType.HEARTBEAT, {"ts": time.time()}) 

134 

135 @classmethod 

136 def child_update(cls, agent_id: str, status: str, progress: float = 0, step: str = "") -> "WSMessage": 

137 return cls(WSMsgType.CHILD_UPDATE, { 

138 "agent_id": agent_id, 

139 "status": status, 

140 "progress": progress, 

141 "step": step, 

142 }) 

143 

144 

145# ────────────────────────────────────────────── 

146# 会话管理 

147# ────────────────────────────────────────────── 

148 

149 

150@dataclass 

151class WSSession: 

152 """单个 WebSocket 连接的会话。""" 

153 session_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) 

154 connected_at: float = field(default_factory=time.time) 

155 last_active: float = field(default_factory=time.time) 

156 running_task: asyncio.Task | None = None 

157 running_handle: ChildHandle | None = None 

158 poll_task: asyncio.Task | None = None 

159 metadata: dict[str, Any] = field(default_factory=dict) 

160 

161 @property 

162 def is_busy(self) -> bool: 

163 return self.running_task is not None and not self.running_task.done() 

164 

165 def touch(self): 

166 self.last_active = time.time() 

167 

168 

169# ────────────────────────────────────────────── 

170# WebSocket Agent 核心 

171# ────────────────────────────────────────────── 

172 

173 

174class AgentWebSocket: 

175 """Agent WebSocket 服务。 

176 

177 Args: 

178 manager: SubAgentManager 实例 

179 run_func: 自定义执行函数 (spec, ctx) -> (output, iterations) 

180 heartbeat_interval: WebSocket 心跳间隔(秒) 

181 poll_interval: 子 Agent 状态轮询间隔(秒) 

182 max_message_size: 最大消息大小(字节) 

183 """ 

184 

185 def __init__( 

186 self, 

187 manager: SubAgentManager | None = None, 

188 run_func: Callable[[SubAgentSpec, ChildContext], Awaitable[tuple[str, int]]] | None = None, 

189 heartbeat_interval: float = 15.0, 

190 poll_interval: float = 0.5, 

191 max_message_size: int = 2 ** 20, 

192 ): 

193 self._mgr = manager or SubAgentManager() 

194 self._run = run_func 

195 self._heartbeat_interval = heartbeat_interval 

196 self._poll_interval = poll_interval 

197 self._max_message_size = max_message_size 

198 self._sessions: dict[str, WSSession] = {} 

199 self._conn_session: dict[WebSocketServerProtocol, str] = {} 

200 

201 # ── 主 handler ──────────────────────── 

202 

203 async def handler(self, websocket: WebSocketServerProtocol) -> None: 

204 """单连接 handler。""" 

205 session = WSSession() 

206 self._sessions[session.session_id] = session 

207 self._conn_session[websocket] = session.session_id 

208 

209 heartbeat_task = asyncio.create_task(self._heartbeat_loop(websocket)) 

210 

211 try: 

212 await self._send(websocket, WSMessage.status("connected", session.session_id)) 

213 

214 async for raw in websocket: 

215 session.touch() 

216 try: 

217 msg = WSMessage.parse(raw) 

218 await self._dispatch(websocket, session, msg) 

219 except json.JSONDecodeError: 

220 await self._send(websocket, WSMessage.error("Invalid JSON", "PARSE_ERROR")) 

221 except Exception as e: 

222 await self._send(websocket, WSMessage.error(str(e))) 

223 

224 except websockets.exceptions.ConnectionClosed: 

225 pass 

226 finally: 

227 heartbeat_task.cancel() 

228 try: 

229 await heartbeat_task 

230 except asyncio.CancelledError: 

231 pass 

232 await self._cleanup_session(websocket, session) 

233 

234 # ── 消息分发 ────────────────────────── 

235 

236 async def _dispatch( 

237 self, 

238 ws: WebSocketServerProtocol, 

239 session: WSSession, 

240 msg: WSMessage, 

241 ) -> None: 

242 handlers: dict[str, Callable] = { 

243 WSMsgType.RUN: self._handle_run, 

244 WSMsgType.CANCEL: self._handle_cancel, 

245 WSMsgType.PAUSE: self._handle_pause, 

246 WSMsgType.RESUME: self._handle_resume, 

247 WSMsgType.PING: self._handle_ping, 

248 } 

249 

250 handler = handlers.get(msg.type) 

251 if handler: 

252 await handler(ws, session, msg) 

253 else: 

254 await self._send(ws, WSMessage.error(f"Unknown type: {msg.type}", "UNKNOWN_TYPE")) 

255 

256 # ── run ─────────────────────────────── 

257 

258 async def _handle_run( 

259 self, 

260 ws: WebSocketServerProtocol, 

261 session: WSSession, 

262 msg: WSMessage, 

263 ) -> None: 

264 if session.is_busy: 

265 await self._send(ws, WSMessage.error("Session busy", "BUSY")) 

266 return 

267 

268 task = msg.data.get("task", "") 

269 if not task: 

270 await self._send(ws, WSMessage.error("Missing 'task'", "INVALID")) 

271 return 

272 

273 await self._send(ws, WSMessage.status("running", session.session_id)) 

274 

275 session.running_task = asyncio.create_task( 

276 self._run_agent(session, task) 

277 ) 

278 session.poll_task = asyncio.create_task( 

279 self._poll_agent(ws, session) 

280 ) 

281 

282 try: 

283 await session.running_task 

284 except asyncio.CancelledError: 

285 await self._send(ws, WSMessage.status("cancelled", session.session_id)) 

286 return 

287 

288 async def _run_agent(self, session: WSSession, task: str) -> None: 

289 """启动 Agent 并在完成后推送结果。""" 

290 async def capturing_run(spec: SubAgentSpec, ctx: ChildContext) -> tuple[str, int]: 

291 if self._run: 

292 return await self._run(spec, ctx) 

293 # 默认 fallback 

294 await ctx.report_progress(0.5, "processing") 

295 await ctx.report_progress(1.0, "done") 

296 return f"Agent received: {task}", 1 

297 

298 result = await self._mgr.spawn_fork(task=task, run_func=capturing_run) 

299 session.running_handle = self._mgr.get_handle(result.agent_id) 

300 

301 # 停止轮询 

302 if session.poll_task and not session.poll_task.done(): 

303 session.poll_task.cancel() 

304 

305 # 确定最终状态并发送 

306 if result.error: 

307 await self.broadcast_to_session(session, WSMessage.error(result.error)) 

308 await self.broadcast_to_session(session, WSMessage.status("failed", result.agent_id)) 

309 else: 

310 await self.broadcast_to_session(session, WSMessage.done( 

311 output=result.output, 

312 iterations=result.iterations, 

313 agent_id=result.agent_id, 

314 )) 

315 await self.broadcast_to_session(session, WSMessage.status("completed", result.agent_id)) 

316 

317 async def _poll_agent( 

318 self, 

319 ws: WebSocketServerProtocol, 

320 session: WSSession, 

321 ) -> None: 

322 """轮询子 Agent 状态并流式推送进度。""" 

323 try: 

324 last_progress = -1.0 

325 last_step = "" 

326 while session.running_task and not session.running_task.done(): 

327 handle = session.running_handle 

328 if handle is None: 

329 # spawn_fork 尚未返回,检查 manager 中是否有新 agent 

330 children = self._mgr.list_children() 

331 if children: 

332 latest = children[-1] 

333 sid = latest.get("agent_id", "") 

334 handle = self._mgr.get_handle(sid) 

335 if handle and handle.status not in (ChildStatus.IDLE,): 

336 session.running_handle = handle 

337 

338 if handle: 

339 cur_progress = handle.info.progress 

340 cur_step = handle.info.current_step 

341 if cur_progress != last_progress or cur_step != last_step: 

342 await self._send(ws, WSMessage.progress( 

343 value=cur_progress, 

344 step=cur_step, 

345 agent_id=handle.agent_id, 

346 )) 

347 last_progress = cur_progress 

348 last_step = cur_step 

349 

350 # 推送状态变化 

351 if handle.status == ChildStatus.FAILED: 

352 await self._send(ws, WSMessage.error( 

353 handle.info.error or "Agent failed", 

354 )) 

355 break 

356 elif handle.status == ChildStatus.CANCELLED: 

357 break 

358 

359 await asyncio.sleep(self._poll_interval) 

360 except asyncio.CancelledError: 

361 pass 

362 

363 # ── cancel / pause / resume ─────────── 

364 

365 async def _handle_cancel( 

366 self, 

367 ws: WebSocketServerProtocol, 

368 session: WSSession, 

369 msg: WSMessage, 

370 ) -> None: 

371 if session.running_handle: 

372 await session.running_handle.cancel() 

373 if session.running_task and not session.running_task.done(): 

374 session.running_task.cancel() 

375 if session.poll_task and not session.poll_task.done(): 

376 session.poll_task.cancel() 

377 await self._send(ws, WSMessage.status("cancelled", session.session_id)) 

378 

379 async def _handle_pause( 

380 self, 

381 ws: WebSocketServerProtocol, 

382 session: WSSession, 

383 msg: WSMessage, 

384 ) -> None: 

385 if session.running_handle: 

386 await session.running_handle.pause() 

387 await self._send(ws, WSMessage.status("paused", session.session_id)) 

388 else: 

389 await self._send(ws, WSMessage.error("No agent to pause", "IDLE")) 

390 

391 async def _handle_resume( 

392 self, 

393 ws: WebSocketServerProtocol, 

394 session: WSSession, 

395 msg: WSMessage, 

396 ) -> None: 

397 if session.running_handle: 

398 await session.running_handle.resume() 

399 await self._send(ws, WSMessage.status("running", session.session_id)) 

400 else: 

401 await self._send(ws, WSMessage.error("No agent to resume", "IDLE")) 

402 

403 async def _handle_ping( 

404 self, 

405 ws: WebSocketServerProtocol, 

406 session: WSSession, 

407 msg: WSMessage, 

408 ) -> None: 

409 await self._send(ws, WSMessage.heartbeat()) 

410 

411 # ── 心跳与广播 ──────────────────────── 

412 

413 async def _heartbeat_loop(self, ws: WebSocketServerProtocol) -> None: 

414 try: 

415 while True: 

416 await asyncio.sleep(self._heartbeat_interval) 

417 await self._send(ws, WSMessage.heartbeat()) 

418 except (websockets.exceptions.ConnectionClosed, asyncio.CancelledError): 

419 pass 

420 

421 async def broadcast(self, msg: WSMessage, exclude_session: str = "") -> None: 

422 """向所有连接的客户端广播。""" 

423 dead: list[WebSocketServerProtocol] = [] 

424 for ws, sid in list(self._conn_session.items()): 

425 if sid == exclude_session: 

426 continue 

427 try: 

428 await ws.send(msg.serialize()) 

429 except websockets.exceptions.ConnectionClosed: 

430 dead.append(ws) 

431 for ws in dead: 

432 await self._cleanup_ws(ws) 

433 

434 async def broadcast_to_session( 

435 self, 

436 session: WSSession, 

437 msg: WSMessage, 

438 ) -> None: 

439 """向指定会话对应的 WebSocket 发送消息。""" 

440 for ws, sid in self._conn_session.items(): 

441 if sid == session.session_id: 

442 try: 

443 await ws.send(msg.serialize()) 

444 except websockets.exceptions.ConnectionClosed: 

445 pass 

446 return 

447 

448 async def broadcast_child_status(self) -> None: 

449 """广播所有子 Agent 状态。""" 

450 children = self._mgr.list_children() 

451 for child in children: 

452 await self.broadcast(WSMessage.child_update( 

453 agent_id=child.get("agent_id", ""), 

454 status=child.get("status", "unknown"), 

455 progress=child.get("progress", 0), 

456 step=child.get("current_step", ""), 

457 )) 

458 

459 # ── 辅助 ────────────────────────────── 

460 

461 async def _send(self, ws: WebSocketServerProtocol, msg: WSMessage) -> None: 

462 try: 

463 await ws.send(msg.serialize()) 

464 except websockets.exceptions.ConnectionClosed: 

465 pass 

466 

467 async def _cleanup_session( 

468 self, ws: WebSocketServerProtocol, session: WSSession 

469 ) -> None: 

470 if session.running_task and not session.running_task.done(): 

471 session.running_task.cancel() 

472 if session.poll_task and not session.poll_task.done(): 

473 session.poll_task.cancel() 

474 self._conn_session.pop(ws, None) 

475 self._sessions.pop(session.session_id, None) 

476 

477 async def _cleanup_ws(self, ws: WebSocketServerProtocol) -> None: 

478 sid = self._conn_session.pop(ws, None) 

479 if sid: 

480 session = self._sessions.pop(sid, None) 

481 if session: 

482 if session.running_task and not session.running_task.done(): 

483 session.running_task.cancel() 

484 if session.poll_task and not session.poll_task.done(): 

485 session.poll_task.cancel() 

486 

487 # ── 属性 ────────────────────────────── 

488 

489 @property 

490 def manager(self) -> SubAgentManager: 

491 return self._mgr 

492 

493 @property 

494 def active_connections(self) -> int: 

495 return len(self._conn_session) 

496 

497 @property 

498 def active_sessions(self) -> int: 

499 return len(self._sessions) 

500 

501 

502# ────────────────────────────────────────────── 

503# 便捷启动 

504# ────────────────────────────────────────────── 

505 

506 

507async def serve_ws( 

508 ws_handler, 

509 host: str = "0.0.0.0", 

510 port: int = 8765, 

511 **kwargs, 

512): 

513 """启动 WebSocket 服务。 

514 

515 Args: 

516 ws_handler: AgentWebSocket.handler 或兼容的 coroutine handler 

517 host: 监听地址 

518 port: 监听端口 

519 

520 Example:: 

521 

522 mgr = SubAgentManager() 

523 ws = AgentWebSocket(manager=mgr) 

524 await serve_ws(ws.handler, port=8765) 

525 """ 

526 async with websockets.serve( 

527 ws_handler, 

528 host=host, 

529 port=port, 

530 max_size=kwargs.pop("max_size", 2 ** 20), 

531 ping_interval=kwargs.pop("ping_interval", 20), 

532 **kwargs, 

533 ): 

534 print(f"WebSocket server listening on ws://{host}:{port}") 

535 await asyncio.Future() # run forever