Coverage for mcpgateway/services/tool_service.py: 91%

341 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""Tool Service Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements tool management and invocation according to the MCP specification. 

9It handles: 

10- Tool registration and validation 

11- Tool invocation with schema validation 

12- Tool federation across gateways 

13- Event notifications for tool changes 

14- Active/inactive tool management 

15""" 

16 

17# Standard 

18import asyncio 

19import base64 

20from datetime import datetime, timezone 

21import json 

22import logging 

23import re 

24import time 

25from typing import Any, AsyncGenerator, Dict, List, Optional 

26 

27# Third-Party 

28import httpx 

29from mcp import ClientSession 

30from mcp.client.sse import sse_client 

31from mcp.client.streamable_http import streamablehttp_client 

32from sqlalchemy import case, delete, func, literal, not_, select 

33from sqlalchemy.exc import IntegrityError 

34from sqlalchemy.orm import Session 

35 

36# First-Party 

37from mcpgateway.config import settings 

38from mcpgateway.db import Gateway as DbGateway 

39from mcpgateway.db import server_tool_association 

40from mcpgateway.db import Tool as DbTool 

41from mcpgateway.db import ToolMetric 

42from mcpgateway.models import TextContent, ToolResult 

43from mcpgateway.schemas import ( 

44 ToolCreate, 

45 ToolRead, 

46 ToolUpdate, 

47) 

48from mcpgateway.utils.create_slug import slugify 

49from mcpgateway.utils.services_auth import decode_auth 

50 

51# Local 

52from ..config import extract_using_jq 

53 

54logger = logging.getLogger(__name__) 

55 

56 

57class ToolError(Exception): 

58 """Base class for tool-related errors.""" 

59 

60 

61class ToolNotFoundError(ToolError): 

62 """Raised when a requested tool is not found.""" 

63 

64 

65class ToolNameConflictError(ToolError): 

66 """Raised when a tool name conflicts with existing (active or inactive) tool.""" 

67 

68 def __init__(self, name: str, enabled: bool = True, tool_id: Optional[int] = None): 

69 """Initialize the error with tool information. 

70 

71 Args: 

72 name: The conflicting tool name. 

73 enabled: Whether the existing tool is enabled or not. 

74 tool_id: ID of the existing tool if available. 

75 """ 

76 self.name = name 

77 self.enabled = enabled 

78 self.tool_id = tool_id 

79 message = f"Tool already exists with name: {name}" 

80 if not enabled: 

81 message += f" (currently inactive, ID: {tool_id})" 

82 super().__init__(message) 

83 

84 

85class ToolValidationError(ToolError): 

86 """Raised when tool validation fails.""" 

87 

88 

89class ToolInvocationError(ToolError): 

90 """Raised when tool invocation fails.""" 

91 

92 

93class ToolService: 

94 """Service for managing and invoking tools. 

95 

96 Handles: 

97 - Tool registration and deregistration. 

98 - Tool invocation and validation. 

99 - Tool federation. 

100 - Event notifications. 

101 - Active/inactive tool management. 

102 """ 

103 

104 def __init__(self): 

105 """Initialize the tool service.""" 

106 self._event_subscribers: List[asyncio.Queue] = [] 

107 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

108 

109 async def initialize(self) -> None: 

110 """Initialize the service.""" 

111 logger.info("Initializing tool service") 

112 

113 async def shutdown(self) -> None: 

114 """Shutdown the service.""" 

115 await self._http_client.aclose() 

116 logger.info("Tool service shutdown complete") 

117 

118 def _convert_tool_to_read(self, tool: DbTool) -> ToolRead: 

119 """ 

120 Converts a DbTool instance into a ToolRead model, including aggregated metrics and 

121 new API gateway fields: request_type and authentication credentials (masked). 

122 

123 Args: 

124 tool (DbTool): The ORM instance of the tool. 

125 

126 Returns: 

127 ToolRead: The Pydantic model representing the tool, including aggregated metrics and new fields. 

