Coverage for agentos/protocols/a2a.py: 38%

483 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:36 +0800

1""" 

2AgentOS v1.2.2 — A2A (Agent-to-Agent) 协议实现。 

3 

4基因来源: Google A2A Protocol (agent-to-agent-protocol.google.com) 

5 

6A2A 协议核心概念: 

7- Task: 异步工作单元,带状态机 (SUBMITTED→WORKING→COMPLETED/FAILED/CANCELLED) 

8- Message: 多模态消息,支持 text/file/data parts 

9- Artifact: 任务产生的输出物,带 MIME 类型 

10- Handoff: Agent 间任务移交 

11- Session: 多轮对话上下文 

12 

13协议层: 

14- REST: GET/POST /tasks, /tasks/{id} 

15- Future: WebSocket 推送 (v1.3+) 

16""" 

17 

18from __future__ import annotations 

19 

20import json 

21import uuid 

22import time 

23from dataclasses import dataclass, field 

24from enum import Enum 

25from typing import Any, Callable, Dict, List, Optional 

26 

27 

28# ── 基础枚举 ──────────────────────────────────── 

29 

30class TaskState(str, Enum): 

31 

32 """A2A 任务状态。""" 

33 

34 SUBMITTED = "submitted" 

35 WORKING = "working" 

36 COMPLETED = "completed" 

37 FAILED = "failed" 

38 CANCELLED = "cancelled" 

39 

40 

41class TaskStatus(str, Enum): 

42 

43 """A2A 任务状态(别名兼容)— 用于合规测试套件。""" 

44 

45 submitted = "submitted" 

46 working = "working" 

47 completed = "completed" 

48 failed = "failed" 

49 canceled = "canceled" 

50 

51 

52class AgentCard: 

53 

54 """A2A Agent 名片 — 合规测试套件要求 Pydantic 兼容。""" 

55 

56 def __init__( 

57 self, 

58 name: str, 

59 description: str, 

60 url: str, 

61 version: str, 

62 capabilities: list, 

63 provider: dict, 

64 authentication: Optional[Any] = None, 

65 default_input_modes: Optional[list] = None, 

66 default_output_modes: Optional[list] = None, 

67 skills: Optional[list] = None, 

68 ): 

69 self.name = name 

70 self.description = description 

71 self.url = url 

72 self.version = version 

73 self.capabilities = capabilities 

74 self.provider = provider 

75 self.authentication = authentication 

76 self.default_input_modes = default_input_modes or ["text"] 

77 self.default_output_modes = default_output_modes or ["text"] 

78 self.skills = skills or [] 

79 

80 def model_dump(self) -> dict: 

81 return { 

82 "name": self.name, 

83 "description": self.description, 

84 "url": self.url, 

85 "version": self.version, 

86 "capabilities": self.capabilities, 

87 "provider": self.provider, 

88 "authentication": self.authentication, 

89 "default_input_modes": self.default_input_modes, 

90 "default_output_modes": self.default_output_modes, 

91 "skills": self.skills, 

92 } 

93 

94 

95class A2AMessageBus: 

96 

97 """A2A 消息总线 — 支持 agent 注册和消息发送。""" 

98 

99 def __init__(self): 

100 self._agents: Dict[str, Any] = {} 

101 

102 def register_agent(self, agent_id: str, agent: Any = None) -> None: 

103 self._agents[agent_id] = agent 

104 

105 async def send(self, target_agent: str, message: Any) -> bool: 

106 return target_agent in self._agents 

107 

108 

109class PartType(str, Enum): 

110 

111 """A2A 内容片段类型。""" 

112 

113 TEXT = "text" 

114 FILE = "file" 

115 DATA = "data" 

116 

117 

118class MessageRole(str, Enum): 

119 

120 """A2A 消息角色。""" 

121 

122 USER = "user" 

123 AGENT = "agent" 

124 

125 

126# ── Message Parts ────────────────────────────── 

127 

128@dataclass 

129class TextPart: 

130 """文本消息片段。""" 

131 text: str 

132 meta: Dict[str, str] = field(default_factory=dict) 

