Coverage for agentos/protocols/a2a_streaming.py: 37%

126 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2A2A Streaming — real-time task status updates via SSE for A2A protocol. 

3 

4Provides push-based task lifecycle notifications so agents don't poll. 

5""" 

6 

7from __future__ import annotations 

8 

9import asyncio 

10import json 

11import time 

12from dataclasses import dataclass, field 

13from enum import Enum 

14from typing import Any, AsyncIterator, Callable, Dict, Optional 

15 

16from agentos.protocols.a2a import A2ATask, TaskState 

17 

18 

19class A2AStreamEvent(str, Enum): 

20 """A2A-specific streaming event types.""" 

21 

22 TASK_CREATED = "task.created" 

23 TASK_STARTED = "task.started" 

24 TASK_PROGRESS = "task.progress" 

25 TASK_COMPLETED = "task.completed" 

26 TASK_FAILED = "task.failed" 

27 TASK_CANCELLED = "task.cancelled" 

28 ARTIFACT_ADDED = "artifact.added" 

29 HEARTBEAT = "heartbeat" 

30 

31 

32@dataclass 

33class TaskProgress: 

34 """Progress update within a running task.""" 

35 

36 percent: float = 0.0 

37 message: str = "" 

38 step: str = "" 

39 metadata: dict[str, Any] = field(default_factory=dict) 

40 

41 

42class A2AStreamSession: 

43 """Manages a streaming connection for a single task. 

44 

45 Agents subscribe to receive push updates as the task progresses. 

46 """ 

47 

48 def __init__(self, task: A2ATask): 

49 self.task_id = task.task_id 

50 self._subscribers: list[asyncio.Queue[dict]] = [] 

51 self._closed = False 

52 self._heartbeat_task: Optional[asyncio.Task] = None 

53 

54 async def start(self, heartbeat_s: float = 30.0): 

55 """Start heartbeat loop.""" 

56 async def _pulse(): 

57 while not self._closed: 

58 await asyncio.sleep(heartbeat_s) 

59 if not self._closed: 

60 await self._broadcast({ 

61 "event": A2AStreamEvent.HEARTBEAT, 

62 "task_id": self.task_id, 

63 "timestamp": time.time(), 

64 }) 

65 self._heartbeat_task = asyncio.create_task(_pulse()) 

66 

67 def subscribe(self) -> asyncio.Queue[dict]: 

68 """Register a new subscriber. Returns a queue of SSE events.""" 

69 q: asyncio.Queue[dict] = asyncio.Queue(maxsize=64) 

70 self._subscribers.append(q) 

71 return q 

72 

73 def unsubscribe(self, sub: asyncio.Queue): 

74 """Remove a subscriber.""" 

75 try: 

76 self._subscribers.remove(sub) 

77 except ValueError: 

78 pass 

79 

80 async def emit(self, event: A2AStreamEvent, data: dict | None = None): 

81 """Push an event to all subscribers.""" 

82 payload = { 

83 "event": event.value, 

84 "task_id": self.task_id, 

85 "timestamp": time.time(), 

86 } 

87 if data: 

88 payload["data"] = data 

89 await self._broadcast(payload) 

90 

91 async def _broadcast(self, payload: dict): 

92 dead: list[asyncio.Queue] = [] 

93 for q in self._subscribers: 

94 try: 

95 q.put_nowait(payload) 

96 except asyncio.QueueFull: 

97 dead.append(q) 

98 for q in dead: 

99 self.unsubscribe(q) 

100 

101 async def close(self): 

102 """Shut down the stream.""" 

103 self._closed = True 

104 if self._heartbeat_task: 

105 self._heartbeat_task.cancel() 

106 # Close all subscriber queues 

107 for q in self._subscribers: 

108 try: 

109 q.put_nowait(None) # Sentinel 

110 except asyncio.QueueFull: 

111 pass 

112 self._subscribers.clear() 

113 

114 async def iter_events(self, subscriber: asyncio.Queue) -> AsyncIterator[dict]: 

115 """Async iterator yielding SSE-compatible event dicts.""" 

116 while True: 

117 event = await subscriber.get() 

118 if event is None: 

119 break 

120 yield event 

121 

122 def to_sse(self, event: dict) -> str: 

123 """Format a single event dict into SSE wire format.""" 

124 lines: list[str] = [f"event: {event['event']}"] 

125 for key in ("task_id", "timestamp"): 

126 if key in event: 

127 lines.append(f"id: {key}={event[key]}") 

128 data_str = json.dumps(event.get("data", {}), ensure_ascii=False) 

129 for line in data_str.split("\n"): 

130 lines.append(f"data: {line}") 

131 return "\n".join(lines) + "\n\n" 

132 

133 

134class StreamingAggregator: 

135 """流式结果聚合器 — 合规测试套件要求。""" 

136 

137 def __init__(self): 

138 self._chunks: list[str] = [] 

139 

140 def collect(self, chunk: str) -> None: 

141 self._chunks.append(chunk) 

142 

143 def aggregated(self) -> str: 

144 return "".join(self._chunks) 

145 

146 

147class A2AStreamManager: 

148 """Global manager for A2A task streaming sessions. 