128 """ 

129 tool_dict = tool.__dict__.copy() 

130 tool_dict.pop("_sa_instance_state", None) 

131 tool_dict["execution_count"] = tool.execution_count 

132 tool_dict["metrics"] = tool.metrics_summary 

133 tool_dict["request_type"] = tool.request_type 

134 tool_dict["annotations"] = tool.annotations or {} 

135 

136 decoded_auth_value = decode_auth(tool.auth_value) 

137 if tool.auth_type == "basic": 

138 decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1]) 

139 username, password = decoded_bytes.decode("utf-8").split(":") 

140 tool_dict["auth"] = { 

141 "auth_type": "basic", 

142 "username": username, 

143 "password": "********" if password else None, 

144 } 

145 elif tool.auth_type == "bearer": 

146 tool_dict["auth"] = { 

147 "auth_type": "bearer", 

148 "token": "********" if decoded_auth_value["Authorization"] else None, 

149 } 

150 elif tool.auth_type == "authheaders": 

151 tool_dict["auth"] = { 

152 "auth_type": "authheaders", 

153 "auth_header_key": next(iter(decoded_auth_value)), 

154 "auth_header_value": "********" if decoded_auth_value[next(iter(decoded_auth_value))] else None, 

155 } 

156 else: 

157 tool_dict["auth"] = None 

158 

159 tool_dict["name"] = tool.name 

160 tool_dict["gateway_slug"] = tool.gateway_slug if tool.gateway_slug else "" 

161 tool_dict["original_name_slug"] = tool.original_name_slug 

162 

163 return ToolRead.model_validate(tool_dict) 

164 

165 async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float, success: bool, error_message: Optional[str]) -> None: 

166 """ 

167 Records a metric for a tool invocation. 

168 

169 This function calculates the response time using the provided start time and records 

170 the metric details (including whether the invocation was successful and any error message) 

171 into the database. The metric is then committed to the database. 

172 

173 Args: 

174 db (Session): The SQLAlchemy database session. 

175 tool (DbTool): The tool that was invoked. 

176 start_time (float): The monotonic start time of the invocation. 

177 success (bool): True if the invocation succeeded; otherwise, False. 

178 error_message (Optional[str]): The error message if the invocation failed, otherwise None. 

179 """ 

180 end_time = time.monotonic() 

181 response_time = end_time - start_time 

182 metric = ToolMetric( 

183 tool_id=tool.id, 

184 response_time=response_time, 

185 is_success=success, 

186 error_message=error_message, 

187 ) 

188 db.add(metric) 

189 db.commit() 

190 

191 async def register_tool(self, db: Session, tool: ToolCreate) -> ToolRead: 

192 """Register a new tool. 

193 

194 Args: 

195 db: Database session. 

196 tool: Tool creation schema. 

197 

198 Returns: 

199 Created tool information. 

200 

201 Raises: 

202 ToolNameConflictError: If tool name already exists. 

203 ToolError: For other tool registration errors. 

204 """ 

205 try: 

206 if not tool.gateway_id: 

207 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name)).scalar_one_or_none() 

208 else: 

209 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name).where(DbTool.gateway_id == tool.gateway_id)).scalar_one_or_none() 

210 if existing_tool: 

211 raise ToolNameConflictError( 

212 existing_tool.name, 

213 enabled=existing_tool.enabled, 

214 tool_id=existing_tool.id, 

215 ) 

216 

217 if tool.auth is None: 

218 auth_type = None 

219 auth_value = None 

220 else: 

221 auth_type = tool.auth.auth_type 

222 auth_value = tool.auth.auth_value 

223 

224 db_tool = DbTool( 

225 original_name=tool.name, 

226 original_name_slug=slugify(tool.name), 

227 url=str(tool.url), 

228 description=tool.description, 

229 integration_type=tool.integration_type, 

230 request_type=tool.request_type, 

231 headers=tool.headers, 

232 input_schema=tool.input_schema, 

233 annotations=tool.annotations, 

234 jsonpath_filter=tool.jsonpath_filter, 

235 auth_type=auth_type, 

236 auth_value=auth_value, 

237 gateway_id=tool.gateway_id, 

238 ) 

239 db.add(db_tool) 

240 db.commit() 

241 db.refresh(db_tool) 

242 await self._notify_tool_added(db_tool) 

243 logger.info(f"Registered tool: {db_tool.name}") 

244 return self._convert_tool_to_read(db_tool) 

245 except IntegrityError: 

246 db.rollback() 

247 raise ToolError(f"Tool already exists: {tool.name}") 

248 except Exception as e: 

249 db.rollback() 

250 raise ToolError(f"Failed to register tool: {str(e)}") 

251 

252 async def list_tools(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]: 

253 """ 