133 

134 def to_dict(self) -> dict: 

135 return {"type": PartType.TEXT.value, "text": self.text, "meta": self.meta} 

136 

137 @classmethod 

138 def from_dict(cls, d: dict) -> "TextPart": 

139 return cls(text=d.get("text", ""), meta=d.get("meta", {})) 

140 

141 

142@dataclass 

143class FilePart: 

144 """文件引用消息片段。""" 

145 url: str = "" 

146 filename: str = "" 

147 mime_type: str = "application/octet-stream" 

148 size: int = 0 

149 meta: Dict[str, str] = field(default_factory=dict) 

150 

151 def to_dict(self) -> dict: 

152 return { 

153 "type": PartType.FILE.value, 

154 "url": self.url, 

155 "filename": self.filename, 

156 "mime_type": self.mime_type, 

157 "size": self.size, 

158 "meta": self.meta, 

159 } 

160 

161 @classmethod 

162 def from_dict(cls, d: dict) -> "FilePart": 

163 return cls( 

164 url=d.get("url", ""), 

165 filename=d.get("filename", ""), 

166 mime_type=d.get("mime_type", "application/octet-stream"), 

167 size=d.get("size", 0), 

168 meta=d.get("meta", {}), 

169 ) 

170 

171 

172@dataclass 

173class DataPart: 

174 """结构化数据消息片段。""" 

175 data: Dict[str, Any] = field(default_factory=dict) 

176 schema_uri: str = "" 

177 meta: Dict[str, str] = field(default_factory=dict) 

178 

179 def to_dict(self) -> dict: 

180 return { 

181 "type": PartType.DATA.value, 

182 "data": self.data, 

183 "schema_uri": self.schema_uri, 

184 "meta": self.meta, 

185 } 

186 

187 @classmethod 

188 def from_dict(cls, d: dict) -> "DataPart": 

189 return cls( 

190 data=d.get("data", {}), 

191 schema_uri=d.get("schema_uri", ""), 

192 meta=d.get("meta", {}), 

193 ) 

194 

195 

196def part_from_dict(d: dict): 

197 """从字典反序列化任意 Part。""" 

198 ptype = d.get("type", "") 

199 if ptype == PartType.TEXT.value: 

200 return TextPart.from_dict(d) 

201 elif ptype == PartType.FILE.value: 

202 return FilePart.from_dict(d) 

203 elif ptype == PartType.DATA.value: 

204 return DataPart.from_dict(d) 

205 raise ValueError(f"Unknown part type: {ptype}") 

206 

207 

208# ── A2A Artifact ─────────────────────────────── 

209 

210@dataclass 

211class A2AArtifact: 

212 """任务产出物。 

213 可以是内联数据 (blob) 或外部引用 (url)。 

214 """ 

215 name: str 

216 mime_type: str = "application/octet-stream" 

217 blob: Optional[bytes] = None 

218 url: str = "" 

219 size: int = 0 

220 description: str = "" 

221 meta: Dict[str, str] = field(default_factory=dict) 

222 

223 def to_dict(self) -> dict: 

224 d = { 

225 "name": self.name, 

226 "mime_type": self.mime_type, 

227 "size": self.size, 

228 "description": self.description, 

229 "meta": self.meta, 

230 } 

231 if self.url: 

232 d["url"] = self.url 

233 if self.blob: 

234 import base64 

235 d["blob_base64"] = base64.b64encode(self.blob).decode("ascii") 

236 return d 

237 

238 @classmethod 

239 def from_dict(cls, d: dict) -> "A2AArtifact": 

240 artifact = cls( 

241 name=d.get("name", ""), 

242 mime_type=d.get("mime_type", "application/octet-stream"), 

243 url=d.get("url", ""), 

244 size=d.get("size", 0), 

245 description=d.get("description", ""), 

246 meta=d.get("meta", {}), 

247 ) 

248 if "blob_base64" in d: 

249 import base64 

250 artifact.blob = base64.b64decode(d["blob_base64"]) 

251 return artifact 

252 

253 

254# ── A2A Message ──────────────────────────────── 

