Coverage for agentos/api/streaming.py: 39%
102 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Streaming SSE (Server-Sent Events) endpoint for agent interactions.
4Provides real-time streaming of agent outputs via HTTP SSE, enabling
5browser-based chat UIs and real-time monitoring dashboards.
6"""
8from __future__ import annotations
10import asyncio
11import json
12import time
13from collections import defaultdict
14from dataclasses import dataclass, field
15from typing import Any, AsyncIterator, Optional
18@dataclass
19class StreamEvent:
20 """Single SSE event emitted by the stream."""
22 event: str
23 """Event type: 'chunk', 'tool_call', 'tool_result', 'done', 'error'."""
25 data: dict[str, Any]
26 """Event payload as JSON-serializable dict."""
28 id: Optional[str] = None
29 """Optional event ID for resume support."""
31 retry: Optional[int] = None
32 """Reconnection retry interval in milliseconds."""
34 def to_sse(self) -> str:
35 """Format as SSE wire format."""
36 lines: list[str] = []
37 if self.id:
38 lines.append(f"id: {self.id}")
39 if self.event:
40 lines.append(f"event: {self.event}")
41 lines.append(f"data: {json.dumps(self.data, ensure_ascii=False)}")
42 if self.retry:
43 lines.append(f"retry: {self.retry}")
44 lines.append("") # blank line terminates event
45 return "\n".join(lines)
48@dataclass
49class StreamSession:
50 """Track an active streaming session."""
52 session_id: str
53 started_at: float = field(default_factory=time.time)
54 events_emitted: int = 0
55 last_event_at: float = 0.0
56 metadata: dict[str, Any] = field(default_factory=dict)
59class StreamingAgent:
60 """
61 Agent that emits Server-Sent Events for real-time streaming.
63 Example (FastAPI integration)::
65 streaming = StreamingAgent(agent_loop)
67 @app.get("/agent/stream")
68 async def stream():
69 return StreamingResponse(
70 streaming.stream_chat("What is quantum computing?", "session-1"),
71 media_type="text/event-stream"
72 )
73 """
75 def __init__(
76 self,
77 agent_loop: Any = None,
78 heartbeat_interval: float = 15.0,
79 ):
80 """
81 Args:
82 agent_loop: The underlying agent loop (sync or async).
83 heartbeat_interval: Seconds between heartbeat keepalive events.
84 """
85 self._loop = agent_loop
86 self._heartbeat = heartbeat_interval
87 self._sessions: dict[str, StreamSession] = defaultdict(StreamSession)
89 async def stream_chat(
90 self,
91 message: str,
92 session_id: str = "default",
93 ) -> AsyncIterator[str]:
94 """
95 Stream a chat interaction as SSE events.
97 Yields:
98 SSE-formatted strings suitable for HTTP response body.
99 """
100 session = self._sessions[session_id]
101 session.session_id = session_id
102 t_start = time.time()
104 # Emit start event
105 yield StreamEvent(
106 event="start",
107 data={"session_id": session_id, "message": message},
108 ).to_sse()
109 session.events_emitted += 1
111 # Simulate streaming chunks (integrate with real agent loop)
112 chunks = self._generate_chunks(message)
113 heartbeat_task = asyncio.create_task(
114 self._heartbeat_loop(session_id)
115 )
117 try:
118 async for chunk in chunks:
119 yield StreamEvent(
120 event="chunk",
121 data={"content": chunk, "session_id": session_id},
122 ).to_sse()
123 session.events_emitted += 1
124 session.last_event_at = time.time()
125 finally:
126 heartbeat_task.cancel()
127 try:
128 await heartbeat_task
129 except asyncio.CancelledError:
130 pass
132 # Emit done event
133 total_ms = (time.time() - t_start) * 1000
134 yield StreamEvent(
135 event="done",
136 data={
137 "session_id": session_id,
138 "total_latency_ms": total_ms,
139 "events_emitted": session.events_emitted,
140 },
141 ).to_sse()
143 def stream_chat_sync(self, message: str, session_id: str = "default"):
144 """Synchronous wrapper for stream_chat."""
145 loop = asyncio.get_event_loop()
146 return _SyncSSEWrapper(
147 loop.run_until_complete(
148 self._collect_events(message, session_id)
149 )
150 )
152 async def _collect_events(
153 self, message: str, session_id: str
154 ) -> list[str]:
155 events: list[str] = []
156 async for sse in self.stream_chat(message, session_id):
157 events.append(sse)
158 return events
160 async def _generate_chunks(self, message: str) -> AsyncIterator[str]:
161 """Generate streaming text chunks. Override with real LLM integration."""
162 if self._loop and hasattr(self._loop, "run"):
163 # Integrate with actual agent loop
164 result = self._loop.run(message)
165 text = str(result.output) if hasattr(result, "output") else str(result)
166 words = text.split()
167 for i, word in enumerate(words):
168 chunk = word + (" " if i < len(words) - 1 else "")
169 yield chunk
170 await asyncio.sleep(0.02) # simulate streaming
171 else:
172 # Fallback: simulate streaming
173 words = message.split()
174 yield f"Processing: {message}\n"
175 await asyncio.sleep(0.3)
176 for i in range(3):
177 yield f"Agent step {i + 1}: analyzing...\n"
178 await asyncio.sleep(0.5)
179 yield f"Complete. Response for: {message}"
181 async def _heartbeat_loop(self, session_id: str) -> None:
182 """Send periodic heartbeat comments to keep connection alive."""
183 while True:
184 await asyncio.sleep(self._heartbeat)
186 def emit_tool_call(self, session_id: str, tool_name: str, args: dict) -> str:
187 """Emit a tool_call SSE event (non-streaming helper)."""
188 return StreamEvent(
189 event="tool_call",
190 data={
191 "session_id": session_id,
192 "tool": tool_name,
193 "arguments": args,
194 },
195 ).to_sse()
197 def emit_tool_result(
198 self, session_id: str, tool_name: str, result: Any
199 ) -> str:
200 """Emit a tool_result SSE event."""
201 return StreamEvent(
202 event="tool_result",
203 data={
204 "session_id": session_id,
205 "tool": tool_name,
206 "result": result,
207 },
208 ).to_sse()
210 def emit_error(self, session_id: str, error: str) -> str:
211 """Emit an error SSE event."""
212 return StreamEvent(
213 event="error",
214 data={"session_id": session_id, "error": error},
215 ).to_sse()
217 def get_session(self, session_id: str) -> Optional[StreamSession]:
218 return self._sessions.get(session_id)
220 def list_sessions(self) -> dict[str, StreamSession]:
221 return dict(self._sessions)
224class _SyncSSEWrapper:
225 """Makes a list of SSE strings iterable for sync streaming."""
227 def __init__(self, events: list[str]):
228 self._events = events
230 def __iter__(self):
231 return iter(self._events)