Coverage for agentos/swarm/tool_registry.py: 31%

318 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2v1.9.8: Dynamic Tool Registry + Intelligent Tool Router. 

3 

4ToolRegistry: schema-based tool catalog with versioning, capability tags, and dependency tracking. 

5ToolRouter: LLM-driven tool selection with semantic matching, confidence scoring, and fallback chains. 

6""" 

7 

8from __future__ import annotations 

9 

10import json 

11import time 

12from collections import defaultdict 

13from dataclasses import dataclass, field 

14from enum import Enum 

15from typing import Any, Callable, Optional 

16 

17 

18# ── Tool Schema ─────────────────────────────────────────────────── 

19 

20class ToolCategory(str, Enum): 

21 """Top-level tool category for coarse-grained routing.""" 

22 FILE = "file" 

23 NETWORK = "network" 

24 CODE = "code" 

25 SYSTEM = "system" 

26 DATA = "data" 

27 AGENT = "agent" 

28 CUSTOM = "custom" 

29 

30 

31@dataclass 

32class ToolParam: 

33 """Parameter definition for a tool.""" 

34 name: str 

35 type: str # str, int, float, bool, list, dict 

36 description: str = "" 

37 required: bool = False 

38 default: Any = None 

39 enum_values: list[str] | None = None 

40 min_value: float | None = None 

41 max_value: float | None = None 

42 pattern: str = "" # Regex validation pattern 

43 

44 

45@dataclass 

46class ToolSchema: 

47 """Complete tool schema definition.""" 

48 name: str # Unique tool name 

49 description: str # Human-readable description 

50 category: ToolCategory = ToolCategory.CUSTOM 

51 params: list[ToolParam] = field(default_factory=list) 

52 returns: str = "any" # Return type description 

53 version: str = "1.0.0" 

54 capabilities: list[str] = field(default_factory=list) # e.g. ["read", "text", "file"] 

55 tags: list[str] = field(default_factory=list) # Searchable tags 

56 dependencies: list[str] = field(default_factory=list) # Required other tools 

57 handler: Callable[..., Any] | None = None # Actual implementation 

58 handler_ref: str = "" # String reference for serialization 

59 cost_estimate: float = 0.0 # Relative cost (latency, tokens, etc) 

60 is_destructive: bool = False # Data-modifying operations 

61 requires_auth: bool = False # Needs authentication 

62 rate_limit: int = 0 # Max calls per minute, 0 = unlimited 

63 deprecated: bool = False 

64 deprecated_message: str = "" 

65 metadata: dict[str, Any] = field(default_factory=dict) 

66 

67 def to_openai_function(self) -> dict[str, Any]: 

68 """Export schema as OpenAI function-calling format.""" 

69 properties = {} 

70 required = [] 

71 for p in self.params: 

72 prop: dict[str, Any] = { 

73 "type": p.type, 

74 "description": p.description, 

75 } 

76 if p.enum_values: 

77 prop["enum"] = p.enum_values 

78 properties[p.name] = prop 

79 if p.required: 

80 required.append(p.name) 

81 

82 return { 

83 "type": "function", 

84 "function": { 

85 "name": self.name, 

86 "description": self.description, 

87 "parameters": { 

88 "type": "object", 

89 "properties": properties, 

90 "required": required, 

91 }, 

92 }, 

93 } 

94 

95 def match_score(self, query: str, keywords: list[str]) -> float: 

96 """Compute relevance score for a natural language query.""" 

97 query_lower = query.lower() 

98 score = 0.0 

99 

100 # Name exact match 

101 if self.name.lower() == query_lower: 

102 score += 10.0 

103 elif self.name.lower() in query_lower: 

104 score += 5.0 

105 

106 # Description match 

107 desc_lower = self.description.lower() 

108 if query_lower in desc_lower: 

109 score += 3.0 

110 for kw in keywords: 

111 if kw in desc_lower: 

112 score += 1.5 

113 

114 # Capability tag match 

115 for cap in self.capabilities: 

116 if cap.lower() in query_lower: 

117 score += 2.0 

118 

119 # Tag match 

120 for tag in self.tags: 

121 if tag.lower() in query_lower or tag.lower() in keywords: 

122 score += 1.0 

123 

124 # Parameter name match (user mentioned specific fields) 

125 for p in self.params: 

126 if p.name.lower() in query_lower: 

127 score += 0.5 

128 

129 return score 

130 

131 

132# ── Tool Registry ───────────────────────────────────────────────── 

133 

134class ToolRegistry: 

135 """Central tool catalog with versioning, query, and lifecycle management. 