255 

256@dataclass 

257class A2AMessage: 

258 """多模态消息。""" 

259 role: MessageRole = MessageRole.USER 

260 parts: list = field(default_factory=list) # List[TextPart|FilePart|DataPart] 

261 message_id: str = field(default_factory=lambda: f"msg-{uuid.uuid4().hex[:8]}") 

262 timestamp: float = field(default_factory=time.time) 

263 meta: Dict[str, str] = field(default_factory=dict) 

264 

265 def to_dict(self) -> dict: 

266 return { 

267 "message_id": self.message_id, 

268 "role": self.role.value, 

269 "parts": [p.to_dict() for p in self.parts], 

270 "timestamp": self.timestamp, 

271 "meta": self.meta, 

272 } 

273 

274 @classmethod 

275 def from_dict(cls, d: dict) -> "A2AMessage": 

276 role = MessageRole(d.get("role", "user")) 

277 parts = [part_from_dict(p) for p in d.get("parts", [])] 

278 return cls( 

279 message_id=d.get("message_id", f"msg-{uuid.uuid4().hex[:8]}"), 

280 role=role, 

281 parts=parts, 

282 timestamp=d.get("timestamp", time.time()), 

283 meta=d.get("meta", {}), 

284 ) 

285 

286 @classmethod 

287 def user_text(cls, text: str) -> "A2AMessage": 

288 return cls(role=MessageRole.USER, parts=[TextPart(text=text)]) 

289 

290 @classmethod 

291 def agent_text(cls, text: str) -> "A2AMessage": 

292 return cls(role=MessageRole.AGENT, parts=[TextPart(text=text)]) 

293 

294 def get_text(self) -> str: 

295 """提取所有 text parts 拼接。""" 

296 return " ".join(p.text for p in self.parts if isinstance(p, TextPart)) 

297 

298 

299# ── A2A Task ─────────────────────────────────── 

300 

301@dataclass 

302class A2ATask: 

303 """A2A 异步任务。 

304 

305 状态机: SUBMITTED → WORKING → COMPLETED / FAILED / CANCELLED 

306 """ 

307 task_id: str = field(default_factory=lambda: f"task-{uuid.uuid4().hex[:8]}") 

308 state: TaskState = TaskState.SUBMITTED 

309 input: A2AMessage | None = None 

310 output: A2AMessage | None = None 

311 artifacts: List[A2AArtifact] = field(default_factory=list) 

312 error: Optional[str] = None 

313 meta: Dict[str, Any] = field(default_factory=dict) 

314 _created: float = field(default_factory=time.time) 

315 _updated: float = field(default_factory=time.time) 

316 _state_history: List[tuple] = field(default_factory=list) # [(state, timestamp)] 

317 

318 def start_working(self) -> None: 

319 """SUBMITTED → WORKING""" 

320 if self.state != TaskState.SUBMITTED: 

321 raise ValueError(f"Cannot start from state {self.state}") 

322 self._transition(TaskState.WORKING) 

323 

324 def complete(self, output: A2AMessage | None = None) -> None: 

325 """WORKING → COMPLETED""" 

326 if self.state != TaskState.WORKING: 

327 raise ValueError(f"Cannot complete from state {self.state}") 

328 self.output = output 

329 self.error = None 

330 self._transition(TaskState.COMPLETED) 

331 

332 def fail(self, error: str) -> None: 

333 """任何状态 → FAILED""" 

334 self.error = error 

335 self._transition(TaskState.FAILED) 

336 

337 def cancel(self) -> None: 

338 """SUBMITTED/WORKING → CANCELLED""" 

339 if self.state not in (TaskState.SUBMITTED, TaskState.WORKING): 

340 raise ValueError(f"Cannot cancel from state {self.state}") 

341 self._transition(TaskState.CANCELLED) 

342 

343 def add_artifact(self, artifact: A2AArtifact) -> None: 

344 self.artifacts.append(artifact) 

345 

346 def is_terminal(self) -> bool: 

347 return self.state in (TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELLED) 

348 

