Coverage for agentos/swarm/tool_registry.py: 31%
318 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2v1.9.8: Dynamic Tool Registry + Intelligent Tool Router.
4ToolRegistry: schema-based tool catalog with versioning, capability tags, and dependency tracking.
5ToolRouter: LLM-driven tool selection with semantic matching, confidence scoring, and fallback chains.
6"""
8from __future__ import annotations
10import json
11import time
12from collections import defaultdict
13from dataclasses import dataclass, field
14from enum import Enum
15from typing import Any, Callable, Optional
18# ── Tool Schema ───────────────────────────────────────────────────
20class ToolCategory(str, Enum):
21 """Top-level tool category for coarse-grained routing."""
22 FILE = "file"
23 NETWORK = "network"
24 CODE = "code"
25 SYSTEM = "system"
26 DATA = "data"
27 AGENT = "agent"
28 CUSTOM = "custom"
31@dataclass
32class ToolParam:
33 """Parameter definition for a tool."""
34 name: str
35 type: str # str, int, float, bool, list, dict
36 description: str = ""
37 required: bool = False
38 default: Any = None
39 enum_values: list[str] | None = None
40 min_value: float | None = None
41 max_value: float | None = None
42 pattern: str = "" # Regex validation pattern
45@dataclass
46class ToolSchema:
47 """Complete tool schema definition."""
48 name: str # Unique tool name
49 description: str # Human-readable description
50 category: ToolCategory = ToolCategory.CUSTOM
51 params: list[ToolParam] = field(default_factory=list)
52 returns: str = "any" # Return type description
53 version: str = "1.0.0"
54 capabilities: list[str] = field(default_factory=list) # e.g. ["read", "text", "file"]
55 tags: list[str] = field(default_factory=list) # Searchable tags
56 dependencies: list[str] = field(default_factory=list) # Required other tools
57 handler: Callable[..., Any] | None = None # Actual implementation
58 handler_ref: str = "" # String reference for serialization
59 cost_estimate: float = 0.0 # Relative cost (latency, tokens, etc)
60 is_destructive: bool = False # Data-modifying operations
61 requires_auth: bool = False # Needs authentication
62 rate_limit: int = 0 # Max calls per minute, 0 = unlimited
63 deprecated: bool = False
64 deprecated_message: str = ""
65 metadata: dict[str, Any] = field(default_factory=dict)
67 def to_openai_function(self) -> dict[str, Any]:
68 """Export schema as OpenAI function-calling format."""
69 properties = {}
70 required = []
71 for p in self.params:
72 prop: dict[str, Any] = {
73 "type": p.type,
74 "description": p.description,
75 }
76 if p.enum_values:
77 prop["enum"] = p.enum_values
78 properties[p.name] = prop
79 if p.required:
80 required.append(p.name)
82 return {
83 "type": "function",
84 "function": {
85 "name": self.name,
86 "description": self.description,
87 "parameters": {
88 "type": "object",
89 "properties": properties,
90 "required": required,
91 },
92 },
93 }
95 def match_score(self, query: str, keywords: list[str]) -> float:
96 """Compute relevance score for a natural language query."""
97 query_lower = query.lower()
98 score = 0.0
100 # Name exact match
101 if self.name.lower() == query_lower:
102 score += 10.0
103 elif self.name.lower() in query_lower:
104 score += 5.0
106 # Description match
107 desc_lower = self.description.lower()
108 if query_lower in desc_lower:
109 score += 3.0
110 for kw in keywords:
111 if kw in desc_lower:
112 score += 1.5
114 # Capability tag match
115 for cap in self.capabilities:
116 if cap.lower() in query_lower:
117 score += 2.0
119 # Tag match
120 for tag in self.tags:
121 if tag.lower() in query_lower or tag.lower() in keywords:
122 score += 1.0
124 # Parameter name match (user mentioned specific fields)
125 for p in self.params:
126 if p.name.lower() in query_lower:
127 score += 0.5
129 return score
132# ── Tool Registry ─────────────────────────────────────────────────
134class ToolRegistry:
135 """Central tool catalog with versioning, query, and lifecycle management.
137 Features:
138 - Schema-based registration with validation
139 - Semantic search over tool descriptions/capabilities/tags
140 - Version tracking and deprecation warnings
141 - Capability-based grouping
142 - Category-based organization
143 - Rate limiting enforcement
144 """
146 def __init__(self):
147 self._tools: dict[str, ToolSchema] = {}
148 self._by_category: dict[ToolCategory, list[str]] = defaultdict(list)
149 self._by_capability: dict[str, list[str]] = defaultdict(list)
150 self._by_tag: dict[str, list[str]] = defaultdict(list)
151 self._usage_counts: dict[str, int] = defaultdict(int)
152 self._rate_trackers: dict[str, list[float]] = defaultdict(list)
153 self._deprecation_log: list[dict[str, Any]] = []
155 def register(self, tool: ToolSchema) -> ToolSchema:
156 """Register a tool. Overwrites if same name (with warning)."""
157 if tool.name in self._tools:
158 existing = self._tools[tool.name]
159 if existing.version != tool.version:
160 # Version upgrade
161 pass
162 else:
163 pass # Overwrite silently
165 self._tools[tool.name] = tool
166 self._by_category[tool.category].append(tool.name)
167 for cap in tool.capabilities:
168 self._by_capability[cap].append(tool.name)
169 for tag in tool.tags:
170 self._by_tag[tag].append(tool.name)
172 return tool
174 def register_many(self, tools: list[ToolSchema]) -> list[ToolSchema]:
175 """Batch register tools."""
176 return [self.register(t) for t in tools]
178 def unregister(self, name: str) -> bool:
179 """Remove a tool from registry."""
180 if name not in self._tools:
181 return False
183 tool = self._tools.pop(name)
184 self._by_category[tool.category].remove(name)
185 for cap in tool.capabilities:
186 self._by_capability[cap].remove(name)
187 for tag in tool.tags:
188 self._by_tag[tag].remove(name)
189 return True
191 def get(self, name: str) -> ToolSchema | None:
192 """Get tool by name. Returns None and logs warning if deprecated."""
193 tool = self._tools.get(name)
194 if tool and tool.deprecated:
195 self._deprecation_log.append({
196 "tool": name,
197 "message": tool.deprecated_message,
198 "timestamp": time.time(),
199 })
200 return tool
202 def search(
203 self,
204 query: str,
205 top_k: int = 5,
206 category: ToolCategory | None = None,
207 exclude_deprecated: bool = True,
208 ) -> list[tuple[ToolSchema, float]]:
209 """Search tools by natural language query.
211 Returns ranked list of (ToolSchema, score).
212 """
213 keywords = query.lower().split()
214 candidates = []
216 tool_names = list(self._tools.keys())
217 if category:
218 tool_names = [n for n in tool_names if self._tools[n].category == category]
220 for name in tool_names:
221 tool = self._tools[name]
222 if exclude_deprecated and tool.deprecated:
223 continue
224 score = tool.match_score(query, keywords)
225 if score > 0:
226 # Boost frequently-used tools
227 usage_boost = min(self._usage_counts[name] * 0.1, 1.0)
228 candidates.append((tool, score + usage_boost))
230 candidates.sort(key=lambda x: -x[1])
231 return candidates[:top_k]
233 def search_by_capability(self, capability: str) -> list[ToolSchema]:
234 """Find all tools with a specific capability."""
235 names = self._by_capability.get(capability, [])
236 return [self._tools[n] for n in names if n in self._tools]
238 def search_by_tag(self, tag: str) -> list[ToolSchema]:
239 """Find all tools matching a tag."""
240 names = self._by_tag.get(tag, [])
241 return [self._tools[n] for n in names if n in self._tools]
243 def list_categories(self) -> dict[ToolCategory, int]:
244 """Count tools per category."""
245 return {cat: len(names) for cat, names in self._by_category.items() if names}
247 def list_capabilities(self) -> list[str]:
248 """List all registered capabilities."""
249 return sorted(self._by_capability.keys())
251 def list_tags(self) -> list[str]:
252 """List all registered tags."""
253 return sorted(self._by_tag.keys())
255 def export_openai_functions(
256 self,
257 category: ToolCategory | None = None,
258 exclude_deprecated: bool = True,
259 ) -> list[dict[str, Any]]:
260 """Export all tools as OpenAI function-calling format."""
261 result = []
262 for tool in self._tools.values():
263 if exclude_deprecated and tool.deprecated:
264 continue
265 if category and tool.category != category:
266 continue
267 result.append(tool.to_openai_function())
268 return result
270 def check_rate_limit(self, name: str) -> bool:
271 """Check if tool is within rate limit. Returns True if allowed."""
272 tool = self._tools.get(name)
273 if not tool or tool.rate_limit <= 0:
274 return True
276 now = time.time()
277 window_start = now - 60 # 1-minute window
278 calls = self._rate_trackers[name]
279 # Clean old entries
280 self._rate_trackers[name] = [t for t in calls if t > window_start]
282 return len(self._rate_trackers[name]) < tool.rate_limit
284 def record_usage(self, name: str) -> None:
285 """Record a tool usage for rate limiting and analytics."""
286 self._usage_counts[name] += 1
287 self._rate_trackers[name].append(time.time())
289 def get_stats(self) -> dict[str, Any]:
290 """Get registry statistics."""
291 return {
292 "total_tools": len(self._tools),
293 "categories": {str(k): len(v) for k, v in self._by_category.items()},
294 "capabilities": len(self._by_capability),
295 "tags": len(self._by_tag),
296 "deprecated": sum(1 for t in self._tools.values() if t.deprecated),
297 "top_used": sorted(
298 [(k, v) for k, v in self._usage_counts.items() if v > 0],
299 key=lambda x: -x[1],
300 )[:10],
301 }
304# ── Tool Router ───────────────────────────────────────────────────
306@dataclass
307class RoutingDecision:
308 """Result of tool routing decision."""
309 tool_name: str
310 tool_schema: ToolSchema | None
311 confidence: float # 0.0 - 1.0
312 reasoning: str # Why this tool was chosen
313 alternatives: list[str] # Fallback tool names
314 params: dict[str, Any] = field(default_factory=dict)
317@dataclass
318class RoutingContext:
319 """Context for tool routing decisions."""
320 task: str # User's task description
321 available_capabilities: list[str] = field(default_factory=list)
322 preferred_category: ToolCategory | None = None
323 exclude_destructive: bool = False
324 min_confidence: float = 0.3 # Minimum confidence threshold
325 max_alternatives: int = 3 # Max fallback alternatives
328class ToolRouter:
329 """Intelligent tool router with semantic matching and fallback chains.
331 Selects the best tool for a given task by:
332 1. Semantic matching via search query
333 2. LLM-driven selection (when available)
334 3. Rule-based fallback selection
335 4. Confidence scoring with threshold gating
336 """
338 def __init__(
339 self,
340 registry: ToolRegistry,
341 llm_selector: Callable[..., Any] | None = None,
342 ):
343 self.registry = registry
344 self.llm_selector = llm_selector # Optional LLM for smarter selection
346 def route(self, context: RoutingContext) -> RoutingDecision:
347 """Route a task to the best tool.
349 Priority:
350 1. LLM selector (if available) — best semantic understanding
351 2. Semantic search — keyword + capability matching
352 3. Default fallback
353 """
354 # Try LLM-based routing
355 if self.llm_selector and self._is_llm_worthwhile(context.task):
356 decision = self._llm_route(context)
357 if decision and decision.confidence >= context.min_confidence:
358 return decision
360 # Semantic search routing
361 return self._semantic_route(context)
363 def _is_llm_worthwhile(self, task: str) -> bool:
364 """Heuristic: LLM routing is worthwhile for complex tasks."""
365 # Simple one-word or obvious tool names don't need LLM
366 task_lower = task.lower().strip()
367 # If task is just a tool name, skip LLM
368 if task_lower in self.registry._tools:
369 return False
370 # If task is very short (~2 words), skip LLM
371 if len(task_lower.split()) <= 2:
372 return False
373 return True
375 def _llm_route(self, context: RoutingContext) -> RoutingDecision | None:
376 """Use LLM for intelligent tool selection."""
377 try:
378 tools_desc = self._build_tools_description(context)
379 prompt = (
380 f"Task: {context.task}\n\n"
381 f"Available tools:\n{tools_desc}\n\n"
382 "Select the best tool. Reply with JSON:\n"
383 '{"tool_name": "xxx", "confidence": 0.0-1.0, "reasoning": "why", '
384 '"alternatives": ["tool2", "tool3"]}'
385 )
386 result = self.llm_selector(prompt)
387 if isinstance(result, str):
388 result = json.loads(result)
390 tool_name = result.get("tool_name", "")
391 tool = self.registry.get(tool_name)
392 if not tool:
393 return None
395 return RoutingDecision(
396 tool_name=tool_name,
397 tool_schema=tool,
398 confidence=float(result.get("confidence", 0.5)),
399 reasoning=str(result.get("reasoning", "")),
400 alternatives=result.get("alternatives", []),
401 )
402 except Exception:
403 return None
405 def _semantic_route(self, context: RoutingContext) -> RoutingDecision:
406 """Semantic search-based routing with confidence scoring."""
407 candidates = self.registry.search(
408 query=context.task,
409 top_k=context.max_alternatives + 1,
410 category=context.preferred_category,
411 )
413 if not candidates:
414 return RoutingDecision(
415 tool_name="",
416 tool_schema=None,
417 confidence=0.0,
418 reasoning="No matching tool found",
419 alternatives=[],
420 )
422 # Filter destructive tools if excluded
423 if context.exclude_destructive:
424 candidates = [
425 (t, s) for t, s in candidates if not t.is_destructive
426 ]
428 if not candidates:
429 return RoutingDecision(
430 tool_name="",
431 tool_schema=None,
432 confidence=0.0,
433 reasoning="All matching tools are destructive (excluded)",
434 alternatives=[],
435 )
437 # Normalize scores to 0-1 confidence
438 if len(candidates) == 1:
439 best_tool, raw_score = candidates[0]
440 confidence = min(raw_score / 10.0, 1.0)
441 alternatives = []
442 else:
443 scores = [s for _, s in candidates]
444 max_s = max(scores) if scores else 1
445 best_tool, raw_score = candidates[0]
446 confidence = min(raw_score / max_s, 1.0) if max_s > 0 else 0.5
447 alternatives = [t.name for t, _ in candidates[1:context.max_alternatives + 1]]
449 return RoutingDecision(
450 tool_name=best_tool.name,
451 tool_schema=best_tool,
452 confidence=confidence,
453 reasoning=f"Best match: {best_tool.name} (score={raw_score:.1f})",
454 alternatives=alternatives,
455 )
457 def _build_tools_description(self, context: RoutingContext) -> str:
458 """Build a compact tool description for LLM prompt."""
459 tools = []
460 # Prioritize by category
461 names = list(self.registry._tools.keys())
462 if context.preferred_category:
463 cat_names = self.registry._by_category.get(context.preferred_category, [])
464 names = cat_names + [n for n in names if n not in cat_names]
466 for name in names[:20]: # Limit to avoid huge prompts
467 tool = self.registry._tools[name]
468 if tool.deprecated:
469 continue
470 if context.exclude_destructive and tool.is_destructive:
471 continue
472 params_desc = ", ".join(
473 f"{p.name}:{p.type}" + ("?" if not p.required else "")
474 for p in tool.params[:5]
475 )
476 cap_tags = ", ".join(tool.capabilities[:3])
477 tools.append(
478 f"- {tool.name}: {tool.description[:100]}. "
479 f"Params: [{params_desc}]. Caps: [{cap_tags}]"
480 )
482 return "\n".join(tools)
485# ── Tool Execution Engine ─────────────────────────────────────────
487class ToolExecutionError(Exception):
488 """Raised when tool execution fails."""
489 def __init__(self, tool_name: str, message: str, recoverable: bool = True):
490 self.tool_name = tool_name
491 self.recoverable = recoverable
492 super().__init__(f"[{tool_name}] {message}")
495class ToolExecutor:
496 """Execution engine for registered tools with safety and error handling.
498 Features:
499 - Rate limit enforcement
500 - Parameter validation
501 - Destructive operation confirmation
502 - Timeout protection
503 - Error categorization (recoverable vs fatal)
504 """
506 def __init__(
507 self,
508 registry: ToolRegistry,
509 timeout: float = 30.0,
510 require_destructive_confirm: bool = True,
511 ):
512 self.registry = registry
513 self.timeout = timeout
514 self.require_destructive_confirm = require_destructive_confirm
515 self._pending_confirmations: dict[str, dict[str, Any]] = {}
517 def execute(
518 self,
519 tool_name: str,
520 params: dict[str, Any] | None = None,
521 force: bool = False,
522 ) -> Any:
523 """Execute a registered tool with safety checks.
525 Args:
526 tool_name: Registered tool name
527 params: Tool parameters
528 force: Skip destructive confirmation (use with caution)
530 Returns:
531 Tool execution result
533 Raises:
534 ToolExecutionError: On execution failure
535 """
536 params = params or {}
538 tool = self.registry.get(tool_name)
539 if not tool:
540 raise ToolExecutionError(tool_name, f"Tool '{tool_name}' not registered", recoverable=False)
542 if tool.deprecated:
543 raise ToolExecutionError(
544 tool_name,
545 f"Tool deprecated: {tool.deprecated_message}",
546 recoverable=False,
547 )
549 # Rate limit check
550 if not self.registry.check_rate_limit(tool_name):
551 raise ToolExecutionError(
552 tool_name,
553 f"Rate limit exceeded ({tool.rate_limit}/min)",
554 recoverable=True,
555 )
557 # Destructive check
558 if tool.is_destructive and self.require_destructive_confirm and not force:
559 self._pending_confirmations[tool_name] = params
560 raise ToolExecutionError(
561 tool_name,
562 "Destructive operation requires confirmation (pass force=True to skip)",
563 recoverable=False,
564 )
566 # Parameter validation
567 self._validate_params(tool, params)
569 # Record usage (before execution to prevent double-counting on retry)
570 self.registry.record_usage(tool_name)
572 # Execute
573 if not tool.handler:
574 raise ToolExecutionError(
575 tool_name,
576 "No handler registered for tool",
577 recoverable=False,
578 )
580 try:
581 result = tool.handler(**params)
582 except Exception as e:
583 raise ToolExecutionError(
584 tool_name,
585 f"Execution failed: {str(e)}",
586 recoverable=True,
587 ) from e
589 return result
591 def confirm_destructive(self, tool_name: str) -> Any:
592 """Confirm and execute a pending destructive operation."""
593 if tool_name not in self._pending_confirmations:
594 raise ToolExecutionError(tool_name, "No pending confirmation", recoverable=False)
595 params = self._pending_confirmations.pop(tool_name)
596 return self.execute(tool_name, params, force=True)
598 def cancel_destructive(self, tool_name: str) -> bool:
599 """Cancel a pending destructive operation."""
600 if tool_name in self._pending_confirmations:
601 del self._pending_confirmations[tool_name]
602 return True
603 return False
605 def _validate_params(self, tool: ToolSchema, params: dict[str, Any]) -> None:
606 """Validate parameters against schema."""
607 for p in tool.params:
608 if p.required and p.name not in params:
609 raise ToolExecutionError(
610 tool.name,
611 f"Missing required parameter: {p.name} ({p.description})",
612 recoverable=False,
613 )
615 if p.name in params:
616 value = params[p.name]
617 # Type check
618 if p.type == "str" and not isinstance(value, str):
619 raise ToolExecutionError(tool.name, f"Parameter '{p.name}' must be string", recoverable=False)
620 if p.type == "int" and not isinstance(value, int):
621 raise ToolExecutionError(tool.name, f"Parameter '{p.name}' must be int", recoverable=False)
622 if p.type == "float" and not isinstance(value, (int, float)):
623 raise ToolExecutionError(tool.name, f"Parameter '{p.name}' must be number", recoverable=False)
624 if p.type == "bool" and not isinstance(value, bool):
625 raise ToolExecutionError(tool.name, f"Parameter '{p.name}' must be bool", recoverable=False)
627 # Enum check
628 if p.enum_values and value not in p.enum_values:
629 raise ToolExecutionError(
630 tool.name,
631 f"Parameter '{p.name}' must be one of: {p.enum_values}",
632 recoverable=False,
633 )
635 # Range check
636 if isinstance(value, (int, float)):
637 if p.min_value is not None and value < p.min_value:
638 raise ToolExecutionError(
639 tool.name,
640 f"Parameter '{p.name}' minimum is {p.min_value}",
641 recoverable=False,
642 )
643 if p.max_value is not None and value > p.max_value:
644 raise ToolExecutionError(
645 tool.name,
646 f"Parameter '{p.name}' maximum is {p.max_value}",
647 recoverable=False,
648 )
650 def get_pending_confirmations(self) -> list[str]:
651 """List tools awaiting destructive confirmation."""
652 return list(self._pending_confirmations.keys())
655# ── Utility helpers ───────────────────────────────────────────────
657def create_tool(
658 name: str,
659 description: str,
660 handler: Callable,
661 category: ToolCategory = ToolCategory.CUSTOM,
662 params: list[ToolParam] | None = None,
663 capabilities: list[str] | None = None,
664 tags: list[str] | None = None,
665 is_destructive: bool = False,
666 rate_limit: int = 0,
667 **kwargs,
668) -> ToolSchema:
669 """Quick helper to create a tool schema."""
670 return ToolSchema(
671 name=name,
672 description=description,
673 category=category,
674 params=params or [],
675 capabilities=capabilities or [],
676 tags=tags or [],
677 handler=handler,
678 is_destructive=is_destructive,
679 rate_limit=rate_limit,
680 **kwargs,
681 )