136 

137 Features: 

138 - Schema-based registration with validation 

139 - Semantic search over tool descriptions/capabilities/tags 

140 - Version tracking and deprecation warnings 

141 - Capability-based grouping 

142 - Category-based organization 

143 - Rate limiting enforcement 

144 """ 

145 

146 def __init__(self): 

147 self._tools: dict[str, ToolSchema] = {} 

148 self._by_category: dict[ToolCategory, list[str]] = defaultdict(list) 

149 self._by_capability: dict[str, list[str]] = defaultdict(list) 

150 self._by_tag: dict[str, list[str]] = defaultdict(list) 

151 self._usage_counts: dict[str, int] = defaultdict(int) 

152 self._rate_trackers: dict[str, list[float]] = defaultdict(list) 

153 self._deprecation_log: list[dict[str, Any]] = [] 

154 

155 def register(self, tool: ToolSchema) -> ToolSchema: 

156 """Register a tool. Overwrites if same name (with warning).""" 

157 if tool.name in self._tools: 

158 existing = self._tools[tool.name] 

159 if existing.version != tool.version: 

160 # Version upgrade 

161 pass 

162 else: 

163 pass # Overwrite silently 

164 

165 self._tools[tool.name] = tool 

166 self._by_category[tool.category].append(tool.name) 

167 for cap in tool.capabilities: 

168 self._by_capability[cap].append(tool.name) 

169 for tag in tool.tags: 

170 self._by_tag[tag].append(tool.name) 

171 

172 return tool 

173 

174 def register_many(self, tools: list[ToolSchema]) -> list[ToolSchema]: 

175 """Batch register tools.""" 

176 return [self.register(t) for t in tools] 

177 

178 def unregister(self, name: str) -> bool: 

179 """Remove a tool from registry.""" 

180 if name not in self._tools: 

181 return False 

182 

183 tool = self._tools.pop(name) 

184 self._by_category[tool.category].remove(name) 

185 for cap in tool.capabilities: 

186 self._by_capability[cap].remove(name) 

187 for tag in tool.tags: 

188 self._by_tag[tag].remove(name) 

189 return True 

190 

191 def get(self, name: str) -> ToolSchema | None: 

192 """Get tool by name. Returns None and logs warning if deprecated.""" 

193 tool = self._tools.get(name) 

194 if tool and tool.deprecated: 

195 self._deprecation_log.append({ 

196 "tool": name, 

197 "message": tool.deprecated_message, 

198 "timestamp": time.time(), 

199 }) 

200 return tool 

201 

202 def search( 

203 self, 

204 query: str, 

205 top_k: int = 5, 

206 category: ToolCategory | None = None, 

207 exclude_deprecated: bool = True, 

208 ) -> list[tuple[ToolSchema, float]]: 

209 """Search tools by natural language query. 

210 

211 Returns ranked list of (ToolSchema, score). 