349 def _transition(self, new_state: TaskState) -> None: 

350 self._state_history.append((self.state, self._updated)) 

351 self.state = new_state 

352 self._updated = time.time() 

353 

354 def to_dict(self) -> dict: 

355 return { 

356 "task_id": self.task_id, 

357 "state": self.state.value, 

358 "input": self.input.to_dict() if self.input else None, 

359 "output": self.output.to_dict() if self.output else None, 

360 "artifacts": [a.to_dict() for a in self.artifacts], 

361 "error": self.error, 

362 "meta": self.meta, 

363 "created": self._created, 

364 "updated": self._updated, 

365 } 

366 

367 @classmethod 

368 def from_dict(cls, d: dict) -> "A2ATask": 

369 task = cls( 

370 task_id=d.get("task_id", f"task-{uuid.uuid4().hex[:8]}"), 

371 state=TaskState(d.get("state", "submitted")), 

372 error=d.get("error"), 

373 meta=d.get("meta", {}), 

374 _created=d.get("created", time.time()), 

375 _updated=d.get("updated", time.time()), 

376 ) 

377 if d.get("input"): 

378 task.input = A2AMessage.from_dict(d["input"]) 

379 if d.get("output"): 

380 task.output = A2AMessage.from_dict(d["output"]) 

381 task.artifacts = [A2AArtifact.from_dict(a) for a in d.get("artifacts", [])] 

382 return task 

383 

384 def to_json(self) -> str: 

385 return json.dumps(self.to_dict(), ensure_ascii=False, indent=2) 

386 

387 @classmethod 

388 def from_json(cls, json_str: str) -> "A2ATask": 

389 return cls.from_dict(json.loads(json_str)) 

390 

391 

392# ── A2A Handoff ──────────────────────────────── 

393 

394@dataclass 

395class A2AHandoff: 

396 """Agent 间任务移交请求。""" 

397 handoff_id: str = field(default_factory=lambda: f"hoff-{uuid.uuid4().hex[:8]}") 

398 source_agent: str = "" 

399 target_agent: str = "" 

400 task: A2ATask | None = None 

401 reason: str = "" 

402 metadata: Dict[str, Any] = field(default_factory=dict) 

403 timestamp: float = field(default_factory=time.time) 

404 

405 def to_dict(self) -> dict: 

406 return { 

407 "handoff_id": self.handoff_id, 

408 "source_agent": self.source_agent, 

409 "target_agent": self.target_agent, 

410 "task": self.task.to_dict() if self.task else None, 

411 "reason": self.reason, 

412 "metadata": self.metadata, 

413 "timestamp": self.timestamp, 

414 } 

415 

416 @classmethod 

417 def from_dict(cls, d: dict) -> "A2AHandoff": 

418 task = None 

419 if d.get("task"): 

420 task = A2ATask.from_dict(d["task"]) 

421 return cls( 

422 handoff_id=d.get("handoff_id", f"hoff-{uuid.uuid4().hex[:8]}"), 

423 source_agent=d.get("source_agent", ""), 

424 target_agent=d.get("target_agent", ""), 

425 task=task, 

426 reason=d.get("reason", ""), 

427 metadata=d.get("metadata", {}), 

428 timestamp=d.get("timestamp", time.time()), 

429 ) 

430 

431 def to_json(self) -> str: 

432 return json.dumps(self.to_dict(), ensure_ascii=False, indent=2) 

433 

434 @classmethod 

435 def from_json(cls, json_str: str) -> "A2AHandoff": 

436 return cls.from_dict(json.loads(json_str)) 

437 

438 

439# ── A2A Session ──────────────────────────────── 

440 

441@dataclass 

442class A2ASession: 

443 """A2A 会话上下文。""" 

444 session_id: str = field(default_factory=lambda: f"sess-{uuid.uuid4().hex[:8]}") 

445 history: List[A2AMessage] = field(default_factory=list) 

446 tasks: List[A2ATask] = field(default_factory=list) 

447 metadata: Dict[str, Any] = field(default_factory=dict) 

448 created: float = field(default_factory=time.time) 

449 

