Coverage for agentos/protocols/registry.py: 39%
361 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
1"""
2AgentOS v1.14.0 — A2A Agent Registry(服务发现)。
4Agent 注册中心实现:
5- 服务注册/注销/心跳
6- 能力广播 (Agent Card)
7- 服务发现 (按能力/角色查询)
8- 健康检查自动摘除
9"""
11from __future__ import annotations
13import asyncio
14import json
15import time
16import uuid
17from dataclasses import dataclass, field
18from enum import Enum
19from typing import Any, Callable, Coroutine, Dict, List, Optional, Set, Union
22# ── Agent Info ───────────────────────────────
25@dataclass
26class AgentInfo:
27 """Lightweight agent info for registration and discovery."""
28 agent_id: str
29 endpoint: str = ""
30 capabilities: List[str] = field(default_factory=list)
31 version: str = "1.0.0"
32 transport: str = "grpc"
33 status: str = "active"
34 heartbeat_ts: float = field(default_factory=time.time)
37# ── Agent Card ───────────────────────────────
40@dataclass
41class DiscoveryCapability:
42 """服务发现用 Agent 能力描述。
44 不同于 protocols.contracts.AgentCapability(面向合约),
45 本类面向注册中心发现和匹配。
46 """
47 name: str # 能力名称 (e.g. "code_review", "pdf_parsing")
48 description: str = ""
49 version: str = "1.0.0"
50 input_schema: Optional[Dict[str, Any]] = None
51 performance: Dict[str, Any] = field(default_factory=dict) # 性能指标
53 def to_dict(self) -> dict:
54 return {
55 "name": self.name,
56 "description": self.description,
57 "version": self.version,
58 "input_schema": self.input_schema,
59 "performance": self.performance,
60 }
62 @classmethod
63 def from_dict(cls, d: dict) -> "DiscoveryCapability":
64 return cls(
65 name=d.get("name", ""),
66 description=d.get("description", ""),
67 version=d.get("version", "1.0.0"),
68 input_schema=d.get("input_schema"),
69 performance=d.get("performance", {}),
70 )
73@dataclass
74class DiscoveryCard:
75 """Agent 发现名片 — 向注册中心声明的自身信息。
77 符合 Google A2A AgentCard 规范。
78 不同于 protocols.agent_card.AgentCard(面向本地广播),
79 本类面向远程注册中心的服务发现。
80 """
82 agent_id: str
83 name: str
84 description: str = ""
85 url: str = "" # Agent 的服务端点
86 version: str = "1.0.0"
87 capabilities: List[DiscoveryCapability] = field(default_factory=list)
88 provider: Dict[str, Any] = field(default_factory=dict) # 组织/团队信息
89 default_input_modes: List[str] = field(default_factory=lambda: ["text"])
90 default_output_modes: List[str] = field(default_factory=lambda: ["text"])
91 skills: List[Dict[str, Any]] = field(default_factory=list)
92 supports_streaming: bool = False
93 supports_handoff: bool = True
94 max_context_length: int = 128000
95 preferred_model: str = ""
96 service_tier: str = "standard" # standard | premium | enterprise
97 metadata: Dict[str, Any] = field(default_factory=dict)
99 def to_dict(self) -> dict:
100 return {
101 "agent_id": self.agent_id,
102 "name": self.name,
103 "description": self.description,
104 "url": self.url,
105 "version": self.version,
106 "capabilities": [c.to_dict() for c in self.capabilities],
107 "provider": self.provider,
108 "defaultInputModes": self.default_input_modes,
109 "defaultOutputModes": self.default_output_modes,
110 "skills": self.skills,
111 "supportsStreaming": self.supports_streaming,
112 "supportsHandoff": self.supports_handoff,
113 "maxContextLength": self.max_context_length,
114 "preferredModel": self.preferred_model,
115 "serviceTier": self.service_tier,
116 "metadata": self.metadata,
117 }
119 @classmethod
120 def from_dict(cls, d: dict) -> "DiscoveryCard":
121 return cls(
122 agent_id=d.get("agent_id", ""),
123 name=d.get("name", ""),
124 description=d.get("description", ""),
125 url=d.get("url", ""),
126 version=d.get("version", "1.0.0"),
127 capabilities=[
128 DiscoveryCapability.from_dict(c)
129 for c in d.get("capabilities", [])
130 ],
131 provider=d.get("provider", {}),
132 default_input_modes=d.get("defaultInputModes", ["text"]),
133 default_output_modes=d.get("defaultOutputModes", ["text"]),
134 skills=d.get("skills", []),
135 supports_streaming=d.get("supportsStreaming", False),
136 supports_handoff=d.get("supportsHandoff", True),
137 max_context_length=d.get("maxContextLength", 128000),
138 preferred_model=d.get("preferredModel", ""),
139 service_tier=d.get("serviceTier", "standard"),
140 metadata=d.get("metadata", {}),
141 )
143 def has_capability(self, cap_name: str) -> bool:
144 """检查是否具备某项能力。"""
145 return any(c.name == cap_name for c in self.capabilities)
148# ── Registry Entry ──────────────────────────
151class AgentStatus(str, Enum):
152 """Agent 运行状态。"""
153 ONLINE = "online"
154 BUSY = "busy"
155 DEGRADED = "degraded"
156 OFFLINE = "offline"
159@dataclass
160class RegistryEntry:
161 """注册中心中的单条 Agent 记录。"""
163 card: DiscoveryCard
164 status: AgentStatus = AgentStatus.ONLINE
165 registered_at: float = field(default_factory=time.time)
166 last_heartbeat: float = field(default_factory=time.time)
167 load: float = 0.0 # 0.0~1.0 负载
168 task_count: int = 0 # 已完成任务数
169 error_count: int = 0
170 avg_response_ms: float = 0.0
171 tags: List[str] = field(default_factory=list)
172 endpoint_health: Dict[str, Any] = field(default_factory=dict)
174 def is_healthy(self, heartbeat_timeout: float = 30.0) -> bool:
175 """心跳是否正常。"""
176 return (time.time() - self.last_heartbeat) < heartbeat_timeout
178 @property
179 def uptime_seconds(self) -> float:
180 """注册后的运行时长。"""
181 return time.time() - self.registered_at
183 @property
184 def endpoint(self) -> str:
185 """Agent 服务端点(映射 card.url)。"""
186 return self.card.url
189# ── Agent Registry ──────────────────────────
192class AgentRegistry:
193 """A2A Agent 注册中心。
195 功能:
196 - register: 注册 Agent(带心跳保活)
197 - discover: 按能力/角色/标签发现 Agent
198 - health_check: 自动摘除失联 Agent
199 - subscribe: 订阅注册事件
201 Usage:
202 registry = AgentRegistry(heartbeat_timeout=30)
203 registry.register(agent_card)
205 # 发现能处理 code_review 的 Agent
206 agents = registry.discover(capability="code_review")
208 # 按标签查找
209 agents = registry.discover(tags=["production", "high-priority"])
210 """
212 def __init__(
213 self,
214 heartbeat_timeout: float = 30.0,
215 health_check_interval: float = 10.0,
216 auto_cleanup: bool = True,
217 ):
218 self._entries: Dict[str, RegistryEntry] = {}
219 self._capability_index: Dict[str, Set[str]] = {}
220 self._tag_index: Dict[str, Set[str]] = {}
221 self._role_index: Dict[str, Set[str]] = {}
222 self.heartbeat_timeout = heartbeat_timeout
223 self.health_check_interval = health_check_interval
224 self.auto_cleanup = auto_cleanup
226 # Event subscribers
227 self._subscribers: Dict[str, List[Callable]] = {
228 "register": [],
229 "deregister": [],
230 "status_change": [],
231 "heartbeat": [],
232 }
234 # Health check task
235 self._health_check_task: Optional[asyncio.Task] = None
236 self._running = False
238 # ── Lifecycle ─────────────────────────────
240 async def start(self) -> None:
241 """启动注册中心(开始健康检查循环)。"""
242 if self._running:
243 return
244 self._running = True
245 if self.auto_cleanup:
246 self._health_check_task = asyncio.create_task(self._health_check_loop())
248 async def stop(self) -> None:
249 """停止注册中心。"""
250 self._running = False
251 if self._health_check_task:
252 self._health_check_task.cancel()
253 try:
254 await self._health_check_task
255 except asyncio.CancelledError:
256 pass
257 self._health_check_task = None
259 async def _health_check_loop(self) -> None:
260 """后台健康检查循环。"""
261 while self._running:
262 try:
263 self._check_all_health()
264 except Exception:
265 pass
266 await asyncio.sleep(self.health_check_interval)
268 def _check_all_health(self) -> None:
269 """健康检查:摘除失联 Agent。"""
270 to_remove = []
271 for agent_id, entry in list(self._entries.items()):
272 if entry.status != AgentStatus.OFFLINE and not entry.is_healthy(
273 self.heartbeat_timeout
274 ):
275 self._update_status(agent_id, AgentStatus.OFFLINE)
276 to_remove.append(agent_id)
278 for agent_id in to_remove:
279 self._cleanup_indices(agent_id)
281 # ── Registration ─────────────────────────
283 def register(
284 self,
285 card_or_info: Union[DiscoveryCard, "AgentInfo"],
286 tags: Optional[List[str]] = None,
287 ) -> str:
288 """注册一个 Agent。
290 Args:
291 card_or_info: Agent 名片 (DiscoveryCard) 或轻量信息 (AgentInfo)
292 tags: 自定义标签
294 Returns:
295 agent_id
296 """
297 # Accept both DiscoveryCard and AgentInfo
298 if isinstance(card_or_info, AgentInfo):
299 info = card_or_info
300 card = DiscoveryCard(
301 agent_id=info.agent_id,
302 name=info.agent_id,
303 description=f"Agent {info.agent_id} v{info.version}",
304 version=info.version,
305 capabilities=[
306 DiscoveryCapability(name=c) for c in info.capabilities
307 ],
308 url=info.endpoint,
309 )
310 else:
311 card = card_or_info
313 agent_id = card.agent_id
314 now = time.time()
316 if agent_id in self._entries:
317 # Re-registration: update card, refresh heartbeat
318 entry = self._entries[agent_id]
319 old_status = entry.status
320 entry.card = card
321 entry.last_heartbeat = now
322 entry.tags = tags or entry.tags
323 if old_status == AgentStatus.OFFLINE:
324 self._update_status(agent_id, AgentStatus.ONLINE)
325 return agent_id
327 entry = RegistryEntry(
328 card=card,
329 status=AgentStatus.ONLINE,
330 registered_at=now,
331 last_heartbeat=now,
332 tags=tags or [],
333 )
334 self._entries[agent_id] = entry
335 self._build_indices(agent_id, card, tags or [])
336 self._emit("register", {"agent_id": agent_id, "card": card.to_dict()})
337 return agent_id
339 def deregister(self, agent_id: str) -> bool:
340 """注销一个 Agent。"""
341 if agent_id not in self._entries:
342 return False
343 entry = self._entries.pop(agent_id)
344 self._cleanup_indices(agent_id)
345 self._emit("deregister", {"agent_id": agent_id, "card": entry.card.to_dict()})
346 return True
348 def heartbeat(self, agent_id: str, load: Optional[float] = None) -> bool:
349 """Agent 心跳上报。
351 Args:
352 agent_id: Agent ID
353 load: 当前负载 0.0~1.0
355 Returns:
356 是否成功
357 """
358 if agent_id not in self._entries:
359 return False
360 entry = self._entries[agent_id]
361 entry.last_heartbeat = time.time()
362 if load is not None:
363 entry.load = max(0.0, min(1.0, load))
364 if entry.status == AgentStatus.OFFLINE:
365 self._update_status(agent_id, AgentStatus.ONLINE)
366 self._emit("heartbeat", {"agent_id": agent_id, "load": entry.load})
367 return True
369 def update_stats(
370 self,
371 agent_id: str,
372 task_count: Optional[int] = None,
373 error_count: Optional[int] = None,
374 avg_response_ms: Optional[float] = None,
375 ) -> bool:
376 """更新 Agent 统计信息。"""
377 if agent_id not in self._entries:
378 return False
379 entry = self._entries[agent_id]
380 if task_count is not None:
381 entry.task_count = task_count
382 if error_count is not None:
383 entry.error_count = error_count
384 if avg_response_ms is not None:
385 entry.avg_response_ms = avg_response_ms
386 return True
388 # ── Discovery ────────────────────────────
390 def discover(
391 self,
392 capability: Optional[str] = None,
393 tags: Optional[List[str]] = None,
394 role: Optional[str] = None,
395 status: Optional[AgentStatus] = None,
396 service_tier: Optional[str] = None,
397 supports_streaming: Optional[bool] = None,
398 min_health: bool = True,
399 limit: int = 50,
400 ) -> List[RegistryEntry]:
401 """服务发现:按条件筛选 Agent。
403 Args:
404 capability: 按能力名称筛选
405 tags: 按标签筛选(AND 逻辑)
406 role: 按角色筛选
407 status: 按状态筛选
408 service_tier: 按服务等级筛选
409 supports_streaming: 是否支持流式
410 min_health: 仅返回心跳正常的 Agent
411 limit: 最大返回数
413 Returns:
414 匹配的 RegistryEntry 列表
415 """
416 candidates: Set[str] = set(self._entries.keys())
418 # Capability filter
419 if capability:
420 cap_agents = self._capability_index.get(capability, set())
421 candidates &= cap_agents
423 # Tag filter (AND)
424 if tags:
425 for tag in tags:
426 tag_agents = self._tag_index.get(tag, set())
427 candidates &= tag_agents
429 # Role filter
430 if role:
431 role_agents = self._role_index.get(role, set())
432 candidates &= role_agents
434 # Collect results
435 results = []
436 for agent_id in candidates:
437 entry = self._entries[agent_id]
439 # Status filter
440 if status and entry.status != status:
441 continue
443 # Health filter
444 if min_health and not entry.is_healthy(self.heartbeat_timeout):
445 continue
447 # Service tier filter
448 if service_tier and entry.card.service_tier != service_tier:
449 continue
451 # Streaming filter
452 if supports_streaming is not None and \
453 entry.card.supports_streaming != supports_streaming:
454 continue
456 results.append(entry)
458 # Sort: online first, then by load (least loaded first)
459 status_order = {
460 AgentStatus.ONLINE: 0,
461 AgentStatus.BUSY: 1,
462 AgentStatus.DEGRADED: 2,
463 AgentStatus.OFFLINE: 3,
464 }
465 results.sort(key=lambda e: (status_order.get(e.status, 9), e.load))
467 return results[:limit]
469 def discover_one(
470 self,
471 capability: Optional[str] = None,
472 load_balanced: bool = True,
473 **kwargs,
474 ) -> Optional[RegistryEntry]:
475 """发现单个最优 Agent。
477 Args:
478 capability: 按能力筛选
479 load_balanced: 是否负载均衡(选负载最低的)
480 """
481 results = self.discover(
482 capability=capability,
483 status=AgentStatus.ONLINE,
484 min_health=True,
485 limit=10,
486 **kwargs,
487 )
488 if not results:
489 return None
490 if load_balanced:
491 # Already sorted by load ascending
492 return results[0]
493 return results[0]
495 def get_agent(self, agent_id: str) -> Optional[RegistryEntry]:
496 """按 ID 获取 Agent。"""
497 return self._entries.get(agent_id)
499 def list_all(self, include_offline: bool = False) -> List[RegistryEntry]:
500 """列出所有 Agent。"""
501 entries = list(self._entries.values())
502 if not include_offline:
503 entries = [
504 e for e in entries
505 if e.status != AgentStatus.OFFLINE
506 or e.is_healthy(self.heartbeat_timeout)
507 ]
508 return entries
510 # ── Stats ─────────────────────────────────
512 @property
513 def total_agents(self) -> int:
514 """已注册的 Agent 总数。"""
515 return len(self._entries)
517 @property
518 def online_count(self) -> int:
519 """在线 Agent 数。"""
520 return sum(
521 1 for e in self._entries.values()
522 if e.status == AgentStatus.ONLINE and e.is_healthy(self.heartbeat_timeout)
523 )
525 def get_stats(self) -> Dict[str, Any]:
526 """获取注册中心统计。"""
527 online = 0
528 busy = 0
529 degraded = 0
530 offline = 0
531 for e in self._entries.values():
532 if not e.is_healthy(self.heartbeat_timeout):
533 offline += 1
534 elif e.status == AgentStatus.ONLINE:
535 online += 1
536 elif e.status == AgentStatus.BUSY:
537 busy += 1
538 elif e.status == AgentStatus.DEGRADED:
539 degraded += 1
541 return {
542 "total": self.total_agents,
543 "online": online,
544 "busy": busy,
545 "degraded": degraded,
546 "offline": offline,
547 "capabilities": list(self._capability_index.keys()),
548 "tags": list(self._tag_index.keys()),
549 "roles": list(self._role_index.keys()),
550 }
552 # ── Events ────────────────────────────────
554 def subscribe(
555 self,
556 event: str,
557 callback: Callable[[Dict[str, Any]], Any],
558 ) -> None:
559 """订阅注册事件。
561 Args:
562 event: 'register' | 'deregister' | 'status_change' | 'heartbeat'
563 callback: 回调函数,接收事件 dict
564 """
565 if event in self._subscribers:
566 self._subscribers[event].append(callback)
568 def _emit(self, event: str, data: Dict[str, Any]) -> None:
569 """触发事件。"""
570 for cb in self._subscribers.get(event, []):
571 try:
572 cb(data)
573 except Exception:
574 pass
576 # ── Internal Helpers ─────────────────────
578 def _update_status(self, agent_id: str, new_status: AgentStatus) -> None:
579 """更新 Agent 状态并触发事件。"""
580 if agent_id in self._entries:
581 old_status = self._entries[agent_id].status
582 self._entries[agent_id].status = new_status
583 if old_status != new_status:
584 self._emit("status_change", {
585 "agent_id": agent_id,
586 "old_status": old_status.value,
587 "new_status": new_status.value,
588 })
590 def _build_indices(
591 self,
592 agent_id: str,
593 card: DiscoveryCard,
594 tags: List[str],
595 ) -> None:
596 """构建反向索引。"""
597 # Capability index
598 for cap in card.capabilities:
599 if cap.name not in self._capability_index:
600 self._capability_index[cap.name] = set()
601 self._capability_index[cap.name].add(agent_id)
603 # Tag index
604 for tag in tags:
605 if tag not in self._tag_index:
606 self._tag_index[tag] = set()
607 self._tag_index[tag].add(agent_id)
609 # Role index (from provider)
610 role = card.provider.get("role", "")
611 if role:
612 if role not in self._role_index:
613 self._role_index[role] = set()
614 self._role_index[role].add(agent_id)
616 def _cleanup_indices(self, agent_id: str) -> None:
617 """从所有索引中移除 Agent。"""
618 for cap_set in self._capability_index.values():
619 cap_set.discard(agent_id)
620 for tag_set in self._tag_index.values():
621 tag_set.discard(agent_id)
622 for role_set in self._role_index.values():
623 role_set.discard(agent_id)
626# ── A2A Registry Bridge ─────────────────────
629class A2ARegistryBridge:
630 """A2A Registry + A2A Client 桥接。
632 自动从 Registry 发现 Agent 并创建 A2A Client。
634 Usage:
635 bridge = A2ARegistryBridge(registry)
636 client = await bridge.get_client(capability="code_review")
637 result = await client.send_and_wait_for_reply("Review this code...")
638 """
640 def __init__(self, registry: AgentRegistry):
641 self._registry = registry
642 self._clients: Dict[str, Any] = {} # agent_id -> A2AClient
644 async def get_client(
645 self,
646 capability: Optional[str] = None,
647 agent_id: Optional[str] = None,
648 **kwargs,
649 ) -> Any:
650 """获取已发现的 Agent 的 A2A Client。
652 Args:
653 capability: 按能力发现
654 agent_id: 直接指定 Agent ID
655 """
656 from agentos.protocols.a2a import A2AClient
658 if agent_id and agent_id in self._clients:
659 return self._clients[agent_id]
661 entry = None
662 if agent_id:
663 entry = self._registry.get_agent(agent_id)
664 else:
665 entry = self._registry.discover_one(capability=capability, **kwargs)
667 if not entry:
668 raise RuntimeError(
669 f"No available agent found for capability={capability}"
670 )
672 client = A2AClient(
673 base_url=entry.card.url,
674 agent_name=entry.card.name,
675 )
676 self._clients[entry.card.agent_id] = client
677 return client
679 async def close_all(self) -> None:
680 """关闭所有 Client 连接。"""
681 for client in self._clients.values():
682 try:
683 await client.close()
684 except Exception:
685 pass
686 self._clients.clear()
688 def invalidate(self, agent_id: str) -> None:
689 """使某个 Agent 的缓存 Client 失效。"""
690 self._clients.pop(agent_id, None)
693# ── 全局单例 ─────────────────────────────────
695default_registry = AgentRegistry()
697# ── AgentRecord (test compatibility) ──
698from dataclasses import dataclass, field
699from typing import Optional
700import time
702@dataclass
703class AgentRecord:
704 agent_id: str
705 capabilities: list = field(default_factory=list)
706 endpoint: str = ""
707 load: float = 0.0
708 healthy: bool = True
709 name: str = ""
710 version: str = "1.0"
711 registered_at: float = field(default_factory=time.time)
713# ==========================================================================
714# Compat: AgentRecord + lightweight compat API for test_core.py
715# ==========================================================================
718# Re-use the AgentRecord from a2a module to avoid duplication
719try:
720 from agentos.protocols.a2a import AgentRecord
721except ImportError:
722 from dataclasses import dataclass, field
723 import time as _time
725 @dataclass
726 class AgentRecord:
727 agent_id: str
728 capabilities: list = field(default_factory=list)
729 endpoint: str = ""
730 load: float = 0.0
731 _last_heartbeat: float = field(default_factory=_time.time, repr=False)
733 @property
734 def healthy(self) -> bool:
735 return (_time.time() - self._last_heartbeat) < 60.0
738def _patch_agent_registry():
739 """Extend AgentRegistry with compat methods."""
740 import time as _time
742 def compat_register(self, record):
743 from agentos.protocols.a2a import AgentRecord as AR
744 if not hasattr(self, '_compat_records'):
745 self._compat_records = {}
746 if isinstance(record, AR):
747 record._last_heartbeat = _time.time()
748 self._compat_records[record.agent_id] = record
749 else:
750 self._orig_register(record)
752 def compat_get(self, agent_id):
753 if not hasattr(self, '_compat_records'):
754 self._compat_records = {}
755 return self._compat_records.get(agent_id)
757 def compat_find_by_capability(self, capability):
758 if not hasattr(self, '_compat_records'):
759 self._compat_records = {}
760 return [r for r in self._compat_records.values() if capability in r.capabilities]
762 def compat_heartbeat(self, agent_id):
763 if not hasattr(self, '_compat_records'):
764 self._compat_records = {}
765 r = self._compat_records.get(agent_id)
766 if r:
767 r._last_heartbeat = _time.time()
769 def compat_pick_least_loaded(self, capability):
770 candidates = compat_find_by_capability(self, capability)
771 if not candidates:
772 return None
773 return min(candidates, key=lambda r: r.load)
775 if not hasattr(AgentRegistry, '_orig_register'):
776 AgentRegistry._orig_register = AgentRegistry.register
777 AgentRegistry.register = compat_register
778 AgentRegistry.get = compat_get
779 AgentRegistry.find_by_capability = compat_find_by_capability
780 AgentRegistry.heartbeat = compat_heartbeat
781 AgentRegistry.pick_least_loaded = compat_pick_least_loaded
784_patch_agent_registry()