149 

150 Tracks all active task streams and dispatches events on state transitions. 

151 """ 

152 

153 def __init__(self): 

154 self._sessions: dict[str, A2AStreamSession] = {} 

155 self._on_state_change: Optional[Callable] = None 

156 

157 def on_state_change(self, callback: Callable[[A2ATask, TaskState, TaskState], Any]): 

158 """Register a hook called on every state transition (old_state, new_state).""" 

159 self._on_state_change = callback 

160 

161 def create_session(self, task: A2ATask) -> A2AStreamSession: 

162 """Create a streaming session for a new task.""" 

163 session = A2AStreamSession(task) 

164 self._sessions[task.task_id] = session 

165 return session 

166 

167 def get_session(self, task_id: str) -> Optional[A2AStreamSession]: 

168 return self._sessions.get(task_id) 

169 

170 async def notify_state_change(self, task: A2ATask, old_state: TaskState): 

171 """Called when a task transitions state.""" 

172 session = self._sessions.get(task.task_id) 

173 if not session: 

174 return 

175 

176 event_map = { 

177 TaskState.SUBMITTED: A2AStreamEvent.TASK_CREATED, 

178 TaskState.WORKING: A2AStreamEvent.TASK_STARTED, 

179 TaskState.COMPLETED: A2AStreamEvent.TASK_COMPLETED, 

180 TaskState.FAILED: A2AStreamEvent.TASK_FAILED, 

181 TaskState.CANCELLED: A2AStreamEvent.TASK_CANCELLED, 

182 } 

183 event = event_map.get(task.state, A2AStreamEvent.TASK_PROGRESS) 

184 await session.emit(event, { 

185 "previous_state": old_state.value, 

186 "current_state": task.state.value, 

187 "error": task.error, 

188 }) 

189 

190 if task.is_terminal(): 

191 await session.close() 

192 del self._sessions[task.task_id] 

193 

194 async def notify_artifact(self, task_id: str, artifact_name: str): 

195 """Called when an artifact is added to a task.""" 

196 session = self._sessions.get(task_id) 

197 if session: 

198 await session.emit(A2AStreamEvent.ARTIFACT_ADDED, { 

199 "artifact_name": artifact_name, 

200 }) 

201 

202 async def notify_progress( 

203 self, 

204 task_id: str, 

205 progress: TaskProgress, 

206 ): 

207 """Push a progress update to subscribers.""" 

208 session = self._sessions.get(task_id) 

209 if session: 

210 await session.emit(A2AStreamEvent.TASK_PROGRESS, { 

211 "percent": progress.percent, 

212 "message": progress.message, 

213 "step": progress.step, 

214 "metadata": progress.metadata, 

215 }) 

216 

217 async def shutdown(self): 

218 """Gracefully close all sessions.""" 

219 for sid in list(self._sessions.keys()): 

220 session = self._sessions[sid] 

221 await session.close() 

222 self._sessions.clear()