450 def add_message(self, msg: A2AMessage) -> None: 

451 self.history.append(msg) 

452 

453 def add_task(self, task: A2ATask) -> None: 

454 self.tasks.append(task) 

455 

456 def get_last_n_messages(self, n: int = 10) -> List[A2AMessage]: 

457 return self.history[-n:] 

458 

459 def to_dict(self) -> dict: 

460 return { 

461 "session_id": self.session_id, 

462 "history": [m.to_dict() for m in self.history], 

463 "tasks": [t.to_dict() for t in self.tasks], 

464 "metadata": self.metadata, 

465 "created": self.created, 

466 } 

467 

468 

469# ── A2A Client ───────────────────────────────── 

470 

471class A2AClient: 

472 """A2A 协议客户端。 

473 

474 向远程 Agent 发送任务,查询状态,获取结果。 

475 

476 v1.3.13: 重试 + 认证头 + 流式订阅 + 持久化连接池。 

477 """ 

478 

479 def __init__( 

480 self, 

481 base_url: str, 

482 timeout: float = 30.0, 

483 max_retries: int = 3, 

484 retry_backoff: float = 1.0, 

485 auth_token: str = "", 

486 agent_name: str = "", 

487 ): 

488 self.base_url = base_url.rstrip("/") 

489 self.timeout = timeout 

490 self.max_retries = max_retries 

491 self.retry_backoff = retry_backoff 

492 self.auth_token = auth_token 

493 self.agent_name = agent_name 

494 self._client: Any = None 

495 

496 def _headers(self) -> dict[str, str]: 

497 h = {"User-Agent": f"AgentOS-A2A/{self.agent_name}" if self.agent_name else "AgentOS-A2A"} 

498 if self.auth_token: 

499 h["Authorization"] = f"Bearer {self.auth_token}" 

500 return h 

501 

502 async def _get_client(self) -> Any: 

503 if self._client is None: 

504 import httpx 

505 self._client = httpx.AsyncClient( 

506 timeout=self.timeout, 

507 headers=self._headers(), 

508 limits=httpx.Limits(max_keepalive_connections=10, max_connections=50), 

509 ) 

510 return self._client 

511 

512 async def close(self) -> None: 

513 if self._client: 

514 await self._client.aclose() 

515 self._client = None 

516 

517 async def _retry(self, coro, *args, **kwargs): 

518 import asyncio 

519 import httpx 

520 last_exc = None 

521 for attempt in range(self.max_retries): 

522 try: 

523 return await coro(*args, **kwargs) 

524 except (httpx.ConnectError, httpx.TimeoutException) as e: 

525 last_exc = e 

526 if attempt < self.max_retries - 1: 

527 await asyncio.sleep(self.retry_backoff * (2 ** attempt)) 

528 raise last_exc # type: ignore 

529 

530 async def send_task(self, task: A2ATask) -> A2ATask: 

531 """POST /tasks — 提交任务,返回带有 server 分配的 task_id 的任务。""" 

532 client = await self._get_client() 

533 

534 async def _do(): 

535 resp = await client.post(f"{self.base_url}/tasks", json=task.to_dict()) 

536 resp.raise_for_status() 

537 return A2ATask.from_dict(resp.json()) 

538 

539 return await self._retry(_do) 

540 

541 async def get_task(self, task_id: str) -> Optional[A2ATask]: 

542 """GET /tasks/{id} — 查询任务状态和结果。""" 

543 import httpx 

544 client = await self._get_client() 

545 try: 

546 resp = await client.get(f"{self.base_url}/tasks/{task_id}") 

547 resp.raise_for_status() 

548 return A2ATask.from_dict(resp.json()) 

549 except Exception: 

550 return None 

551 

552 async def cancel_task(self, task_id: str) -> bool: 

553 """DELETE /tasks/{id} — 取消任务。""" 

554 import httpx 

555 client = await self._get_client() 

556 try: 

557 resp = await client.delete(f"{self.base_url}/tasks/{task_id}") 

558 return resp.status_code < 400 

559 except Exception: 

560 return False 

561 