212 """ 

213 keywords = query.lower().split() 

214 candidates = [] 

215 

216 tool_names = list(self._tools.keys()) 

217 if category: 

218 tool_names = [n for n in tool_names if self._tools[n].category == category] 

219 

220 for name in tool_names: 

221 tool = self._tools[name] 

222 if exclude_deprecated and tool.deprecated: 

223 continue 

224 score = tool.match_score(query, keywords) 

225 if score > 0: 

226 # Boost frequently-used tools 

227 usage_boost = min(self._usage_counts[name] * 0.1, 1.0) 

228 candidates.append((tool, score + usage_boost)) 

229 

230 candidates.sort(key=lambda x: -x[1]) 

231 return candidates[:top_k] 

232 

233 def search_by_capability(self, capability: str) -> list[ToolSchema]: 

234 """Find all tools with a specific capability.""" 

235 names = self._by_capability.get(capability, []) 

236 return [self._tools[n] for n in names if n in self._tools] 

237 

238 def search_by_tag(self, tag: str) -> list[ToolSchema]: 

239 """Find all tools matching a tag.""" 

240 names = self._by_tag.get(tag, []) 

241 return [self._tools[n] for n in names if n in self._tools] 

242 

243 def list_categories(self) -> dict[ToolCategory, int]: 

244 """Count tools per category.""" 

245 return {cat: len(names) for cat, names in self._by_category.items() if names} 

246 

247 def list_capabilities(self) -> list[str]: 

248 """List all registered capabilities.""" 

249 return sorted(self._by_capability.keys()) 

250 

251 def list_tags(self) -> list[str]: 

252 """List all registered tags.""" 

253 return sorted(self._by_tag.keys()) 

254 

255 def export_openai_functions( 

256 self, 

257 category: ToolCategory | None = None, 

258 exclude_deprecated: bool = True, 

259 ) -> list[dict[str, Any]]: 

260 """Export all tools as OpenAI function-calling format.""" 

261 result = [] 

262 for tool in self._tools.values(): 

263 if exclude_deprecated and tool.deprecated: 

264 continue 

265 if category and tool.category != category: 

266 continue 

267 result.append(tool.to_openai_function()) 

268 return result 

269 

270 def check_rate_limit(self, name: str) -> bool: 

271 """Check if tool is within rate limit. Returns True if allowed.""" 

272 tool = self._tools.get(name) 

273 if not tool or tool.rate_limit <= 0: 

274 return True 

275 

276 now = time.time() 

277 window_start = now - 60 # 1-minute window 

278 calls = self._rate_trackers[name] 

279 # Clean old entries 

280 self._rate_trackers[name] = [t for t in calls if t > window_start] 

281 

282 return len(self._rate_trackers[name]) < tool.rate_limit 

283 

284 def record_usage(self, name: str) -> None: 

285 """Record a tool usage for rate limiting and analytics.""" 

286 self._usage_counts[name] += 1 

287 self._rate_trackers[name].append(time.time()) 

288 

289 def get_stats(self) -> dict[str, Any]: 

290 """Get registry statistics.""" 

291 return { 

292 "total_tools": len(self._tools), 

293 "categories": {str(k): len(v) for k, v in self._by_category.items()}, 

294 "capabilities": len(self._by_capability), 

295 "tags": len(self._by_tag), 

296 "deprecated": sum(1 for t in self._tools.values() if t.deprecated), 

297 "top_used": sorted( 

298 [(k, v) for k, v in self._usage_counts.items() if v > 0], 

299 key=lambda x: -x[1], 

300 )[:10], 

301 } 

302 

303 

304# ── Tool Router ─────────────────────────────────────────────────── 

305 

306@dataclass 

307class RoutingDecision: 

308 """Result of tool routing decision.""" 

309 tool_name: str 

310 tool_schema: ToolSchema | None 

311 confidence: float # 0.0 - 1.0 

312 reasoning: str # Why this tool was chosen 

313 alternatives: list[str] # Fallback tool names 

314 params: dict[str, Any] = field(default_factory=dict) 

315 

316 

317@dataclass 

318class RoutingContext: 

319 """Context for tool routing decisions.""" 

320 task: str # User's task description 

321 available_capabilities: list[str] = field(default_factory=list) 

322 preferred_category: ToolCategory | None = None 

323 exclude_destructive: bool = False 

324 min_confidence: float = 0.3 # Minimum confidence threshold 

325 max_alternatives: int = 3 # Max fallback alternatives 

326 

327 

328class ToolRouter: 

329 """Intelligent tool router with semantic matching and fallback chains. 

330 

331 Selects the best tool for a given task by: 

332 1. Semantic matching via search query 

333 2. LLM-driven selection (when available) 

334 3. Rule-based fallback selection 

335 4. Confidence scoring with threshold gating 