254 Retrieve a list of registered tools from the database. 

255 

256 Args: 

257 db (Session): The SQLAlchemy database session. 

258 include_inactive (bool): If True, include inactive tools in the result. 

259 Defaults to False. 

260 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, 

261 this parameter is ignored. Defaults to None. 

262 

263 Returns: 

264 List[ToolRead]: A list of registered tools represented as ToolRead objects. 

265 """ 

266 query = select(DbTool) 

267 cursor = None # Placeholder for pagination; ignore for now 

268 logger.debug(f"Listing tools with include_inactive={include_inactive}, cursor={cursor}") 

269 if not include_inactive: 

270 query = query.where(DbTool.enabled) 

271 tools = db.execute(query).scalars().all() 

272 return [self._convert_tool_to_read(t) for t in tools] 

273 

274 async def list_server_tools(self, db: Session, server_id: str, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]: 

275 """ 

276 Retrieve a list of registered tools from the database. 

277 

278 Args: 

279 db (Session): The SQLAlchemy database session. 

280 server_id (str): Server ID 

281 include_inactive (bool): If True, include inactive tools in the result. 

282 Defaults to False. 

283 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently, 

284 this parameter is ignored. Defaults to None. 

285 

286 Returns: 

287 List[ToolRead]: A list of registered tools represented as ToolRead objects. 

288 """ 

289 query = select(DbTool).join(server_tool_association, DbTool.id == server_tool_association.c.tool_id).where(server_tool_association.c.server_id == server_id) 

290 cursor = None # Placeholder for pagination; ignore for now 

291 logger.debug(f"Listing server tools for server_id={server_id} with include_inactive={include_inactive}, cursor={cursor}") 

292 if not include_inactive: 

293 query = query.where(DbTool.enabled) 

294 tools = db.execute(query).scalars().all() 

295 return [self._convert_tool_to_read(t) for t in tools] 

296 

297 async def get_tool(self, db: Session, tool_id: str) -> ToolRead: 

298 """Get a specific tool by ID. 

299 

300 Args: 

301 db: Database session. 

302 tool_id: Tool ID to retrieve. 

303 

304 Returns: 

305 Tool information. 

306 

307 Raises: 

308 ToolNotFoundError: If tool not found. 

309 """ 

310 tool = db.get(DbTool, tool_id) 

311 if not tool: 

312 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

313 return self._convert_tool_to_read(tool) 

314 

315 async def delete_tool(self, db: Session, tool_id: str) -> None: 

316 """Permanently delete a tool from the database. 

317 

318 Args: 

319 db: Database session. 

320 tool_id: Tool ID to delete. 

321 

322 Raises: 

323 ToolNotFoundError: If tool not found. 

324 ToolError: For other deletion errors. 

325 """ 

326 try: 

327 tool = db.get(DbTool, tool_id) 

328 if not tool: 

329 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

330 tool_info = {"id": tool.id, "name": tool.name} 

331 db.delete(tool) 

332 db.commit() 

333 await self._notify_tool_deleted(tool_info) 

334 logger.info(f"Permanently deleted tool: {tool_info['name']}") 

335 except Exception as e: 

336 db.rollback() 

337 raise ToolError(f"Failed to delete tool: {str(e)}") 

338 

339 async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, reachable: bool) -> ToolRead: 

340 """Toggle tool active status. 

341 

342 Args: 

343 db: Database session. 

344 tool_id: Tool ID to toggle. 

345 activate: True to activate, False to deactivate. 

346 reachable: True if the tool is reachable, False otherwise. 

347 

348 Returns: 

349 Updated tool information. 

350 

351 Raises: 

352 ToolNotFoundError: If tool not found. 

353 ToolError: For other errors. 

