Coverage for agentos/protocols/a2a.py: 38%
483 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
1"""
2AgentOS v1.2.2 — A2A (Agent-to-Agent) 协议实现。
4基因来源: Google A2A Protocol (agent-to-agent-protocol.google.com)
6A2A 协议核心概念:
7- Task: 异步工作单元,带状态机 (SUBMITTED→WORKING→COMPLETED/FAILED/CANCELLED)
8- Message: 多模态消息,支持 text/file/data parts
9- Artifact: 任务产生的输出物,带 MIME 类型
10- Handoff: Agent 间任务移交
11- Session: 多轮对话上下文
13协议层:
14- REST: GET/POST /tasks, /tasks/{id}
15- Future: WebSocket 推送 (v1.3+)
16"""
18from __future__ import annotations
20import json
21import uuid
22import time
23from dataclasses import dataclass, field
24from enum import Enum
25from typing import Any, Callable, Dict, List, Optional
28# ── 基础枚举 ────────────────────────────────────
30class TaskState(str, Enum):
32 """A2A 任务状态。"""
34 SUBMITTED = "submitted"
35 WORKING = "working"
36 COMPLETED = "completed"
37 FAILED = "failed"
38 CANCELLED = "cancelled"
41class TaskStatus(str, Enum):
43 """A2A 任务状态(别名兼容)— 用于合规测试套件。"""
45 submitted = "submitted"
46 working = "working"
47 completed = "completed"
48 failed = "failed"
49 canceled = "canceled"
52class AgentCard:
54 """A2A Agent 名片 — 合规测试套件要求 Pydantic 兼容。"""
56 def __init__(
57 self,
58 name: str,
59 description: str,
60 url: str,
61 version: str,
62 capabilities: list,
63 provider: dict,
64 authentication: Optional[Any] = None,
65 default_input_modes: Optional[list] = None,
66 default_output_modes: Optional[list] = None,
67 skills: Optional[list] = None,
68 ):
69 self.name = name
70 self.description = description
71 self.url = url
72 self.version = version
73 self.capabilities = capabilities
74 self.provider = provider
75 self.authentication = authentication
76 self.default_input_modes = default_input_modes or ["text"]
77 self.default_output_modes = default_output_modes or ["text"]
78 self.skills = skills or []
80 def model_dump(self) -> dict:
81 return {
82 "name": self.name,
83 "description": self.description,
84 "url": self.url,
85 "version": self.version,
86 "capabilities": self.capabilities,
87 "provider": self.provider,
88 "authentication": self.authentication,
89 "default_input_modes": self.default_input_modes,
90 "default_output_modes": self.default_output_modes,
91 "skills": self.skills,
92 }
95class A2AMessageBus:
97 """A2A 消息总线 — 支持 agent 注册和消息发送。"""
99 def __init__(self):
100 self._agents: Dict[str, Any] = {}
102 def register_agent(self, agent_id: str, agent: Any = None) -> None:
103 self._agents[agent_id] = agent
105 async def send(self, target_agent: str, message: Any) -> bool:
106 return target_agent in self._agents
109class PartType(str, Enum):
111 """A2A 内容片段类型。"""
113 TEXT = "text"
114 FILE = "file"
115 DATA = "data"
118class MessageRole(str, Enum):
120 """A2A 消息角色。"""
122 USER = "user"
123 AGENT = "agent"
126# ── Message Parts ──────────────────────────────
128@dataclass
129class TextPart:
130 """文本消息片段。"""
131 text: str
132 meta: Dict[str, str] = field(default_factory=dict)
134 def to_dict(self) -> dict:
135 return {"type": PartType.TEXT.value, "text": self.text, "meta": self.meta}
137 @classmethod
138 def from_dict(cls, d: dict) -> "TextPart":
139 return cls(text=d.get("text", ""), meta=d.get("meta", {}))
142@dataclass
143class FilePart:
144 """文件引用消息片段。"""
145 url: str = ""
146 filename: str = ""
147 mime_type: str = "application/octet-stream"
148 size: int = 0
149 meta: Dict[str, str] = field(default_factory=dict)
151 def to_dict(self) -> dict:
152 return {
153 "type": PartType.FILE.value,
154 "url": self.url,
155 "filename": self.filename,
156 "mime_type": self.mime_type,
157 "size": self.size,
158 "meta": self.meta,
159 }
161 @classmethod
162 def from_dict(cls, d: dict) -> "FilePart":
163 return cls(
164 url=d.get("url", ""),
165 filename=d.get("filename", ""),
166 mime_type=d.get("mime_type", "application/octet-stream"),
167 size=d.get("size", 0),
168 meta=d.get("meta", {}),
169 )
172@dataclass
173class DataPart:
174 """结构化数据消息片段。"""
175 data: Dict[str, Any] = field(default_factory=dict)
176 schema_uri: str = ""
177 meta: Dict[str, str] = field(default_factory=dict)
179 def to_dict(self) -> dict:
180 return {
181 "type": PartType.DATA.value,
182 "data": self.data,
183 "schema_uri": self.schema_uri,
184 "meta": self.meta,
185 }
187 @classmethod
188 def from_dict(cls, d: dict) -> "DataPart":
189 return cls(
190 data=d.get("data", {}),
191 schema_uri=d.get("schema_uri", ""),
192 meta=d.get("meta", {}),
193 )
196def part_from_dict(d: dict):
197 """从字典反序列化任意 Part。"""
198 ptype = d.get("type", "")
199 if ptype == PartType.TEXT.value:
200 return TextPart.from_dict(d)
201 elif ptype == PartType.FILE.value:
202 return FilePart.from_dict(d)
203 elif ptype == PartType.DATA.value:
204 return DataPart.from_dict(d)
205 raise ValueError(f"Unknown part type: {ptype}")
208# ── A2A Artifact ───────────────────────────────
210@dataclass
211class A2AArtifact:
212 """任务产出物。
213 可以是内联数据 (blob) 或外部引用 (url)。
214 """
215 name: str
216 mime_type: str = "application/octet-stream"
217 blob: Optional[bytes] = None
218 url: str = ""
219 size: int = 0
220 description: str = ""
221 meta: Dict[str, str] = field(default_factory=dict)
223 def to_dict(self) -> dict:
224 d = {
225 "name": self.name,
226 "mime_type": self.mime_type,
227 "size": self.size,
228 "description": self.description,
229 "meta": self.meta,
230 }
231 if self.url:
232 d["url"] = self.url
233 if self.blob:
234 import base64
235 d["blob_base64"] = base64.b64encode(self.blob).decode("ascii")
236 return d
238 @classmethod
239 def from_dict(cls, d: dict) -> "A2AArtifact":
240 artifact = cls(
241 name=d.get("name", ""),
242 mime_type=d.get("mime_type", "application/octet-stream"),
243 url=d.get("url", ""),
244 size=d.get("size", 0),
245 description=d.get("description", ""),
246 meta=d.get("meta", {}),
247 )
248 if "blob_base64" in d:
249 import base64
250 artifact.blob = base64.b64decode(d["blob_base64"])
251 return artifact
254# ── A2A Message ────────────────────────────────
256@dataclass
257class A2AMessage:
258 """多模态消息。"""
259 role: MessageRole = MessageRole.USER
260 parts: list = field(default_factory=list) # List[TextPart|FilePart|DataPart]
261 message_id: str = field(default_factory=lambda: f"msg-{uuid.uuid4().hex[:8]}")
262 timestamp: float = field(default_factory=time.time)
263 meta: Dict[str, str] = field(default_factory=dict)
265 def to_dict(self) -> dict:
266 return {
267 "message_id": self.message_id,
268 "role": self.role.value,
269 "parts": [p.to_dict() for p in self.parts],
270 "timestamp": self.timestamp,
271 "meta": self.meta,
272 }
274 @classmethod
275 def from_dict(cls, d: dict) -> "A2AMessage":
276 role = MessageRole(d.get("role", "user"))
277 parts = [part_from_dict(p) for p in d.get("parts", [])]
278 return cls(
279 message_id=d.get("message_id", f"msg-{uuid.uuid4().hex[:8]}"),
280 role=role,
281 parts=parts,
282 timestamp=d.get("timestamp", time.time()),
283 meta=d.get("meta", {}),
284 )
286 @classmethod
287 def user_text(cls, text: str) -> "A2AMessage":
288 return cls(role=MessageRole.USER, parts=[TextPart(text=text)])
290 @classmethod
291 def agent_text(cls, text: str) -> "A2AMessage":
292 return cls(role=MessageRole.AGENT, parts=[TextPart(text=text)])
294 def get_text(self) -> str:
295 """提取所有 text parts 拼接。"""
296 return " ".join(p.text for p in self.parts if isinstance(p, TextPart))
299# ── A2A Task ───────────────────────────────────
301@dataclass
302class A2ATask:
303 """A2A 异步任务。
305 状态机: SUBMITTED → WORKING → COMPLETED / FAILED / CANCELLED
306 """
307 task_id: str = field(default_factory=lambda: f"task-{uuid.uuid4().hex[:8]}")
308 state: TaskState = TaskState.SUBMITTED
309 input: A2AMessage | None = None
310 output: A2AMessage | None = None
311 artifacts: List[A2AArtifact] = field(default_factory=list)
312 error: Optional[str] = None
313 meta: Dict[str, Any] = field(default_factory=dict)
314 _created: float = field(default_factory=time.time)
315 _updated: float = field(default_factory=time.time)
316 _state_history: List[tuple] = field(default_factory=list) # [(state, timestamp)]
318 def start_working(self) -> None:
319 """SUBMITTED → WORKING"""
320 if self.state != TaskState.SUBMITTED:
321 raise ValueError(f"Cannot start from state {self.state}")
322 self._transition(TaskState.WORKING)
324 def complete(self, output: A2AMessage | None = None) -> None:
325 """WORKING → COMPLETED"""
326 if self.state != TaskState.WORKING:
327 raise ValueError(f"Cannot complete from state {self.state}")
328 self.output = output
329 self.error = None
330 self._transition(TaskState.COMPLETED)
332 def fail(self, error: str) -> None:
333 """任何状态 → FAILED"""
334 self.error = error
335 self._transition(TaskState.FAILED)
337 def cancel(self) -> None:
338 """SUBMITTED/WORKING → CANCELLED"""
339 if self.state not in (TaskState.SUBMITTED, TaskState.WORKING):
340 raise ValueError(f"Cannot cancel from state {self.state}")
341 self._transition(TaskState.CANCELLED)
343 def add_artifact(self, artifact: A2AArtifact) -> None:
344 self.artifacts.append(artifact)
346 def is_terminal(self) -> bool:
347 return self.state in (TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELLED)
349 def _transition(self, new_state: TaskState) -> None:
350 self._state_history.append((self.state, self._updated))
351 self.state = new_state
352 self._updated = time.time()
354 def to_dict(self) -> dict:
355 return {
356 "task_id": self.task_id,
357 "state": self.state.value,
358 "input": self.input.to_dict() if self.input else None,
359 "output": self.output.to_dict() if self.output else None,
360 "artifacts": [a.to_dict() for a in self.artifacts],
361 "error": self.error,
362 "meta": self.meta,
363 "created": self._created,
364 "updated": self._updated,
365 }
367 @classmethod
368 def from_dict(cls, d: dict) -> "A2ATask":
369 task = cls(
370 task_id=d.get("task_id", f"task-{uuid.uuid4().hex[:8]}"),
371 state=TaskState(d.get("state", "submitted")),
372 error=d.get("error"),
373 meta=d.get("meta", {}),
374 _created=d.get("created", time.time()),
375 _updated=d.get("updated", time.time()),
376 )
377 if d.get("input"):
378 task.input = A2AMessage.from_dict(d["input"])
379 if d.get("output"):
380 task.output = A2AMessage.from_dict(d["output"])
381 task.artifacts = [A2AArtifact.from_dict(a) for a in d.get("artifacts", [])]
382 return task
384 def to_json(self) -> str:
385 return json.dumps(self.to_dict(), ensure_ascii=False, indent=2)
387 @classmethod
388 def from_json(cls, json_str: str) -> "A2ATask":
389 return cls.from_dict(json.loads(json_str))
392# ── A2A Handoff ────────────────────────────────
394@dataclass
395class A2AHandoff:
396 """Agent 间任务移交请求。"""
397 handoff_id: str = field(default_factory=lambda: f"hoff-{uuid.uuid4().hex[:8]}")
398 source_agent: str = ""
399 target_agent: str = ""
400 task: A2ATask | None = None
401 reason: str = ""
402 metadata: Dict[str, Any] = field(default_factory=dict)
403 timestamp: float = field(default_factory=time.time)
405 def to_dict(self) -> dict:
406 return {
407 "handoff_id": self.handoff_id,
408 "source_agent": self.source_agent,
409 "target_agent": self.target_agent,
410 "task": self.task.to_dict() if self.task else None,
411 "reason": self.reason,
412 "metadata": self.metadata,
413 "timestamp": self.timestamp,
414 }
416 @classmethod
417 def from_dict(cls, d: dict) -> "A2AHandoff":
418 task = None
419 if d.get("task"):
420 task = A2ATask.from_dict(d["task"])
421 return cls(
422 handoff_id=d.get("handoff_id", f"hoff-{uuid.uuid4().hex[:8]}"),
423 source_agent=d.get("source_agent", ""),
424 target_agent=d.get("target_agent", ""),
425 task=task,
426 reason=d.get("reason", ""),
427 metadata=d.get("metadata", {}),
428 timestamp=d.get("timestamp", time.time()),
429 )
431 def to_json(self) -> str:
432 return json.dumps(self.to_dict(), ensure_ascii=False, indent=2)
434 @classmethod
435 def from_json(cls, json_str: str) -> "A2AHandoff":
436 return cls.from_dict(json.loads(json_str))
439# ── A2A Session ────────────────────────────────
441@dataclass
442class A2ASession:
443 """A2A 会话上下文。"""
444 session_id: str = field(default_factory=lambda: f"sess-{uuid.uuid4().hex[:8]}")
445 history: List[A2AMessage] = field(default_factory=list)
446 tasks: List[A2ATask] = field(default_factory=list)
447 metadata: Dict[str, Any] = field(default_factory=dict)
448 created: float = field(default_factory=time.time)
450 def add_message(self, msg: A2AMessage) -> None:
451 self.history.append(msg)
453 def add_task(self, task: A2ATask) -> None:
454 self.tasks.append(task)
456 def get_last_n_messages(self, n: int = 10) -> List[A2AMessage]:
457 return self.history[-n:]
459 def to_dict(self) -> dict:
460 return {
461 "session_id": self.session_id,
462 "history": [m.to_dict() for m in self.history],
463 "tasks": [t.to_dict() for t in self.tasks],
464 "metadata": self.metadata,
465 "created": self.created,
466 }
469# ── A2A Client ─────────────────────────────────
471class A2AClient:
472 """A2A 协议客户端。
474 向远程 Agent 发送任务,查询状态,获取结果。
476 v1.3.13: 重试 + 认证头 + 流式订阅 + 持久化连接池。
477 """
479 def __init__(
480 self,
481 base_url: str,
482 timeout: float = 30.0,
483 max_retries: int = 3,
484 retry_backoff: float = 1.0,
485 auth_token: str = "",
486 agent_name: str = "",
487 ):
488 self.base_url = base_url.rstrip("/")
489 self.timeout = timeout
490 self.max_retries = max_retries
491 self.retry_backoff = retry_backoff
492 self.auth_token = auth_token
493 self.agent_name = agent_name
494 self._client: Any = None
496 def _headers(self) -> dict[str, str]:
497 h = {"User-Agent": f"AgentOS-A2A/{self.agent_name}" if self.agent_name else "AgentOS-A2A"}
498 if self.auth_token:
499 h["Authorization"] = f"Bearer {self.auth_token}"
500 return h
502 async def _get_client(self) -> Any:
503 if self._client is None:
504 import httpx
505 self._client = httpx.AsyncClient(
506 timeout=self.timeout,
507 headers=self._headers(),
508 limits=httpx.Limits(max_keepalive_connections=10, max_connections=50),
509 )
510 return self._client
512 async def close(self) -> None:
513 if self._client:
514 await self._client.aclose()
515 self._client = None
517 async def _retry(self, coro, *args, **kwargs):
518 import asyncio
519 import httpx
520 last_exc = None
521 for attempt in range(self.max_retries):
522 try:
523 return await coro(*args, **kwargs)
524 except (httpx.ConnectError, httpx.TimeoutException) as e:
525 last_exc = e
526 if attempt < self.max_retries - 1:
527 await asyncio.sleep(self.retry_backoff * (2 ** attempt))
528 raise last_exc # type: ignore
530 async def send_task(self, task: A2ATask) -> A2ATask:
531 """POST /tasks — 提交任务,返回带有 server 分配的 task_id 的任务。"""
532 client = await self._get_client()
534 async def _do():
535 resp = await client.post(f"{self.base_url}/tasks", json=task.to_dict())
536 resp.raise_for_status()
537 return A2ATask.from_dict(resp.json())
539 return await self._retry(_do)
541 async def get_task(self, task_id: str) -> Optional[A2ATask]:
542 """GET /tasks/{id} — 查询任务状态和结果。"""
543 import httpx
544 client = await self._get_client()
545 try:
546 resp = await client.get(f"{self.base_url}/tasks/{task_id}")
547 resp.raise_for_status()
548 return A2ATask.from_dict(resp.json())
549 except Exception:
550 return None
552 async def cancel_task(self, task_id: str) -> bool:
553 """DELETE /tasks/{id} — 取消任务。"""
554 import httpx
555 client = await self._get_client()
556 try:
557 resp = await client.delete(f"{self.base_url}/tasks/{task_id}")
558 return resp.status_code < 400
559 except Exception:
560 return False
562 async def handoff(self, handoff: A2AHandoff) -> bool:
563 """POST /handoff — 移交任务到另一个 Agent。"""
564 client = await self._get_client()
565 try:
566 resp = await client.post(f"{self.base_url}/handoff", json=handoff.to_dict())
567 return resp.status_code < 400
568 except Exception:
569 return False
571 async def wait_for_completion(
572 self,
573 task_id: str,
574 poll_interval: float = 1.0,
575 max_wait: float = 60.0,
576 ) -> A2ATask:
577 """轮询等待任务完成。"""
578 import asyncio
579 elapsed = 0.0
580 while elapsed < max_wait:
581 task = await self.get_task(task_id)
582 if task is None:
583 raise RuntimeError(f"Task {task_id} not found")
584 if task.is_terminal():
585 return task
586 await asyncio.sleep(poll_interval)
587 elapsed += poll_interval
588 raise TimeoutError(f"Task {task_id} did not complete within {max_wait}s")
590 async def send_and_wait_for_reply(
591 self,
592 text: str,
593 target_agent: str = "",
594 poll_interval: float = 1.0,
595 max_wait: float = 60.0,
596 ) -> str:
597 """便捷方法:发送文本任务并等待回复文本。"""
598 task = new_task(text, target_agent=target_agent)
599 task = await self.send_task(task)
600 result = await self.wait_for_completion(task.task_id, poll_interval, max_wait)
601 if result.output:
602 return result.output.get_text()
603 if result.error:
604 return f"[Error] {result.error}"
605 return ""
607 async def subscribe_task_stream(
608 self,
609 task_id: str,
610 on_event: Callable[[dict], Any] | None = None,
611 ) -> None:
612 """SSE streaming subscribe: 连接到服务端 SSE 端点监听任务事件。"""
613 import httpx
614 client = await self._get_client()
615 async with client.stream("GET", f"{self.base_url}/tasks/{task_id}/stream") as resp:
616 resp.raise_for_status()
617 buffer = ""
618 async for chunk in resp.aiter_text():
619 buffer += chunk
620 while "\n\n" in buffer:
621 msg, buffer = buffer.split("\n\n", 1)
622 event_data: dict[str, str] = {}
623 for line in msg.split("\n"):
624 if line.startswith("event: "):
625 event_data["event"] = line[7:]
626 elif line.startswith("data: "):
627 event_data["data"] = line[6:]
628 if on_event:
629 on_event(event_data)
632# ── A2A Server ─────────────────────────────────
634class A2AServer:
635 """A2A 协议服务端。
637 接收并处理 Agent 间任务请求。
639 使用方式:
640 server = A2AServer()
641 server.register_handler("my-agent", my_handler)
642 # 集成到 FastAPI:
643 app = FastAPI()
644 server.mount_routes(app)
645 """
647 def __init__(
648 self,
649 task_store=None,
650 stream_manager=None,
651 require_auth: bool = False,
652 auth_tokens: List[str] | None = None,
653 ):
654 self._handlers: Dict[str, Callable] = {}
655 self._task_store = task_store
656 self._stream_manager = stream_manager
657 self.require_auth = require_auth
658 self.auth_tokens: set[str] = set(auth_tokens or [])
659 self._default_store_created = False
661 def _ensure_store(self):
662 if self._task_store is None:
663 from agentos.protocols.a2a_store import InMemoryTaskStore
664 self._task_store = InMemoryTaskStore()
665 self._default_store_created = True
667 @property
668 def task_store(self):
669 self._ensure_store()
670 return self._task_store
672 def register_handler(
673 self,
674 agent_name: str,
675 handler: Callable,
676 ) -> None:
677 """注册 Agent 处理函数。
679 handler 签名: async def handler(task: A2ATask) -> A2AMessage
680 """
681 self._handlers[agent_name] = handler
683 async def process_task(self, body: dict, auth_token: str = "") -> dict:
684 """处理传入任务:解析、执行 handler、返回。"""
685 self._ensure_store()
687 if self.require_auth and auth_token not in self.auth_tokens:
688 task = A2ATask.from_dict(body)
689 task.fail("Unauthorized: invalid or missing A2A auth token")
690 self._task_store.save_task(task)
691 return task.to_dict()
693 task = A2ATask.from_dict(body)
694 old_state = task.state
695 self._task_store.save_task(task)
697 target = body.get("meta", {}).get("target_agent", "")
698 handler = self._handlers.get(target)
700 if not handler and target:
701 task.fail(f"No handler for agent '{target}'")
702 elif not handler:
703 task.fail("No target agent specified in meta")
704 else:
705 try:
706 task.start_working()
707 if self._stream_manager:
708 await self._stream_manager.notify_state_change(task, old_state)
709 result = handler(task)
710 import inspect
711 if inspect.isawaitable(result):
712 output = await result
713 else:
714 output = result
715 task.complete(output)
716 except Exception as e:
717 task.fail(str(e))
719 if self._stream_manager:
720 await self._stream_manager.notify_state_change(task, old_state)
721 self._task_store.save_task(task)
722 return task.to_dict()
724 def get_task(self, task_id: str) -> Optional[A2ATask]:
725 return self.task_store.get_task(task_id)
727 def list_tasks(self, state: TaskState | None = None) -> list[A2ATask]:
728 return self.task_store.list_tasks(state=state)
730 def cleanup_old(self, max_age_seconds: float = 3600.0) -> int:
731 return self.task_store.cleanup_terminal(max_age_seconds)
733 # ── FastAPI 路由构建器 ─────────────────────
735 def mount_routes(self, app, prefix: str = "") -> None:
736 """将 A2A 标准路由挂载到 FastAPI/Starlette app 上。
738 路由:
739 POST {prefix}/tasks — 创建任务
740 GET {prefix}/tasks — 列出任务
741 GET {prefix}/tasks/{id} — 获取任务
742 DELETE {prefix}/tasks/{id} — 取消任务
743 GET {prefix}/tasks/{id}/stream — SSE 事件流
744 POST {prefix}/handoff — 任务移交
745 """
746 try:
747 from fastapi import FastAPI, Request, HTTPException
748 from starlette.responses import StreamingResponse
749 except ImportError:
750 raise ImportError("FastAPI and Starlette are required for mount_routes()")
752 server = self
754 @app.post(f"{prefix}/tasks")
755 async def create_task(request: Request):
756 body = await request.json()
757 token = request.headers.get("Authorization", "").removeprefix("Bearer ")
758 return await server.process_task(body, auth_token=token)
760 @app.get(f"{prefix}/tasks")
761 async def list_tasks_endpoint(state: str = ""):
762 task_state = TaskState(state) if state else None
763 tasks = server.list_tasks(state=task_state)
764 return [t.to_dict() for t in tasks]
766 @app.get(f"{prefix}/tasks/{{task_id}}")
767 async def get_task_endpoint(task_id: str):
768 task = server.get_task(task_id)
769 if task is None:
770 raise HTTPException(status_code=404, detail="Task not found")
771 return task.to_dict()
773 @app.delete(f"{prefix}/tasks/{{task_id}}")
774 async def cancel_task_endpoint(task_id: str):
775 task = server.get_task(task_id)
776 if task is None:
777 raise HTTPException(status_code=404, detail="Task not found")
778 if task.is_terminal():
779 raise HTTPException(status_code=400, detail="Task already terminal")
780 task.cancel()
781 server._task_store.save_task(task)
782 return {"status": "cancelled", "task_id": task_id}
784 @app.get(f"{prefix}/tasks/{{task_id}}/stream")
785 async def stream_task(task_id: str):
786 if server._stream_manager is None:
787 raise HTTPException(status_code=501, detail="Streaming not enabled")
789 async def event_generator():
790 stream = server._stream_manager
791 session = stream.get_session(task_id)
792 if session is None:
793 yield f"event: error\ndata: {{\"error\": \"Session not found\"}}\n\n"
794 return
795 sub = session.subscribe()
796 try:
797 async for evt in session.iter_events(sub):
798 yield session.to_sse(evt)
799 except Exception:
800 pass
802 return StreamingResponse(
803 event_generator(),
804 media_type="text/event-stream",
805 headers={
806 "Cache-Control": "no-cache",
807 "Connection": "keep-alive",
808 "X-Accel-Buffering": "no",
809 },
810 )
812 @app.post(f"{prefix}/handoff")
813 async def handoff_endpoint(request: Request):
814 body = await request.json()
815 token = request.headers.get("Authorization", "").removeprefix("Bearer ")
816 if server.require_auth and token not in server.auth_tokens:
817 raise HTTPException(status_code=401, detail="Unauthorized")
818 handoff = A2AHandoff.from_dict(body)
819 if handoff.task:
820 server._task_store.save_task(handoff.task)
821 return {"status": "received", "handoff_id": handoff.handoff_id}
824# ── 便捷函数 ───────────────────────────────────
826def new_task(text: str, target_agent: str = "", **meta) -> A2ATask:
827 """快速创建一个文本任务。"""
828 return A2ATask(
829 input=A2AMessage.user_text(text),
830 meta={"target_agent": target_agent, **meta},
831 )
834def new_handoff(
835 task: A2ATask,
836 source: str,
837 target: str,
838 reason: str = "",
839) -> A2AHandoff:
840 """快速创建 Handoff。"""
841 return A2AHandoff(
842 source_agent=source,
843 target_agent=target,
844 task=task,
845 reason=reason,
846 )
848# ==========================================================================
849# Compat: lightweight AgentRegistry for test_core.py
850# ==========================================================================
851from dataclasses import dataclass, field
852from typing import List, Optional, Dict
853import time
855@dataclass
856class AgentRecord:
857 agent_id: str
858 capabilities: List[str] = field(default_factory=list)
859 endpoint: str = ""
860 load: float = 0.0
861 _last_heartbeat: float = field(default_factory=time.time, repr=False)
863 @property
864 def healthy(self) -> bool:
865 return (time.time() - self._last_heartbeat) < 60.0
868class AgentRegistry:
869 def __init__(self, name: str = "default", default_ttl: float = 60.0):
870 self._records: Dict[str, AgentRecord] = {}
871 self.name = name
872 self.default_ttl = default_ttl
874 def register(self, record: AgentRecord):
875 record._last_heartbeat = time.time()
876 self._records[record.agent_id] = record
878 def get(self, agent_id: str) -> Optional[AgentRecord]:
879 return self._records.get(agent_id)
881 def find_by_capability(self, capability: str) -> List[AgentRecord]:
882 return [r for r in self._records.values() if capability in r.capabilities]
884 def heartbeat(self, agent_id: str):
885 r = self._records.get(agent_id)
886 if r:
887 r._last_heartbeat = time.time()
889 def pick_least_loaded(self, capability: str) -> Optional[AgentRecord]:
890 candidates = self.find_by_capability(capability)
891 if not candidates:
892 return None
893 return min(candidates, key=lambda r: r.load)