336 """ 

337 

338 def __init__( 

339 self, 

340 registry: ToolRegistry, 

341 llm_selector: Callable[..., Any] | None = None, 

342 ): 

343 self.registry = registry 

344 self.llm_selector = llm_selector # Optional LLM for smarter selection 

345 

346 def route(self, context: RoutingContext) -> RoutingDecision: 

347 """Route a task to the best tool. 

348 

349 Priority: 

350 1. LLM selector (if available) — best semantic understanding 

351 2. Semantic search — keyword + capability matching 

352 3. Default fallback 

353 """ 

354 # Try LLM-based routing 

355 if self.llm_selector and self._is_llm_worthwhile(context.task): 

356 decision = self._llm_route(context) 

357 if decision and decision.confidence >= context.min_confidence: 

358 return decision 

359 

360 # Semantic search routing 

361 return self._semantic_route(context) 

362 

363 def _is_llm_worthwhile(self, task: str) -> bool: 

364 """Heuristic: LLM routing is worthwhile for complex tasks.""" 

365 # Simple one-word or obvious tool names don't need LLM 

366 task_lower = task.lower().strip() 

367 # If task is just a tool name, skip LLM 

368 if task_lower in self.registry._tools: 

369 return False 

370 # If task is very short (~2 words), skip LLM 

371 if len(task_lower.split()) <= 2: 

372 return False 

373 return True 

374 

375 def _llm_route(self, context: RoutingContext) -> RoutingDecision | None: 

376 """Use LLM for intelligent tool selection.""" 

377 try: 

378 tools_desc = self._build_tools_description(context) 

379 prompt = ( 

380 f"Task: {context.task}\n\n" 

381 f"Available tools:\n{tools_desc}\n\n" 

382 "Select the best tool. Reply with JSON:\n" 

383 '{"tool_name": "xxx", "confidence": 0.0-1.0, "reasoning": "why", ' 

384 '"alternatives": ["tool2", "tool3"]}' 

385 ) 

386 result = self.llm_selector(prompt) 

387 if isinstance(result, str): 

388 result = json.loads(result) 

389 

390 tool_name = result.get("tool_name", "") 

391 tool = self.registry.get(tool_name) 

392 if not tool: 

393 return None 

394 

395 return RoutingDecision( 

396 tool_name=tool_name, 

397 tool_schema=tool, 

398 confidence=float(result.get("confidence", 0.5)), 

399 reasoning=str(result.get("reasoning", "")), 

400 alternatives=result.get("alternatives", []), 

401 ) 

402 except Exception: 

403 return None 

404 

405 def _semantic_route(self, context: RoutingContext) -> RoutingDecision: 

406 """Semantic search-based routing with confidence scoring.""" 

407 candidates = self.registry.search( 

408 query=context.task, 

409 top_k=context.max_alternatives + 1, 

410 category=context.preferred_category, 

411 ) 

412 

413 if not candidates: 

414 return RoutingDecision( 

415 tool_name="", 

416 tool_schema=None, 

417 confidence=0.0, 

418 reasoning="No matching tool found", 

419 alternatives=[], 

420 ) 

421 

422 # Filter destructive tools if excluded 

423 if context.exclude_destructive: 

424 candidates = [ 

425 (t, s) for t, s in candidates if not t.is_destructive 

426 ] 

427 

428 if not candidates: 

429 return RoutingDecision( 

430 tool_name="", 

431 tool_schema=None, 

432 confidence=0.0, 

433 reasoning="All matching tools are destructive (excluded)", 

434 alternatives=[], 

435 ) 

436 

437 # Normalize scores to 0-1 confidence 

438 if len(candidates) == 1: 

439 best_tool, raw_score = candidates[0] 

440 confidence = min(raw_score / 10.0, 1.0) 

441 alternatives = [] 

442 else: 

443 scores = [s for _, s in candidates] 

444 max_s = max(scores) if scores else 1 

445 best_tool, raw_score = candidates[0] 

446 confidence = min(raw_score / max_s, 1.0) if max_s > 0 else 0.5 

447 alternatives = [t.name for t, _ in candidates[1:context.max_alternatives + 1]] 

448 

449 return RoutingDecision( 

450 tool_name=best_tool.name, 

451 tool_schema=best_tool, 

452 confidence=confidence, 

453 reasoning=f"Best match: {best_tool.name} (score={raw_score:.1f})", 

454 alternatives=alternatives, 

455 ) 

456 

457 def _build_tools_description(self, context: RoutingContext) -> str: 

458 """Build a compact tool description for LLM prompt.""" 

459 tools = [] 

460 # Prioritize by category 

461 names = list(self.registry._tools.keys()) 

462 if context.preferred_category: 

463 cat_names = self.registry._by_category.get(context.preferred_category, []) 

464 names = cat_names + [n for n in names if n not in cat_names] 

465 

466 for name in names[:20]: # Limit to avoid huge prompts 

467 tool = self.registry._tools[name] 

468 if tool.deprecated: 

469 continue 

470 if context.exclude_destructive and tool.is_destructive: 

471 continue 

472 params_desc = ", ".join( 

473 f"{p.name}:{p.type}" + ("?" if not p.required else "") 

474 for p in tool.params[:5] 

475 ) 

476 cap_tags = ", ".join(tool.capabilities[:3]) 

477 tools.append( 

478 f"- {tool.name}: {tool.description[:100]}. " 

479 f"Params: [{params_desc}]. Caps: [{cap_tags}]" 

480 ) 

481 

482 return "\n".join(tools) 

483 

484 

485# ── Tool Execution Engine ───────────────────────────────────────── 

486 

487class ToolExecutionError(Exception): 

488 """Raised when tool execution fails.""" 

489 def __init__(self, tool_name: str, message: str, recoverable: bool = True): 

490 self.tool_name = tool_name 

491 self.recoverable = recoverable 

492 super().__init__(f"[{tool_name}] {message}") 

493 

494 

495class ToolExecutor: 

496 """Execution engine for registered tools with safety and error handling. 