354 """ 

355 try: 

356 tool = db.get(DbTool, tool_id) 

357 if not tool: 

358 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

359 

360 is_activated = is_reachable = False 

361 if tool.enabled != activate: 

362 tool.enabled = activate 

363 is_activated = True 

364 

365 if tool.reachable != reachable: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true

366 tool.reachable = reachable 

367 is_reachable = True 

368 

369 if is_activated or is_reachable: 

370 tool.updated_at = datetime.now(timezone.utc) 

371 

372 db.commit() 

373 db.refresh(tool) 

374 if activate: 

375 await self._notify_tool_activated(tool) 

376 else: 

377 await self._notify_tool_deactivated(tool) 

378 logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}") 

379 

380 return self._convert_tool_to_read(tool) 

381 except Exception as e: 

382 db.rollback() 

383 raise ToolError(f"Failed to toggle tool status: {str(e)}") 

384 

385 async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -> ToolResult: 

386 """ 

387 Invoke a registered tool and record execution metrics. 

388 

389 Args: 

390 db: Database session. 

391 name: Name of tool to invoke. 

392 arguments: Tool arguments. 

393 

394 Returns: 

395 Tool invocation result. 

396 

397 Raises: 

398 ToolNotFoundError: If tool not found. 

399 ToolInvocationError: If invocation fails. 

400 """ 

401 separator = literal(settings.gateway_tool_name_separator) 

402 slug_expr = case( 

403 ( 

404 DbTool.gateway_slug.is_(None), # pylint: disable=no-member 

405 DbTool.original_name_slug, 

406 ), # WHEN gateway_slug IS NULL 

407 else_=DbTool.gateway_slug + separator + DbTool.original_name_slug, # ELSE gateway_slug||sep||original 

408 ) 

409 tool = db.execute(select(DbTool).where(slug_expr == name).where(DbTool.enabled)).scalar_one_or_none() 

410 if not tool: 

411 inactive_tool = db.execute(select(DbTool).where(slug_expr == name).where(not_(DbTool.enabled))).scalar_one_or_none() 

412 if inactive_tool: 

413 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive") 

414 raise ToolNotFoundError(f"Tool not found: {name}") 

415 

416 # is_reachable = db.execute(select(DbTool.reachable).where(slug_expr == name)).scalar_one_or_none() 

417 is_reachable = tool.reachable 

418 

419 if not is_reachable: 419 ↛ 420line 419 didn't jump to line 420 because the condition on line 419 was never true

420 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.") 

421 

422 start_time = time.monotonic() 

423 success = False 

424 error_message = None 

425 try: 

426 # tool.validate_arguments(arguments) 

427 # Build headers with auth if necessary. 

428 headers = tool.headers or {} 

429 if tool.integration_type == "REST": 

430 credentials = decode_auth(tool.auth_value) 

431 headers.update(credentials) 

432 

433 # Build the payload based on integration type. 

434 payload = arguments.copy() 

435 

436 # Handle URL path parameter substitution 

437 final_url = tool.url 

438 if "{" in tool.url and "}" in tool.url: 

439 # Extract path parameters from URL template and arguments 

440 url_params = re.findall(r"\{(\w+)\}", tool.url) 

441 url_substitutions = {} 

442 

443 for param in url_params: 

444 if param in payload: 

445 url_substitutions[param] = payload.pop(param) # Remove from payload 

446 final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param])) 

447 else: 

448 raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments") 

449 

450 # Use the tool's request_type rather than defaulting to POST. 

451 method = tool.request_type.upper() 

452 if method == "GET": 

453 response = await self._http_client.get(final_url, params=payload, headers=headers) 

454 else: 

455 response = await self._http_client.request(method, final_url, json=payload, headers=headers) 

456 response.raise_for_status() 

457 

458 # Handle 204 No Content responses that have no body 

459 if response.status_code == 204: 

460 tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")]) 

461 elif response.status_code not in [200, 201, 202, 206]: 

462 result = response.json() 

463 tool_result = ToolResult( 

464 content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")], 

465 is_error=True, 

466 ) 

467 else: 

468 result = response.json() 

469 filtered_response = extract_using_jq(result, tool.jsonpath_filter) 

470 tool_result = ToolResult(content=[TextContent(type="text", text=json.dumps(filtered_response, indent=2))]) 

471 

472 success = True 

473 elif tool.integration_type == "MCP": 

474 transport = tool.request_type.lower() 

475 gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.enabled)).scalar_one_or_none() 

476 headers = decode_auth(gateway.auth_value) 

477 

478 async def connect_to_sse_server(server_url: str) -> str: 

479 """ 

