Coverage for agentos/api/websocket.py: 0%
262 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2WebSocket 双向流式通信 — Agent 实时交互层。
4基于 websockets 库,提供 Agent 与客户端之间的全双工实时通信。
5支持流式进度报告、Agent 状态广播、父子 Agent 监控、暂停/恢复/取消。
7协议(JSON,双向):
9 Client → Server:
10 {"type": "run", "task": "...", "session_id": "..."}
11 {"type": "cancel", "session_id": "..."}
12 {"type": "pause", "session_id": "..."}
13 {"type": "resume", "session_id": "..."}
14 {"type": "ping"}
16 Server → Client:
17 {"type": "token", "text": "...", "seq": N}
18 {"type": "progress", "value": 0.5, "step": "..."}
19 {"type": "tool_call", "name": "...", "args": {...}}
20 {"type": "tool_result", "name": "...", "result": ...}
21 {"type": "status", "status": "running"|"paused"|"..."}
22 {"type": "done", "output": "...", "iterations": N}
23 {"type": "error", "message": "..."}
24 {"type": "heartbeat"}
25 {"type": "child_update", "agent_id": "...", "status": "..."}
27使用示例::
29 from agentos.api.websocket import AgentWebSocket, serve_ws
31 mgr = SubAgentManager()
33 async def my_run(spec, ctx):
34 await ctx.report_progress(0.5, "thinking")
35 return "answer", 1
37 ws = AgentWebSocket(manager=mgr, run_func=my_run)
38 await serve_ws(ws.handler, port=8765)
39"""
41from __future__ import annotations
43import asyncio
44import json
45import time
46import uuid
47from dataclasses import dataclass, field
48from enum import Enum
49from typing import Any, Callable, Awaitable
51import websockets
52from websockets.server import WebSocketServerProtocol
54from agentos.subagent.manager import SubAgentManager, SubAgentSpec, SubAgentResult
55from agentos.subagent.parent_child import ChildContext, ChildHandle, ChildStatus
58# ──────────────────────────────────────────────
59# 消息协议
60# ──────────────────────────────────────────────
63class WSMsgType(str, Enum):
64 """WebSocket 消息类型。"""
65 # Client → Server
66 RUN = "run"
67 CANCEL = "cancel"
68 PAUSE = "pause"
69 RESUME = "resume"
70 PING = "ping"
72 # Server → Client
73 TOKEN = "token"
74 PROGRESS = "progress"
75 TOOL_CALL = "tool_call"
76 TOOL_RESULT = "tool_result"
77 STATUS = "status"
78 DONE = "done"
79 ERROR = "error"
80 HEARTBEAT = "heartbeat"
81 CHILD_UPDATE = "child_update"
84@dataclass
85class WSMessage:
86 """WebSocket 消息体。"""
87 type: str
88 data: dict[str, Any] = field(default_factory=dict)
90 @classmethod
91 def parse(cls, raw: str | bytes) -> "WSMessage":
92 payload = json.loads(raw if isinstance(raw, str) else raw.decode())
93 return cls(
94 type=payload.get("type", ""),
95 data={k: v for k, v in payload.items() if k != "type"},
96 )
98 def serialize(self) -> str:
99 return json.dumps({"type": self.type, **self.data}, ensure_ascii=False)
101 # ── 工厂方法 ──────────────────────────
103 @classmethod
104 def token(cls, text: str, seq: int = 0) -> "WSMessage":
105 return cls(WSMsgType.TOKEN, {"text": text, "seq": seq})
107 @classmethod
108 def progress(cls, value: float, step: str = "", agent_id: str = "") -> "WSMessage":
109 return cls(WSMsgType.PROGRESS, {"value": value, "step": step, "agent_id": agent_id})
111 @classmethod
112 def tool_call(cls, name: str, args: dict) -> "WSMessage":
113 return cls(WSMsgType.TOOL_CALL, {"name": name, "args": args})
115 @classmethod
116 def tool_result(cls, name: str, result: Any) -> "WSMessage":
117 return cls(WSMsgType.TOOL_RESULT, {"name": name, "result": result})
119 @classmethod
120 def status(cls, status: str, agent_id: str = "") -> "WSMessage":
121 return cls(WSMsgType.STATUS, {"status": status, "agent_id": agent_id})
123 @classmethod
124 def done(cls, output: str, iterations: int = 0, agent_id: str = "") -> "WSMessage":
125 return cls(WSMsgType.DONE, {"output": output, "iterations": iterations, "agent_id": agent_id})
127 @classmethod
128 def error(cls, message: str, code: str = "UNKNOWN") -> "WSMessage":
129 return cls(WSMsgType.ERROR, {"message": message, "code": code})
131 @classmethod
132 def heartbeat(cls) -> "WSMessage":
133 return cls(WSMsgType.HEARTBEAT, {"ts": time.time()})
135 @classmethod
136 def child_update(cls, agent_id: str, status: str, progress: float = 0, step: str = "") -> "WSMessage":
137 return cls(WSMsgType.CHILD_UPDATE, {
138 "agent_id": agent_id,
139 "status": status,
140 "progress": progress,
141 "step": step,
142 })
145# ──────────────────────────────────────────────
146# 会话管理
147# ──────────────────────────────────────────────
150@dataclass
151class WSSession:
152 """单个 WebSocket 连接的会话。"""
153 session_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
154 connected_at: float = field(default_factory=time.time)
155 last_active: float = field(default_factory=time.time)
156 running_task: asyncio.Task | None = None
157 running_handle: ChildHandle | None = None
158 poll_task: asyncio.Task | None = None
159 metadata: dict[str, Any] = field(default_factory=dict)
161 @property
162 def is_busy(self) -> bool:
163 return self.running_task is not None and not self.running_task.done()
165 def touch(self):
166 self.last_active = time.time()
169# ──────────────────────────────────────────────
170# WebSocket Agent 核心
171# ──────────────────────────────────────────────
174class AgentWebSocket:
175 """Agent WebSocket 服务。
177 Args:
178 manager: SubAgentManager 实例
179 run_func: 自定义执行函数 (spec, ctx) -> (output, iterations)
180 heartbeat_interval: WebSocket 心跳间隔(秒)
181 poll_interval: 子 Agent 状态轮询间隔(秒)
182 max_message_size: 最大消息大小(字节)
183 """
185 def __init__(
186 self,
187 manager: SubAgentManager | None = None,
188 run_func: Callable[[SubAgentSpec, ChildContext], Awaitable[tuple[str, int]]] | None = None,
189 heartbeat_interval: float = 15.0,
190 poll_interval: float = 0.5,
191 max_message_size: int = 2 ** 20,
192 ):
193 self._mgr = manager or SubAgentManager()
194 self._run = run_func
195 self._heartbeat_interval = heartbeat_interval
196 self._poll_interval = poll_interval
197 self._max_message_size = max_message_size
198 self._sessions: dict[str, WSSession] = {}
199 self._conn_session: dict[WebSocketServerProtocol, str] = {}
201 # ── 主 handler ────────────────────────
203 async def handler(self, websocket: WebSocketServerProtocol) -> None:
204 """单连接 handler。"""
205 session = WSSession()
206 self._sessions[session.session_id] = session
207 self._conn_session[websocket] = session.session_id
209 heartbeat_task = asyncio.create_task(self._heartbeat_loop(websocket))
211 try:
212 await self._send(websocket, WSMessage.status("connected", session.session_id))
214 async for raw in websocket:
215 session.touch()
216 try:
217 msg = WSMessage.parse(raw)
218 await self._dispatch(websocket, session, msg)
219 except json.JSONDecodeError:
220 await self._send(websocket, WSMessage.error("Invalid JSON", "PARSE_ERROR"))
221 except Exception as e:
222 await self._send(websocket, WSMessage.error(str(e)))
224 except websockets.exceptions.ConnectionClosed:
225 pass
226 finally:
227 heartbeat_task.cancel()
228 try:
229 await heartbeat_task
230 except asyncio.CancelledError:
231 pass
232 await self._cleanup_session(websocket, session)
234 # ── 消息分发 ──────────────────────────
236 async def _dispatch(
237 self,
238 ws: WebSocketServerProtocol,
239 session: WSSession,
240 msg: WSMessage,
241 ) -> None:
242 handlers: dict[str, Callable] = {
243 WSMsgType.RUN: self._handle_run,
244 WSMsgType.CANCEL: self._handle_cancel,
245 WSMsgType.PAUSE: self._handle_pause,
246 WSMsgType.RESUME: self._handle_resume,
247 WSMsgType.PING: self._handle_ping,
248 }
250 handler = handlers.get(msg.type)
251 if handler:
252 await handler(ws, session, msg)
253 else:
254 await self._send(ws, WSMessage.error(f"Unknown type: {msg.type}", "UNKNOWN_TYPE"))
256 # ── run ───────────────────────────────
258 async def _handle_run(
259 self,
260 ws: WebSocketServerProtocol,
261 session: WSSession,
262 msg: WSMessage,
263 ) -> None:
264 if session.is_busy:
265 await self._send(ws, WSMessage.error("Session busy", "BUSY"))
266 return
268 task = msg.data.get("task", "")
269 if not task:
270 await self._send(ws, WSMessage.error("Missing 'task'", "INVALID"))
271 return
273 await self._send(ws, WSMessage.status("running", session.session_id))
275 session.running_task = asyncio.create_task(
276 self._run_agent(session, task)
277 )
278 session.poll_task = asyncio.create_task(
279 self._poll_agent(ws, session)
280 )
282 try:
283 await session.running_task
284 except asyncio.CancelledError:
285 await self._send(ws, WSMessage.status("cancelled", session.session_id))
286 return
288 async def _run_agent(self, session: WSSession, task: str) -> None:
289 """启动 Agent 并在完成后推送结果。"""
290 async def capturing_run(spec: SubAgentSpec, ctx: ChildContext) -> tuple[str, int]:
291 if self._run:
292 return await self._run(spec, ctx)
293 # 默认 fallback
294 await ctx.report_progress(0.5, "processing")
295 await ctx.report_progress(1.0, "done")
296 return f"Agent received: {task}", 1
298 result = await self._mgr.spawn_fork(task=task, run_func=capturing_run)
299 session.running_handle = self._mgr.get_handle(result.agent_id)
301 # 停止轮询
302 if session.poll_task and not session.poll_task.done():
303 session.poll_task.cancel()
305 # 确定最终状态并发送
306 if result.error:
307 await self.broadcast_to_session(session, WSMessage.error(result.error))
308 await self.broadcast_to_session(session, WSMessage.status("failed", result.agent_id))
309 else:
310 await self.broadcast_to_session(session, WSMessage.done(
311 output=result.output,
312 iterations=result.iterations,
313 agent_id=result.agent_id,
314 ))
315 await self.broadcast_to_session(session, WSMessage.status("completed", result.agent_id))
317 async def _poll_agent(
318 self,
319 ws: WebSocketServerProtocol,
320 session: WSSession,
321 ) -> None:
322 """轮询子 Agent 状态并流式推送进度。"""
323 try:
324 last_progress = -1.0
325 last_step = ""
326 while session.running_task and not session.running_task.done():
327 handle = session.running_handle
328 if handle is None:
329 # spawn_fork 尚未返回,检查 manager 中是否有新 agent
330 children = self._mgr.list_children()
331 if children:
332 latest = children[-1]
333 sid = latest.get("agent_id", "")
334 handle = self._mgr.get_handle(sid)
335 if handle and handle.status not in (ChildStatus.IDLE,):
336 session.running_handle = handle
338 if handle:
339 cur_progress = handle.info.progress
340 cur_step = handle.info.current_step
341 if cur_progress != last_progress or cur_step != last_step:
342 await self._send(ws, WSMessage.progress(
343 value=cur_progress,
344 step=cur_step,
345 agent_id=handle.agent_id,
346 ))
347 last_progress = cur_progress
348 last_step = cur_step
350 # 推送状态变化
351 if handle.status == ChildStatus.FAILED:
352 await self._send(ws, WSMessage.error(
353 handle.info.error or "Agent failed",
354 ))
355 break
356 elif handle.status == ChildStatus.CANCELLED:
357 break
359 await asyncio.sleep(self._poll_interval)
360 except asyncio.CancelledError:
361 pass
363 # ── cancel / pause / resume ───────────
365 async def _handle_cancel(
366 self,
367 ws: WebSocketServerProtocol,
368 session: WSSession,
369 msg: WSMessage,
370 ) -> None:
371 if session.running_handle:
372 await session.running_handle.cancel()
373 if session.running_task and not session.running_task.done():
374 session.running_task.cancel()
375 if session.poll_task and not session.poll_task.done():
376 session.poll_task.cancel()
377 await self._send(ws, WSMessage.status("cancelled", session.session_id))
379 async def _handle_pause(
380 self,
381 ws: WebSocketServerProtocol,
382 session: WSSession,
383 msg: WSMessage,
384 ) -> None:
385 if session.running_handle:
386 await session.running_handle.pause()
387 await self._send(ws, WSMessage.status("paused", session.session_id))
388 else:
389 await self._send(ws, WSMessage.error("No agent to pause", "IDLE"))
391 async def _handle_resume(
392 self,
393 ws: WebSocketServerProtocol,
394 session: WSSession,
395 msg: WSMessage,
396 ) -> None:
397 if session.running_handle:
398 await session.running_handle.resume()
399 await self._send(ws, WSMessage.status("running", session.session_id))
400 else:
401 await self._send(ws, WSMessage.error("No agent to resume", "IDLE"))
403 async def _handle_ping(
404 self,
405 ws: WebSocketServerProtocol,
406 session: WSSession,
407 msg: WSMessage,
408 ) -> None:
409 await self._send(ws, WSMessage.heartbeat())
411 # ── 心跳与广播 ────────────────────────
413 async def _heartbeat_loop(self, ws: WebSocketServerProtocol) -> None:
414 try:
415 while True:
416 await asyncio.sleep(self._heartbeat_interval)
417 await self._send(ws, WSMessage.heartbeat())
418 except (websockets.exceptions.ConnectionClosed, asyncio.CancelledError):
419 pass
421 async def broadcast(self, msg: WSMessage, exclude_session: str = "") -> None:
422 """向所有连接的客户端广播。"""
423 dead: list[WebSocketServerProtocol] = []
424 for ws, sid in list(self._conn_session.items()):
425 if sid == exclude_session:
426 continue
427 try:
428 await ws.send(msg.serialize())
429 except websockets.exceptions.ConnectionClosed:
430 dead.append(ws)
431 for ws in dead:
432 await self._cleanup_ws(ws)
434 async def broadcast_to_session(
435 self,
436 session: WSSession,
437 msg: WSMessage,
438 ) -> None:
439 """向指定会话对应的 WebSocket 发送消息。"""
440 for ws, sid in self._conn_session.items():
441 if sid == session.session_id:
442 try:
443 await ws.send(msg.serialize())
444 except websockets.exceptions.ConnectionClosed:
445 pass
446 return
448 async def broadcast_child_status(self) -> None:
449 """广播所有子 Agent 状态。"""
450 children = self._mgr.list_children()
451 for child in children:
452 await self.broadcast(WSMessage.child_update(
453 agent_id=child.get("agent_id", ""),
454 status=child.get("status", "unknown"),
455 progress=child.get("progress", 0),
456 step=child.get("current_step", ""),
457 ))
459 # ── 辅助 ──────────────────────────────
461 async def _send(self, ws: WebSocketServerProtocol, msg: WSMessage) -> None:
462 try:
463 await ws.send(msg.serialize())
464 except websockets.exceptions.ConnectionClosed:
465 pass
467 async def _cleanup_session(
468 self, ws: WebSocketServerProtocol, session: WSSession
469 ) -> None:
470 if session.running_task and not session.running_task.done():
471 session.running_task.cancel()
472 if session.poll_task and not session.poll_task.done():
473 session.poll_task.cancel()
474 self._conn_session.pop(ws, None)
475 self._sessions.pop(session.session_id, None)
477 async def _cleanup_ws(self, ws: WebSocketServerProtocol) -> None:
478 sid = self._conn_session.pop(ws, None)
479 if sid:
480 session = self._sessions.pop(sid, None)
481 if session:
482 if session.running_task and not session.running_task.done():
483 session.running_task.cancel()
484 if session.poll_task and not session.poll_task.done():
485 session.poll_task.cancel()
487 # ── 属性 ──────────────────────────────
489 @property
490 def manager(self) -> SubAgentManager:
491 return self._mgr
493 @property
494 def active_connections(self) -> int:
495 return len(self._conn_session)
497 @property
498 def active_sessions(self) -> int:
499 return len(self._sessions)
502# ──────────────────────────────────────────────
503# 便捷启动
504# ──────────────────────────────────────────────
507async def serve_ws(
508 ws_handler,
509 host: str = "0.0.0.0",
510 port: int = 8765,
511 **kwargs,
512):
513 """启动 WebSocket 服务。
515 Args:
516 ws_handler: AgentWebSocket.handler 或兼容的 coroutine handler
517 host: 监听地址
518 port: 监听端口
520 Example::
522 mgr = SubAgentManager()
523 ws = AgentWebSocket(manager=mgr)
524 await serve_ws(ws.handler, port=8765)
525 """
526 async with websockets.serve(
527 ws_handler,
528 host=host,
529 port=port,
530 max_size=kwargs.pop("max_size", 2 ** 20),
531 ping_interval=kwargs.pop("ping_interval", 20),
532 **kwargs,
533 ):
534 print(f"WebSocket server listening on ws://{host}:{port}")
535 await asyncio.Future() # run forever