497 

498 Features: 

499 - Rate limit enforcement 

500 - Parameter validation 

501 - Destructive operation confirmation 

502 - Timeout protection 

503 - Error categorization (recoverable vs fatal) 

504 """ 

505 

506 def __init__( 

507 self, 

508 registry: ToolRegistry, 

509 timeout: float = 30.0, 

510 require_destructive_confirm: bool = True, 

511 ): 

512 self.registry = registry 

513 self.timeout = timeout 

514 self.require_destructive_confirm = require_destructive_confirm 

515 self._pending_confirmations: dict[str, dict[str, Any]] = {} 

516 

517 def execute( 

518 self, 

519 tool_name: str, 

520 params: dict[str, Any] | None = None, 

521 force: bool = False, 

522 ) -> Any: 

523 """Execute a registered tool with safety checks. 

524 

525 Args: 

526 tool_name: Registered tool name 

527 params: Tool parameters 

528 force: Skip destructive confirmation (use with caution) 

529 

530 Returns: 

531 Tool execution result 

532 

533 Raises: 

534 ToolExecutionError: On execution failure 

535 """ 

536 params = params or {} 

537 

538 tool = self.registry.get(tool_name) 

539 if not tool: 

540 raise ToolExecutionError(tool_name, f"Tool '{tool_name}' not registered", recoverable=False) 

541 

542 if tool.deprecated: 

543 raise ToolExecutionError( 

544 tool_name, 

545 f"Tool deprecated: {tool.deprecated_message}", 

546 recoverable=False, 

547 ) 

548 

549 # Rate limit check 

550 if not self.registry.check_rate_limit(tool_name): 

551 raise ToolExecutionError( 

552 tool_name, 

553 f"Rate limit exceeded ({tool.rate_limit}/min)", 

554 recoverable=True, 

555 ) 

556 

557 # Destructive check 

558 if tool.is_destructive and self.require_destructive_confirm and not force: 

559 self._pending_confirmations[tool_name] = params 

560 raise ToolExecutionError( 

561 tool_name, 

562 "Destructive operation requires confirmation (pass force=True to skip)", 

563 recoverable=False, 

564 ) 

565 

566 # Parameter validation 

567 self._validate_params(tool, params) 

568 

569 # Record usage (before execution to prevent double-counting on retry) 

570 self.registry.record_usage(tool_name) 

571 

572 # Execute 

573 if not tool.handler: 

574 raise ToolExecutionError( 

575 tool_name, 

576 "No handler registered for tool", 

577 recoverable=False, 

578 ) 

579 

580 try: 

581 result = tool.handler(**params) 

582 except Exception as e: 

583 raise ToolExecutionError( 

584 tool_name, 

585 f"Execution failed: {str(e)}", 

586 recoverable=True, 

587 ) from e 

588 

589 return result 

590 

591 def confirm_destructive(self, tool_name: str) -> Any: 

592 """Confirm and execute a pending destructive operation.""" 

593 if tool_name not in self._pending_confirmations: 

594 raise ToolExecutionError(tool_name, "No pending confirmation", recoverable=False) 

595 params = self._pending_confirmations.pop(tool_name) 

596 return self.execute(tool_name, params, force=True) 

597 

598 def cancel_destructive(self, tool_name: str) -> bool: 

599 """Cancel a pending destructive operation.""" 

600 if tool_name in self._pending_confirmations: 

601 del self._pending_confirmations[tool_name] 

602 return True 

603 return False 

604 

605 def _validate_params(self, tool: ToolSchema, params: dict[str, Any]) -> None: 

606 """Validate parameters against schema.""" 

607 for p in tool.params: 

608 if p.required and p.name not in params: 

609 raise ToolExecutionError( 

610 tool.name, 

611 f"Missing required parameter: {p.name} ({p.description})", 

612 recoverable=False, 

613 ) 

614 

615 if p.name in params: 

616 value = params[p.name] 

617 # Type check 

618 if p.type == "str" and not isinstance(value, str): 

619 raise ToolExecutionError(tool.name, f"Parameter '{p.name}' must be string", recoverable=False) 

620 if p.type == "int" and not isinstance(value, int): 

621 raise ToolExecutionError(tool.name, f"Parameter '{p.name}' must be int", recoverable=False) 

622 if p.type == "float" and not isinstance(value, (int, float)): 

623 raise ToolExecutionError(tool.name, f"Parameter '{p.name}' must be number", recoverable=False) 

624 if p.type == "bool" and not isinstance(value, bool): 

625 raise ToolExecutionError(tool.name, f"Parameter '{p.name}' must be bool", recoverable=False) 

626 

627 # Enum check 

628 if p.enum_values and value not in p.enum_values: 

629 raise ToolExecutionError( 

630 tool.name, 

631 f"Parameter '{p.name}' must be one of: {p.enum_values}", 

632 recoverable=False, 

633 ) 

634 

635 # Range check 

636 if isinstance(value, (int, float)): 

637 if p.min_value is not None and value < p.min_value: 

638 raise ToolExecutionError( 

639 tool.name, 

640 f"Parameter '{p.name}' minimum is {p.min_value}", 

641 recoverable=False, 

642 ) 

643 if p.max_value is not None and value > p.max_value: 

644 raise ToolExecutionError( 

645 tool.name, 

646 f"Parameter '{p.name}' maximum is {p.max_value}", 

647 recoverable=False, 

648 ) 

649 

650 def get_pending_confirmations(self) -> list[str]: 

651 """List tools awaiting destructive confirmation.""" 

652 return list(self._pending_confirmations.keys()) 

653 

654 

655# ── Utility helpers ─────────────────────────────────────────────── 

656 

657def create_tool( 

658 name: str, 

659 description: str, 

660 handler: Callable, 

661 category: ToolCategory = ToolCategory.CUSTOM, 

662 params: list[ToolParam] | None = None, 

663 capabilities: list[str] | None = None, 

664 tags: list[str] | None = None, 

665 is_destructive: bool = False, 

666 rate_limit: int = 0, 

667 **kwargs, 

668) -> ToolSchema: 

669 """Quick helper to create a tool schema.""" 

670 return ToolSchema( 

671 name=name, 

672 description=description, 

673 category=category, 

674 params=params or [], 

675 capabilities=capabilities or [], 

676 tags=tags or [], 

677 handler=handler, 

678 is_destructive=is_destructive, 

679 rate_limit=rate_limit, 

680 **kwargs, 

681 )