480 Connect to an MCP server running with SSE transport 

481 

482 Args: 

483 server_url (str): MCP Server SSE URL 

484 

485 Returns: 

486 str: Result of tool call 

487 """ 

488 # Use async with directly to manage the context 

489 async with sse_client(url=server_url, headers=headers) as streams: 

490 async with ClientSession(*streams) as session: 

491 # Initialize the session 

492 await session.initialize() 

493 tool_call_result = await session.call_tool(tool.original_name, arguments) 

494 return tool_call_result 

495 

496 async def connect_to_streamablehttp_server(server_url: str) -> str: 

497 """ 

498 Connect to an MCP server running with Streamable HTTP transport 

499 

500 Args: 

501 server_url (str): MCP Server URL 

502 

503 Returns: 

504 str: Result of tool call 

505 """ 

506 # Use async with directly to manage the context 

507 async with streamablehttp_client(url=server_url, headers=headers) as (read_stream, write_stream, _get_session_id): 

508 async with ClientSession(read_stream, write_stream) as session: 

509 # Initialize the session 

510 await session.initialize() 

511 tool_call_result = await session.call_tool(tool.original_name, arguments) 

512 return tool_call_result 

513 

514 tool_gateway_id = tool.gateway_id 

515 tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.enabled)).scalar_one_or_none() 

516 

517 tool_call_result = ToolResult(content=[TextContent(text="", type="text")]) 

518 if transport == "sse": 

519 tool_call_result = await connect_to_sse_server(tool_gateway.url) 

520 elif transport == "streamablehttp": 

521 tool_call_result = await connect_to_streamablehttp_server(tool_gateway.url) 

522 content = tool_call_result.model_dump(by_alias=True).get("content", []) 

523 

524 success = True 

525 filtered_response = extract_using_jq(content, tool.jsonpath_filter) 

526 tool_result = ToolResult(content=filtered_response) 

527 else: 

528 return ToolResult(content=[TextContent(type="text", text="Invalid tool type")]) 

529 

530 return tool_result 

531 except Exception as e: 

532 error_message = str(e) 

533 raise ToolInvocationError(f"Tool invocation failed: {error_message}") 

534 finally: 

535 await self._record_tool_metric(db, tool, start_time, success, error_message) 

536 

537 async def update_tool(self, db: Session, tool_id: str, tool_update: ToolUpdate) -> ToolRead: 

538 """Update an existing tool. 

539 

540 Args: 

541 db: Database session. 

542 tool_id: ID of tool to update. 

543 tool_update: Updated tool data. 

544 

545 Returns: 

546 Updated tool information. 

547 

548 Raises: 

549 ToolNotFoundError: If tool not found. 

550 ToolError: For other tool update errors. 

551 ToolNameConflictError: If tool name conflict occurs 