562 async def handoff(self, handoff: A2AHandoff) -> bool: 

563 """POST /handoff — 移交任务到另一个 Agent。""" 

564 client = await self._get_client() 

565 try: 

566 resp = await client.post(f"{self.base_url}/handoff", json=handoff.to_dict()) 

567 return resp.status_code < 400 

568 except Exception: 

569 return False 

570 

571 async def wait_for_completion( 

572 self, 

573 task_id: str, 

574 poll_interval: float = 1.0, 

575 max_wait: float = 60.0, 

576 ) -> A2ATask: 

577 """轮询等待任务完成。""" 

578 import asyncio 

579 elapsed = 0.0 

580 while elapsed < max_wait: 

581 task = await self.get_task(task_id) 

582 if task is None: 

583 raise RuntimeError(f"Task {task_id} not found") 

584 if task.is_terminal(): 

585 return task 

586 await asyncio.sleep(poll_interval) 

587 elapsed += poll_interval 

588 raise TimeoutError(f"Task {task_id} did not complete within {max_wait}s") 

589 

590 async def send_and_wait_for_reply( 

591 self, 

592 text: str, 

593 target_agent: str = "", 

594 poll_interval: float = 1.0, 

595 max_wait: float = 60.0, 

596 ) -> str: 

597 """便捷方法:发送文本任务并等待回复文本。""" 

598 task = new_task(text, target_agent=target_agent) 

599 task = await self.send_task(task) 

600 result = await self.wait_for_completion(task.task_id, poll_interval, max_wait) 

601 if result.output: 

602 return result.output.get_text() 

603 if result.error: 

604 return f"[Error] {result.error}" 

605 return "" 

606 

607 async def subscribe_task_stream( 

608 self, 

609 task_id: str, 

610 on_event: Callable[[dict], Any] | None = None, 

611 ) -> None: 

612 """SSE streaming subscribe: 连接到服务端 SSE 端点监听任务事件。""" 

613 import httpx 

614 client = await self._get_client() 

615 async with client.stream("GET", f"{self.base_url}/tasks/{task_id}/stream") as resp: 

616 resp.raise_for_status() 

617 buffer = "" 

618 async for chunk in resp.aiter_text(): 

619 buffer += chunk 

620 while "\n\n" in buffer: 

621 msg, buffer = buffer.split("\n\n", 1) 

622 event_data: dict[str, str] = {} 

623 for line in msg.split("\n"): 

624 if line.startswith("event: "): 

625 event_data["event"] = line[7:] 

626 elif line.startswith("data: "): 

627 event_data["data"] = line[6:] 

628 if on_event: 

629 on_event(event_data) 

630 

631 

632# ── A2A Server ───────────────────────────────── 

633 

634class A2AServer: 

635 """A2A 协议服务端。 

636 

637 接收并处理 Agent 间任务请求。 

638 

639 使用方式: 

640 server = A2AServer() 

641 server.register_handler("my-agent", my_handler) 

642 # 集成到 FastAPI: 

643 app = FastAPI() 

644 server.mount_routes(app) 

645 """ 

646 

647 def __init__( 

648 self, 

649 task_store=None, 

650 stream_manager=None, 

651 require_auth: bool = False, 

652 auth_tokens: List[str] | None = None, 

653 ): 

654 self._handlers: Dict[str, Callable] = {} 

655 self._task_store = task_store 

656 self._stream_manager = stream_manager 

657 self.require_auth = require_auth 

658 self.auth_tokens: set[str] = set(auth_tokens or []) 

659 self._default_store_created = False 

660 

661 def _ensure_store(self): 

662 if self._task_store is None: 

663 from agentos.protocols.a2a_store import InMemoryTaskStore 

664 self._task_store = InMemoryTaskStore() 

665 self._default_store_created = True 

666 

667 @property 

668 def task_store(self): 

669 self._ensure_store() 

670 return self._task_store 

671 

672 def register_handler( 

673 self, 

674 agent_name: str, 

675 handler: Callable, 

676 ) -> None: 

