Coverage for agentos/api/sse.py: 47%

109 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2SSE (Server-Sent Events) Streaming — production-grade async streaming endpoint. 

3 

4Provides an ASGI-compatible SSE stream with automatic reconnection, 

5client heartbeat, backpressure control, and typed event dispatching. 

6""" 

7 

8import asyncio 

9import json 

10import time 

11from dataclasses import dataclass, field 

12from enum import Enum 

13from typing import Any, AsyncIterator, Callable, Dict, Optional 

14 

15DEFAULT_RETRY_MS = 3000 

16DEFAULT_HEARTBEAT_S = 30 

17MAX_QUEUE_SIZE = 256 

18 

19 

20class SSEEventType(str, Enum): 

21 """Standard SSE event types plus AgentOS extensions.""" 

22 

23 MESSAGE = "message" 

24 TOKEN = "token" 

25 TOOL_CALL = "tool_call" 

26 TOOL_RESULT = "tool_result" 

27 ERROR = "error" 

28 DONE = "done" 

29 PING = "ping" 

30 HEARTBEAT = "heartbeat" 

31 METADATA = "metadata" 

32 

33 

34@dataclass 

35class SSEEvent: 

36 """A single SSE event to be serialized to the wire.""" 

37 

38 event: str = SSEEventType.MESSAGE 

39 data: Any = "" 

40 id: str = "" 

41 retry: int = DEFAULT_RETRY_MS 

42 

43 def serialize(self) -> str: 

44 """Serialize to raw SSE wire format.""" 

45 lines: list[str] = [] 

46 if self.event: 

47 lines.append(f"event: {self.event.value}") 

48 if self.id: 

49 lines.append(f"id: {self.id}") 

50 if self.retry != DEFAULT_RETRY_MS: 

51 lines.append(f"retry: {self.retry}") 

52 

53 if isinstance(self.data, (dict, list)): 

54 data_str = json.dumps(self.data, ensure_ascii=False) 

55 else: 

56 data_str = str(self.data) 

57 

58 # Multi-line data 

59 for line in data_str.split("\n"): 

60 lines.append(f"data: {line}") 

61 return "\n".join(lines) + "\n\n" 

62 

63 @classmethod 

64 def token(cls, text: str, seq: int = 0) -> "SSEEvent": 

65 return cls(event=SSEEventType.TOKEN, data={"text": text, "seq": seq}) 

66 

67 @classmethod 

68 def tool_call(cls, name: str, args: dict) -> "SSEEvent": 

69 return cls( 

70 event=SSEEventType.TOOL_CALL, 

71 data={"name": name, "arguments": args}, 

72 ) 

73 

74 @classmethod 

75 def tool_result(cls, name: str, result: Any) -> "SSEEvent": 

76 return cls( 

77 event=SSEEventType.TOOL_RESULT, 

78 data={"name": name, "result": result}, 

79 ) 

80 

81 @classmethod 

82 def error(cls, message: str, code: str = "UNKNOWN") -> "SSEEvent": 

83 return cls( 

84 event=SSEEventType.ERROR, 

85 data={"message": message, "code": code}, 

86 ) 

87 

88 @classmethod 

89 def done(cls, metadata: dict[str, Any] | None = None) -> "SSEEvent": 

90 return cls( 

91 event=SSEEventType.DONE, 

92 data=metadata or {}, 

93 ) 

94 

95 @classmethod 

96 def metadata(cls, meta: dict[str, Any]) -> "SSEEvent": 

97 return cls(event=SSEEventType.METADATA, data=meta) 

98 

99 

100class SSEStream: 

101 """SSE stream with heartbeats and backpressure handling. 

102 

103 Usage:: 

104 

105 stream = SSEStream(retry_ms=3000) 

106 # Producer 

107 await stream.queue.put(SSEEvent.token("Hello")) 

108 await stream.queue.put(SSEEvent.done()) 

109 await stream.close() 

110 

111 # Consumer (ASGI) 

112 async for chunk in stream.iter_chunks(): 

113 yield chunk 

114 """ 

115 

116 def __init__( 

117 self, 

118 retry_ms: int = DEFAULT_RETRY_MS, 

119 heartbeat_s: float = DEFAULT_HEARTBEAT_S, 

120 max_queue: int = MAX_QUEUE_SIZE, 

121 ): 

122 self.retry_ms = retry_ms 

123 self.heartbeat_s = heartbeat_s 

124 self.queue: asyncio.Queue[SSEEvent | None] = asyncio.Queue(maxsize=max_queue) 

125 self._closed = False 

126 self._heartbeat_task: Optional[asyncio.Task] = None 

127 self._last_event_id = 0 

128 

129 async def start(self): 

130 """Start the heartbeat background task.""" 

131 self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) 

132 

133 async def _heartbeat_loop(self): 

134 """Send periodic heartbeat pings.""" 

135 try: 

136 while not self._closed: 

137 await asyncio.sleep(self.heartbeat_s) 

138 if not self._closed: 

139 await self.queue.put( 

140 SSEEvent( 

141 event=SSEEventType.HEARTBEAT, 

142 data={"ts": time.time()}, 

143 ) 

144 ) 

145 except asyncio.CancelledError: 

146 pass 

147 

148 async def send(self, event: SSEEvent): 

149 """Enqueue an event. Raises QueueFull if backpressure exceeded.""" 

150 if self._closed: 

151 raise RuntimeError("Stream is closed") 

152 self._last_event_id += 1 

153 if not event.id: 

154 event.id = str(self._last_event_id) 

155 self.queue.put_nowait(event) 

156 

157 async def close(self): 

158 """Signal end of stream.""" 

159 self._closed = True 

160 await self.queue.put(None) # Sentinel 

161 if self._heartbeat_task: 

162 self._heartbeat_task.cancel() 

163 try: 

164 await self._heartbeat_task 

165 except asyncio.CancelledError: 

166 pass 

167 

168 async def iter_events(self) -> AsyncIterator[SSEEvent]: 

169 """Async iterator over enqueued events.""" 

170 while True: 

171 event = await self.queue.get() 

172 if event is None: 

173 break 

174 yield event 

175 

176 async def iter_chunks(self) -> AsyncIterator[str]: 

177 """Async iterator yielding raw SSE wire-format chunks.""" 

178 async for event in self.iter_events(): 

179 yield event.serialize() 

180 

181 

182class SSEResponse: 

183 """Factory for generating ASGI-compatible SSE HTTP responses. 

184 

185 Usage (Starlette / FastAPI):: 

186 

187 from starlette.responses import StreamingResponse 

188 

189 sse = SSEResponse(stream) 

190 return StreamingResponse( 

191 sse.body(), 

192 media_type="text/event-stream", 

193 headers=sse.headers(), 

194 ) 

195 """ 

196 

197 HEADERS = { 

198 "Content-Type": "text/event-stream", 

199 "Cache-Control": "no-cache", 

200 "Connection": "keep-alive", 

201 "X-Accel-Buffering": "no", 

202 } 

203 

204 def __init__(self, stream: SSEStream): 

205 self.stream = stream 

206 

207 def headers(self) -> dict[str, str]: 

208 return dict(self.HEADERS) 

209 

210 async def body(self) -> AsyncIterator[str]: 

211 """ASGI-compatible body iterator.""" 

212 async for chunk in self.stream.iter_chunks(): 

213 yield chunk