552 """ 

553 try: 

554 tool = db.get(DbTool, tool_id) 

555 if not tool: 

556 raise ToolNotFoundError(f"Tool not found: {tool_id}") 

557 if tool_update.name is not None and not (tool_update.name == tool.name and tool_update.gateway_id == tool.gateway_id): 

558 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool_update.name).where(DbTool.gateway_id == tool_update.gateway_id).where(DbTool.id != tool_id)).scalar_one_or_none() 

559 if existing_tool: 

560 raise ToolNameConflictError( 

561 tool_update.name, 

562 enabled=existing_tool.enabled, 

563 tool_id=existing_tool.id, 

564 ) 

565 

566 if tool_update.name is not None: 

567 tool.name = tool_update.name 

568 if tool_update.url is not None: 

569 tool.url = str(tool_update.url) 

570 if tool_update.description is not None: 

571 tool.description = tool_update.description 

572 if tool_update.integration_type is not None: 

573 tool.integration_type = tool_update.integration_type 

574 if tool_update.request_type is not None: 

575 tool.request_type = tool_update.request_type 

576 if tool_update.headers is not None: 

577 tool.headers = tool_update.headers 

578 if tool_update.input_schema is not None: 

579 tool.input_schema = tool_update.input_schema 

580 if tool_update.annotations is not None: 

581 tool.annotations = tool_update.annotations 

582 if tool_update.jsonpath_filter is not None: 

583 tool.jsonpath_filter = tool_update.jsonpath_filter 

584 

585 if tool_update.auth is not None: 

586 if tool_update.auth.auth_type is not None: 

587 tool.auth_type = tool_update.auth.auth_type 

588 if tool_update.auth.auth_value is not None: 

589 tool.auth_value = tool_update.auth.auth_value 

590 else: 

591 tool.auth_type = None 

592 

593 tool.updated_at = datetime.now(timezone.utc) 

594 db.commit() 

595 db.refresh(tool) 

596 await self._notify_tool_updated(tool) 

597 logger.info(f"Updated tool: {tool.name}") 

598 return self._convert_tool_to_read(tool) 

599 except Exception as e: 

600 db.rollback() 

601 raise ToolError(f"Failed to update tool: {str(e)}") 

602 

603 async def _notify_tool_updated(self, tool: DbTool) -> None: 

604 """ 

605 Notify subscribers of tool update. 

606 

607 Args: 

608 tool: Tool updated 

609 """ 

610 event = { 

611 "type": "tool_updated", 

612 "data": {"id": tool.id, "name": tool.name, "url": tool.url, "description": tool.description, "enabled": tool.enabled}, 

613 "timestamp": datetime.now(timezone.utc).isoformat(), 

614 } 

615 await self._publish_event(event) 

616 

617 async def _notify_tool_activated(self, tool: DbTool) -> None: 

618 """ 

619 Notify subscribers of tool activation. 

620 

621 Args: 

622 tool: Tool activated 

623 """ 

624 event = { 

625 "type": "tool_activated", 

626 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled}, 

627 "timestamp": datetime.now(timezone.utc).isoformat(), 

628 } 

629 await self._publish_event(event) 

630 

631 async def _notify_tool_deactivated(self, tool: DbTool) -> None: 

632 """ 

633 Notify subscribers of tool deactivation. 

634 

635 Args: 

636 tool: Tool deactivated 

637 """ 

638 event = { 

639 "type": "tool_deactivated", 

640 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled}, 

641 "timestamp": datetime.now(timezone.utc).isoformat(), 

642 } 

643 await self._publish_event(event) 

644 

645 async def _notify_tool_deleted(self, tool_info: Dict[str, Any]) -> None: 

646 """ 

647 Notify subscribers of tool deletion. 

648 

649 Args: 

650 tool_info: Dictionary on tool deleted 

651 """ 

652 event = { 

653 "type": "tool_deleted", 

654 "data": tool_info, 

655 "timestamp": datetime.now(timezone.utc).isoformat(), 

656 } 

657 await self._publish_event(event) 

658 

659 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

660 """Subscribe to tool events. 

661 

662 Yields: 

663 Tool event messages. 

664 """ 

665 queue: asyncio.Queue = asyncio.Queue() 

666 self._event_subscribers.append(queue) 

667 try: 

668 while True: 

669 event = await queue.get() 

670 yield event 

671 finally: 

672 self._event_subscribers.remove(queue) 

673 

674 async def _notify_tool_added(self, tool: DbTool) -> None: 

675 """ 

676 Notify subscribers of tool addition. 

677 

678 Args: 

679 tool: Tool added 

680 """ 

681 event = { 

682 "type": "tool_added", 

683 "data": { 

684 "id": tool.id, 

685 "name": tool.name, 

686 "url": tool.url, 

687 "description": tool.description, 

688 "enabled": tool.enabled, 

689 }, 

690 "timestamp": datetime.now(timezone.utc).isoformat(), 

691 } 

692 await self._publish_event(event) 

693 

694 async def _notify_tool_removed(self, tool: DbTool) -> None: 

695 """ 

696 Notify subscribers of tool removal (soft delete/deactivation). 

697 

698 Args: 

699 tool: Tool removed 