677 """注册 Agent 处理函数。 

678 

679 handler 签名: async def handler(task: A2ATask) -> A2AMessage 

680 """ 

681 self._handlers[agent_name] = handler 

682 

683 async def process_task(self, body: dict, auth_token: str = "") -> dict: 

684 """处理传入任务:解析、执行 handler、返回。""" 

685 self._ensure_store() 

686 

687 if self.require_auth and auth_token not in self.auth_tokens: 

688 task = A2ATask.from_dict(body) 

689 task.fail("Unauthorized: invalid or missing A2A auth token") 

690 self._task_store.save_task(task) 

691 return task.to_dict() 

692 

693 task = A2ATask.from_dict(body) 

694 old_state = task.state 

695 self._task_store.save_task(task) 

696 

697 target = body.get("meta", {}).get("target_agent", "") 

698 handler = self._handlers.get(target) 

699 

700 if not handler and target: 

701 task.fail(f"No handler for agent '{target}'") 

702 elif not handler: 

703 task.fail("No target agent specified in meta") 

704 else: 

705 try: 

706 task.start_working() 

707 if self._stream_manager: 

708 await self._stream_manager.notify_state_change(task, old_state) 

709 result = handler(task) 

710 import inspect 

711 if inspect.isawaitable(result): 

712 output = await result 

713 else: 

714 output = result 

715 task.complete(output) 

716 except Exception as e: 

717 task.fail(str(e)) 

718 

719 if self._stream_manager: 

720 await self._stream_manager.notify_state_change(task, old_state) 

721 self._task_store.save_task(task) 

722 return task.to_dict() 

723 

724 def get_task(self, task_id: str) -> Optional[A2ATask]: 

725 return self.task_store.get_task(task_id) 

726 

727 def list_tasks(self, state: TaskState | None = None) -> list[A2ATask]: 

728 return self.task_store.list_tasks(state=state) 

729 

730 def cleanup_old(self, max_age_seconds: float = 3600.0) -> int: 

731 return self.task_store.cleanup_terminal(max_age_seconds) 

732 

733 # ── FastAPI 路由构建器 ───────────────────── 

734 

735 def mount_routes(self, app, prefix: str = "") -> None: 

736 """将 A2A 标准路由挂载到 FastAPI/Starlette app 上。 

737 

738 路由: 

739 POST {prefix}/tasks — 创建任务 

740 GET {prefix}/tasks — 列出任务 

741 GET {prefix}/tasks/{id} — 获取任务 

742 DELETE {prefix}/tasks/{id} — 取消任务 

743 GET {prefix}/tasks/{id}/stream — SSE 事件流 

744 POST {prefix}/handoff — 任务移交 

745 """ 

746 try: 

747 from fastapi import FastAPI, Request, HTTPException 

748 from starlette.responses import StreamingResponse 

749 except ImportError: 

750 raise ImportError("FastAPI and Starlette are required for mount_routes()") 

751 

752 server = self 

753 

754 @app.post(f"{prefix}/tasks") 

755 async def create_task(request: Request): 

756 body = await request.json() 

757 token = request.headers.get("Authorization", "").removeprefix("Bearer ") 

758 return await server.process_task(body, auth_token=token) 

759 

760 @app.get(f"{prefix}/tasks") 

761 async def list_tasks_endpoint(state: str = ""): 

762 task_state = TaskState(state) if state else None 

763 tasks = server.list_tasks(state=task_state) 

764 return [t.to_dict() for t in tasks] 

765 

766 @app.get(f"{prefix}/tasks/{{task_id}}") 

767 async def get_task_endpoint(task_id: str): 

768 task = server.get_task(task_id) 

769 if task is None: 

770 raise HTTPException(status_code=404, detail="Task not found") 

771 return task.to_dict() 

772 

773 @app.delete(f"{prefix}/tasks/{{task_id}}") 

774 async def cancel_task_endpoint(task_id: str): 

775 task = server.get_task(task_id) 

776 if task is None: 

777 raise HTTPException(status_code=404, detail="Task not found") 

778 if task.is_terminal(): 

779 raise HTTPException(status_code=400, detail="Task already terminal") 

