Coverage for agentos/protocols/a2a_streaming.py: 37%
126 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2A2A Streaming — real-time task status updates via SSE for A2A protocol.
4Provides push-based task lifecycle notifications so agents don't poll.
5"""
7from __future__ import annotations
9import asyncio
10import json
11import time
12from dataclasses import dataclass, field
13from enum import Enum
14from typing import Any, AsyncIterator, Callable, Dict, Optional
16from agentos.protocols.a2a import A2ATask, TaskState
19class A2AStreamEvent(str, Enum):
20 """A2A-specific streaming event types."""
22 TASK_CREATED = "task.created"
23 TASK_STARTED = "task.started"
24 TASK_PROGRESS = "task.progress"
25 TASK_COMPLETED = "task.completed"
26 TASK_FAILED = "task.failed"
27 TASK_CANCELLED = "task.cancelled"
28 ARTIFACT_ADDED = "artifact.added"
29 HEARTBEAT = "heartbeat"
32@dataclass
33class TaskProgress:
34 """Progress update within a running task."""
36 percent: float = 0.0
37 message: str = ""
38 step: str = ""
39 metadata: dict[str, Any] = field(default_factory=dict)
42class A2AStreamSession:
43 """Manages a streaming connection for a single task.
45 Agents subscribe to receive push updates as the task progresses.
46 """
48 def __init__(self, task: A2ATask):
49 self.task_id = task.task_id
50 self._subscribers: list[asyncio.Queue[dict]] = []
51 self._closed = False
52 self._heartbeat_task: Optional[asyncio.Task] = None
54 async def start(self, heartbeat_s: float = 30.0):
55 """Start heartbeat loop."""
56 async def _pulse():
57 while not self._closed:
58 await asyncio.sleep(heartbeat_s)
59 if not self._closed:
60 await self._broadcast({
61 "event": A2AStreamEvent.HEARTBEAT,
62 "task_id": self.task_id,
63 "timestamp": time.time(),
64 })
65 self._heartbeat_task = asyncio.create_task(_pulse())
67 def subscribe(self) -> asyncio.Queue[dict]:
68 """Register a new subscriber. Returns a queue of SSE events."""
69 q: asyncio.Queue[dict] = asyncio.Queue(maxsize=64)
70 self._subscribers.append(q)
71 return q
73 def unsubscribe(self, sub: asyncio.Queue):
74 """Remove a subscriber."""
75 try:
76 self._subscribers.remove(sub)
77 except ValueError:
78 pass
80 async def emit(self, event: A2AStreamEvent, data: dict | None = None):
81 """Push an event to all subscribers."""
82 payload = {
83 "event": event.value,
84 "task_id": self.task_id,
85 "timestamp": time.time(),
86 }
87 if data:
88 payload["data"] = data
89 await self._broadcast(payload)
91 async def _broadcast(self, payload: dict):
92 dead: list[asyncio.Queue] = []
93 for q in self._subscribers:
94 try:
95 q.put_nowait(payload)
96 except asyncio.QueueFull:
97 dead.append(q)
98 for q in dead:
99 self.unsubscribe(q)
101 async def close(self):
102 """Shut down the stream."""
103 self._closed = True
104 if self._heartbeat_task:
105 self._heartbeat_task.cancel()
106 # Close all subscriber queues
107 for q in self._subscribers:
108 try:
109 q.put_nowait(None) # Sentinel
110 except asyncio.QueueFull:
111 pass
112 self._subscribers.clear()
114 async def iter_events(self, subscriber: asyncio.Queue) -> AsyncIterator[dict]:
115 """Async iterator yielding SSE-compatible event dicts."""
116 while True:
117 event = await subscriber.get()
118 if event is None:
119 break
120 yield event
122 def to_sse(self, event: dict) -> str:
123 """Format a single event dict into SSE wire format."""
124 lines: list[str] = [f"event: {event['event']}"]
125 for key in ("task_id", "timestamp"):
126 if key in event:
127 lines.append(f"id: {key}={event[key]}")
128 data_str = json.dumps(event.get("data", {}), ensure_ascii=False)
129 for line in data_str.split("\n"):
130 lines.append(f"data: {line}")
131 return "\n".join(lines) + "\n\n"
134class StreamingAggregator:
135 """流式结果聚合器 — 合规测试套件要求。"""
137 def __init__(self):
138 self._chunks: list[str] = []
140 def collect(self, chunk: str) -> None:
141 self._chunks.append(chunk)
143 def aggregated(self) -> str:
144 return "".join(self._chunks)
147class A2AStreamManager:
148 """Global manager for A2A task streaming sessions.
150 Tracks all active task streams and dispatches events on state transitions.
151 """
153 def __init__(self):
154 self._sessions: dict[str, A2AStreamSession] = {}
155 self._on_state_change: Optional[Callable] = None
157 def on_state_change(self, callback: Callable[[A2ATask, TaskState, TaskState], Any]):
158 """Register a hook called on every state transition (old_state, new_state)."""
159 self._on_state_change = callback
161 def create_session(self, task: A2ATask) -> A2AStreamSession:
162 """Create a streaming session for a new task."""
163 session = A2AStreamSession(task)
164 self._sessions[task.task_id] = session
165 return session
167 def get_session(self, task_id: str) -> Optional[A2AStreamSession]:
168 return self._sessions.get(task_id)
170 async def notify_state_change(self, task: A2ATask, old_state: TaskState):
171 """Called when a task transitions state."""
172 session = self._sessions.get(task.task_id)
173 if not session:
174 return
176 event_map = {
177 TaskState.SUBMITTED: A2AStreamEvent.TASK_CREATED,
178 TaskState.WORKING: A2AStreamEvent.TASK_STARTED,
179 TaskState.COMPLETED: A2AStreamEvent.TASK_COMPLETED,
180 TaskState.FAILED: A2AStreamEvent.TASK_FAILED,
181 TaskState.CANCELLED: A2AStreamEvent.TASK_CANCELLED,
182 }
183 event = event_map.get(task.state, A2AStreamEvent.TASK_PROGRESS)
184 await session.emit(event, {
185 "previous_state": old_state.value,
186 "current_state": task.state.value,
187 "error": task.error,
188 })
190 if task.is_terminal():
191 await session.close()
192 del self._sessions[task.task_id]
194 async def notify_artifact(self, task_id: str, artifact_name: str):
195 """Called when an artifact is added to a task."""
196 session = self._sessions.get(task_id)
197 if session:
198 await session.emit(A2AStreamEvent.ARTIFACT_ADDED, {
199 "artifact_name": artifact_name,
200 })
202 async def notify_progress(
203 self,
204 task_id: str,
205 progress: TaskProgress,
206 ):
207 """Push a progress update to subscribers."""
208 session = self._sessions.get(task_id)
209 if session:
210 await session.emit(A2AStreamEvent.TASK_PROGRESS, {
211 "percent": progress.percent,
212 "message": progress.message,
213 "step": progress.step,
214 "metadata": progress.metadata,
215 })
217 async def shutdown(self):
218 """Gracefully close all sessions."""
219 for sid in list(self._sessions.keys()):
220 session = self._sessions[sid]
221 await session.close()
222 self._sessions.clear()