Coverage for agentos/api/sse.py: 47%
109 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2SSE (Server-Sent Events) Streaming — production-grade async streaming endpoint.
4Provides an ASGI-compatible SSE stream with automatic reconnection,
5client heartbeat, backpressure control, and typed event dispatching.
6"""
8import asyncio
9import json
10import time
11from dataclasses import dataclass, field
12from enum import Enum
13from typing import Any, AsyncIterator, Callable, Dict, Optional
15DEFAULT_RETRY_MS = 3000
16DEFAULT_HEARTBEAT_S = 30
17MAX_QUEUE_SIZE = 256
20class SSEEventType(str, Enum):
21 """Standard SSE event types plus AgentOS extensions."""
23 MESSAGE = "message"
24 TOKEN = "token"
25 TOOL_CALL = "tool_call"
26 TOOL_RESULT = "tool_result"
27 ERROR = "error"
28 DONE = "done"
29 PING = "ping"
30 HEARTBEAT = "heartbeat"
31 METADATA = "metadata"
34@dataclass
35class SSEEvent:
36 """A single SSE event to be serialized to the wire."""
38 event: str = SSEEventType.MESSAGE
39 data: Any = ""
40 id: str = ""
41 retry: int = DEFAULT_RETRY_MS
43 def serialize(self) -> str:
44 """Serialize to raw SSE wire format."""
45 lines: list[str] = []
46 if self.event:
47 lines.append(f"event: {self.event.value}")
48 if self.id:
49 lines.append(f"id: {self.id}")
50 if self.retry != DEFAULT_RETRY_MS:
51 lines.append(f"retry: {self.retry}")
53 if isinstance(self.data, (dict, list)):
54 data_str = json.dumps(self.data, ensure_ascii=False)
55 else:
56 data_str = str(self.data)
58 # Multi-line data
59 for line in data_str.split("\n"):
60 lines.append(f"data: {line}")
61 return "\n".join(lines) + "\n\n"
63 @classmethod
64 def token(cls, text: str, seq: int = 0) -> "SSEEvent":
65 return cls(event=SSEEventType.TOKEN, data={"text": text, "seq": seq})
67 @classmethod
68 def tool_call(cls, name: str, args: dict) -> "SSEEvent":
69 return cls(
70 event=SSEEventType.TOOL_CALL,
71 data={"name": name, "arguments": args},
72 )
74 @classmethod
75 def tool_result(cls, name: str, result: Any) -> "SSEEvent":
76 return cls(
77 event=SSEEventType.TOOL_RESULT,
78 data={"name": name, "result": result},
79 )
81 @classmethod
82 def error(cls, message: str, code: str = "UNKNOWN") -> "SSEEvent":
83 return cls(
84 event=SSEEventType.ERROR,
85 data={"message": message, "code": code},
86 )
88 @classmethod
89 def done(cls, metadata: dict[str, Any] | None = None) -> "SSEEvent":
90 return cls(
91 event=SSEEventType.DONE,
92 data=metadata or {},
93 )
95 @classmethod
96 def metadata(cls, meta: dict[str, Any]) -> "SSEEvent":
97 return cls(event=SSEEventType.METADATA, data=meta)
100class SSEStream:
101 """SSE stream with heartbeats and backpressure handling.
103 Usage::
105 stream = SSEStream(retry_ms=3000)
106 # Producer
107 await stream.queue.put(SSEEvent.token("Hello"))
108 await stream.queue.put(SSEEvent.done())
109 await stream.close()
111 # Consumer (ASGI)
112 async for chunk in stream.iter_chunks():
113 yield chunk
114 """
116 def __init__(
117 self,
118 retry_ms: int = DEFAULT_RETRY_MS,
119 heartbeat_s: float = DEFAULT_HEARTBEAT_S,
120 max_queue: int = MAX_QUEUE_SIZE,
121 ):
122 self.retry_ms = retry_ms
123 self.heartbeat_s = heartbeat_s
124 self.queue: asyncio.Queue[SSEEvent | None] = asyncio.Queue(maxsize=max_queue)
125 self._closed = False
126 self._heartbeat_task: Optional[asyncio.Task] = None
127 self._last_event_id = 0
129 async def start(self):
130 """Start the heartbeat background task."""
131 self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
133 async def _heartbeat_loop(self):
134 """Send periodic heartbeat pings."""
135 try:
136 while not self._closed:
137 await asyncio.sleep(self.heartbeat_s)
138 if not self._closed:
139 await self.queue.put(
140 SSEEvent(
141 event=SSEEventType.HEARTBEAT,
142 data={"ts": time.time()},
143 )
144 )
145 except asyncio.CancelledError:
146 pass
148 async def send(self, event: SSEEvent):
149 """Enqueue an event. Raises QueueFull if backpressure exceeded."""
150 if self._closed:
151 raise RuntimeError("Stream is closed")
152 self._last_event_id += 1
153 if not event.id:
154 event.id = str(self._last_event_id)
155 self.queue.put_nowait(event)
157 async def close(self):
158 """Signal end of stream."""
159 self._closed = True
160 await self.queue.put(None) # Sentinel
161 if self._heartbeat_task:
162 self._heartbeat_task.cancel()
163 try:
164 await self._heartbeat_task
165 except asyncio.CancelledError:
166 pass
168 async def iter_events(self) -> AsyncIterator[SSEEvent]:
169 """Async iterator over enqueued events."""
170 while True:
171 event = await self.queue.get()
172 if event is None:
173 break
174 yield event
176 async def iter_chunks(self) -> AsyncIterator[str]:
177 """Async iterator yielding raw SSE wire-format chunks."""
178 async for event in self.iter_events():
179 yield event.serialize()
182class SSEResponse:
183 """Factory for generating ASGI-compatible SSE HTTP responses.
185 Usage (Starlette / FastAPI)::
187 from starlette.responses import StreamingResponse
189 sse = SSEResponse(stream)
190 return StreamingResponse(
191 sse.body(),
192 media_type="text/event-stream",
193 headers=sse.headers(),
194 )
195 """
197 HEADERS = {
198 "Content-Type": "text/event-stream",
199 "Cache-Control": "no-cache",
200 "Connection": "keep-alive",
201 "X-Accel-Buffering": "no",
202 }
204 def __init__(self, stream: SSEStream):
205 self.stream = stream
207 def headers(self) -> dict[str, str]:
208 return dict(self.HEADERS)
210 async def body(self) -> AsyncIterator[str]:
211 """ASGI-compatible body iterator."""
212 async for chunk in self.stream.iter_chunks():
213 yield chunk