700 """ 

701 event = { 

702 "type": "tool_removed", 

703 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled}, 

704 "timestamp": datetime.now(timezone.utc).isoformat(), 

705 } 

706 await self._publish_event(event) 

707 

708 async def _publish_event(self, event: Dict[str, Any]) -> None: 

709 """ 

710 Publish event to all subscribers. 

711 

712 Args: 

713 event: Event to publish 

714 """ 

715 for queue in self._event_subscribers: 

716 await queue.put(event) 

717 

718 async def _validate_tool_url(self, url: str) -> None: 

719 """Validate tool URL is accessible. 

720 

721 Args: 

722 url: URL to validate. 

723 

724 Raises: 

725 ToolValidationError: If URL validation fails. 

726 """ 

727 try: 

728 response = await self._http_client.get(url) 

729 response.raise_for_status() 

730 except Exception as e: 

731 raise ToolValidationError(f"Failed to validate tool URL: {str(e)}") 

732 

733 async def _check_tool_health(self, tool: DbTool) -> bool: 

734 """Check if tool endpoint is healthy. 

735 

736 Args: 

737 tool: Tool to check. 

738 

739 Returns: 

740 True if tool is healthy. 

741 """ 

742 try: 

743 response = await self._http_client.get(tool.url) 

744 return response.is_success 

745 except Exception: 

746 return False 

747 

748 async def event_generator(self) -> AsyncGenerator[Dict[str, Any], None]: 

749 """Generate tool events for SSE. 

750 

751 Yields: 

752 Tool events. 

753 """ 

754 queue: asyncio.Queue = asyncio.Queue() 

755 self._event_subscribers.append(queue) 

756 try: 

757 while True: 

758 event = await queue.get() 

759 yield event 

760 finally: 

761 self._event_subscribers.remove(queue) 

762 

763 # --- Metrics --- 

764 async def aggregate_metrics(self, db: Session) -> Dict[str, Any]: 

765 """ 

766 Aggregate metrics for all tool invocations. 

767 

768 Args: 

769 db: Database session 

770 

771 Returns: 

772 A dictionary with keys: 

773 - total_executions 

774 - successful_executions 

775 - failed_executions 

776 - failure_rate 

777 - min_response_time 

778 - max_response_time 

779 - avg_response_time 

780 - last_execution_time 

781 """ 

782 

783 total = db.execute(select(func.count(ToolMetric.id))).scalar() or 0 # pylint: disable=not-callable 

784 successful = db.execute(select(func.count(ToolMetric.id)).where(ToolMetric.is_success)).scalar() or 0 # pylint: disable=not-callable 

785 failed = db.execute(select(func.count(ToolMetric.id)).where(not_(ToolMetric.is_success))).scalar() or 0 # pylint: disable=not-callable 

786 failure_rate = failed / total if total > 0 else 0.0 

787 min_rt = db.execute(select(func.min(ToolMetric.response_time))).scalar() 

788 max_rt = db.execute(select(func.max(ToolMetric.response_time))).scalar() 

789 avg_rt = db.execute(select(func.avg(ToolMetric.response_time))).scalar() 

790 last_time = db.execute(select(func.max(ToolMetric.timestamp))).scalar() 

791 

792 return { 

793 "total_executions": total, 

794 "successful_executions": successful, 

795 "failed_executions": failed, 

796 "failure_rate": failure_rate, 

797 "min_response_time": min_rt, 

798 "max_response_time": max_rt, 

799 "avg_response_time": avg_rt, 

800 "last_execution_time": last_time, 

801 } 

802 

803 async def reset_metrics(self, db: Session, tool_id: Optional[int] = None) -> None: 

804 """ 

805 Reset metrics for tool invocations. 

806 

807 If tool_id is provided, only the metrics for that specific tool will be deleted. 

808 Otherwise, all tool metrics will be deleted (global reset). 

809 

810 Args: 

811 db (Session): The SQLAlchemy database session. 

812 tool_id (Optional[int]): Specific tool ID to reset metrics for. 

813 """ 

814 

815 if tool_id: 

816 db.execute(delete(ToolMetric).where(ToolMetric.tool_id == tool_id)) 

817 else: 

818 db.execute(delete(ToolMetric)) 

819 db.commit()