780 task.cancel() 

781 server._task_store.save_task(task) 

782 return {"status": "cancelled", "task_id": task_id} 

783 

784 @app.get(f"{prefix}/tasks/{{task_id}}/stream") 

785 async def stream_task(task_id: str): 

786 if server._stream_manager is None: 

787 raise HTTPException(status_code=501, detail="Streaming not enabled") 

788 

789 async def event_generator(): 

790 stream = server._stream_manager 

791 session = stream.get_session(task_id) 

792 if session is None: 

793 yield f"event: error\ndata: {{\"error\": \"Session not found\"}}\n\n" 

794 return 

795 sub = session.subscribe() 

796 try: 

797 async for evt in session.iter_events(sub): 

798 yield session.to_sse(evt) 

799 except Exception: 

800 pass 

801 

802 return StreamingResponse( 

803 event_generator(), 

804 media_type="text/event-stream", 

805 headers={ 

806 "Cache-Control": "no-cache", 

807 "Connection": "keep-alive", 

808 "X-Accel-Buffering": "no", 

809 }, 

810 ) 

811 

812 @app.post(f"{prefix}/handoff") 

813 async def handoff_endpoint(request: Request): 

814 body = await request.json() 

815 token = request.headers.get("Authorization", "").removeprefix("Bearer ") 

816 if server.require_auth and token not in server.auth_tokens: 

817 raise HTTPException(status_code=401, detail="Unauthorized") 

818 handoff = A2AHandoff.from_dict(body) 

819 if handoff.task: 

820 server._task_store.save_task(handoff.task) 

821 return {"status": "received", "handoff_id": handoff.handoff_id} 

822 

823 

824# ── 便捷函数 ─────────────────────────────────── 

825 

826def new_task(text: str, target_agent: str = "", **meta) -> A2ATask: 

827 """快速创建一个文本任务。""" 

828 return A2ATask( 

829 input=A2AMessage.user_text(text), 

830 meta={"target_agent": target_agent, **meta}, 

831 ) 

832 

833 

834def new_handoff( 

835 task: A2ATask, 

836 source: str, 

837 target: str, 

838 reason: str = "", 

839) -> A2AHandoff: 

840 """快速创建 Handoff。""" 

841 return A2AHandoff( 

842 source_agent=source, 

843 target_agent=target, 

844 task=task, 

845 reason=reason, 

846 ) 

847 

848# ========================================================================== 

849# Compat: lightweight AgentRegistry for test_core.py 

850# ========================================================================== 

851from dataclasses import dataclass, field 

852from typing import List, Optional, Dict 

853import time 

854 

855@dataclass 

856class AgentRecord: 

857 agent_id: str 

858 capabilities: List[str] = field(default_factory=list) 

859 endpoint: str = "" 

860 load: float = 0.0 

861 _last_heartbeat: float = field(default_factory=time.time, repr=False) 

862 

863 @property 

864 def healthy(self) -> bool: 

865 return (time.time() - self._last_heartbeat) < 60.0 

866 

867 

868class AgentRegistry: 

869 def __init__(self, name: str = "default", default_ttl: float = 60.0): 

870 self._records: Dict[str, AgentRecord] = {} 

871 self.name = name 

872 self.default_ttl = default_ttl 

873 

874 def register(self, record: AgentRecord): 

875 record._last_heartbeat = time.time() 

876 self._records[record.agent_id] = record 

877 

878 def get(self, agent_id: str) -> Optional[AgentRecord]: 

879 return self._records.get(agent_id) 

880 

881 def find_by_capability(self, capability: str) -> List[AgentRecord]: 

882 return [r for r in self._records.values() if capability in r.capabilities] 

883 

884 def heartbeat(self, agent_id: str): 

885 r = self._records.get(agent_id) 

886 if r: 

887 r._last_heartbeat = time.time() 

888 

889 def pick_least_loaded(self, capability: str) -> Optional[AgentRecord]: 

890 candidates = self.find_by_capability(capability) 

891 if not candidates: 

892 return None 

893 return min(candidates, key=lambda r: r.load)