Coverage for mcpgateway/services/tool_service.py: 91%
341 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""Tool Service Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements tool management and invocation according to the MCP specification.
9It handles:
10- Tool registration and validation
11- Tool invocation with schema validation
12- Tool federation across gateways
13- Event notifications for tool changes
14- Active/inactive tool management
15"""
17# Standard
18import asyncio
19import base64
20from datetime import datetime, timezone
21import json
22import logging
23import re
24import time
25from typing import Any, AsyncGenerator, Dict, List, Optional
27# Third-Party
28import httpx
29from mcp import ClientSession
30from mcp.client.sse import sse_client
31from mcp.client.streamable_http import streamablehttp_client
32from sqlalchemy import case, delete, func, literal, not_, select
33from sqlalchemy.exc import IntegrityError
34from sqlalchemy.orm import Session
36# First-Party
37from mcpgateway.config import settings
38from mcpgateway.db import Gateway as DbGateway
39from mcpgateway.db import server_tool_association
40from mcpgateway.db import Tool as DbTool
41from mcpgateway.db import ToolMetric
42from mcpgateway.models import TextContent, ToolResult
43from mcpgateway.schemas import (
44 ToolCreate,
45 ToolRead,
46 ToolUpdate,
47)
48from mcpgateway.utils.create_slug import slugify
49from mcpgateway.utils.services_auth import decode_auth
51# Local
52from ..config import extract_using_jq
54logger = logging.getLogger(__name__)
57class ToolError(Exception):
58 """Base class for tool-related errors."""
61class ToolNotFoundError(ToolError):
62 """Raised when a requested tool is not found."""
65class ToolNameConflictError(ToolError):
66 """Raised when a tool name conflicts with existing (active or inactive) tool."""
68 def __init__(self, name: str, enabled: bool = True, tool_id: Optional[int] = None):
69 """Initialize the error with tool information.
71 Args:
72 name: The conflicting tool name.
73 enabled: Whether the existing tool is enabled or not.
74 tool_id: ID of the existing tool if available.
75 """
76 self.name = name
77 self.enabled = enabled
78 self.tool_id = tool_id
79 message = f"Tool already exists with name: {name}"
80 if not enabled:
81 message += f" (currently inactive, ID: {tool_id})"
82 super().__init__(message)
85class ToolValidationError(ToolError):
86 """Raised when tool validation fails."""
89class ToolInvocationError(ToolError):
90 """Raised when tool invocation fails."""
93class ToolService:
94 """Service for managing and invoking tools.
96 Handles:
97 - Tool registration and deregistration.
98 - Tool invocation and validation.
99 - Tool federation.
100 - Event notifications.
101 - Active/inactive tool management.
102 """
104 def __init__(self):
105 """Initialize the tool service."""
106 self._event_subscribers: List[asyncio.Queue] = []
107 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
109 async def initialize(self) -> None:
110 """Initialize the service."""
111 logger.info("Initializing tool service")
113 async def shutdown(self) -> None:
114 """Shutdown the service."""
115 await self._http_client.aclose()
116 logger.info("Tool service shutdown complete")
118 def _convert_tool_to_read(self, tool: DbTool) -> ToolRead:
119 """
120 Converts a DbTool instance into a ToolRead model, including aggregated metrics and
121 new API gateway fields: request_type and authentication credentials (masked).
123 Args:
124 tool (DbTool): The ORM instance of the tool.
126 Returns:
127 ToolRead: The Pydantic model representing the tool, including aggregated metrics and new fields.
128 """
129 tool_dict = tool.__dict__.copy()
130 tool_dict.pop("_sa_instance_state", None)
131 tool_dict["execution_count"] = tool.execution_count
132 tool_dict["metrics"] = tool.metrics_summary
133 tool_dict["request_type"] = tool.request_type
134 tool_dict["annotations"] = tool.annotations or {}
136 decoded_auth_value = decode_auth(tool.auth_value)
137 if tool.auth_type == "basic":
138 decoded_bytes = base64.b64decode(decoded_auth_value["Authorization"].split("Basic ")[1])
139 username, password = decoded_bytes.decode("utf-8").split(":")
140 tool_dict["auth"] = {
141 "auth_type": "basic",
142 "username": username,
143 "password": "********" if password else None,
144 }
145 elif tool.auth_type == "bearer":
146 tool_dict["auth"] = {
147 "auth_type": "bearer",
148 "token": "********" if decoded_auth_value["Authorization"] else None,
149 }
150 elif tool.auth_type == "authheaders":
151 tool_dict["auth"] = {
152 "auth_type": "authheaders",
153 "auth_header_key": next(iter(decoded_auth_value)),
154 "auth_header_value": "********" if decoded_auth_value[next(iter(decoded_auth_value))] else None,
155 }
156 else:
157 tool_dict["auth"] = None
159 tool_dict["name"] = tool.name
160 tool_dict["gateway_slug"] = tool.gateway_slug if tool.gateway_slug else ""
161 tool_dict["original_name_slug"] = tool.original_name_slug
163 return ToolRead.model_validate(tool_dict)
165 async def _record_tool_metric(self, db: Session, tool: DbTool, start_time: float, success: bool, error_message: Optional[str]) -> None:
166 """
167 Records a metric for a tool invocation.
169 This function calculates the response time using the provided start time and records
170 the metric details (including whether the invocation was successful and any error message)
171 into the database. The metric is then committed to the database.
173 Args:
174 db (Session): The SQLAlchemy database session.
175 tool (DbTool): The tool that was invoked.
176 start_time (float): The monotonic start time of the invocation.
177 success (bool): True if the invocation succeeded; otherwise, False.
178 error_message (Optional[str]): The error message if the invocation failed, otherwise None.
179 """
180 end_time = time.monotonic()
181 response_time = end_time - start_time
182 metric = ToolMetric(
183 tool_id=tool.id,
184 response_time=response_time,
185 is_success=success,
186 error_message=error_message,
187 )
188 db.add(metric)
189 db.commit()
191 async def register_tool(self, db: Session, tool: ToolCreate) -> ToolRead:
192 """Register a new tool.
194 Args:
195 db: Database session.
196 tool: Tool creation schema.
198 Returns:
199 Created tool information.
201 Raises:
202 ToolNameConflictError: If tool name already exists.
203 ToolError: For other tool registration errors.
204 """
205 try:
206 if not tool.gateway_id:
207 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name)).scalar_one_or_none()
208 else:
209 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool.name).where(DbTool.gateway_id == tool.gateway_id)).scalar_one_or_none()
210 if existing_tool:
211 raise ToolNameConflictError(
212 existing_tool.name,
213 enabled=existing_tool.enabled,
214 tool_id=existing_tool.id,
215 )
217 if tool.auth is None:
218 auth_type = None
219 auth_value = None
220 else:
221 auth_type = tool.auth.auth_type
222 auth_value = tool.auth.auth_value
224 db_tool = DbTool(
225 original_name=tool.name,
226 original_name_slug=slugify(tool.name),
227 url=str(tool.url),
228 description=tool.description,
229 integration_type=tool.integration_type,
230 request_type=tool.request_type,
231 headers=tool.headers,
232 input_schema=tool.input_schema,
233 annotations=tool.annotations,
234 jsonpath_filter=tool.jsonpath_filter,
235 auth_type=auth_type,
236 auth_value=auth_value,
237 gateway_id=tool.gateway_id,
238 )
239 db.add(db_tool)
240 db.commit()
241 db.refresh(db_tool)
242 await self._notify_tool_added(db_tool)
243 logger.info(f"Registered tool: {db_tool.name}")
244 return self._convert_tool_to_read(db_tool)
245 except IntegrityError:
246 db.rollback()
247 raise ToolError(f"Tool already exists: {tool.name}")
248 except Exception as e:
249 db.rollback()
250 raise ToolError(f"Failed to register tool: {str(e)}")
252 async def list_tools(self, db: Session, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]:
253 """
254 Retrieve a list of registered tools from the database.
256 Args:
257 db (Session): The SQLAlchemy database session.
258 include_inactive (bool): If True, include inactive tools in the result.
259 Defaults to False.
260 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
261 this parameter is ignored. Defaults to None.
263 Returns:
264 List[ToolRead]: A list of registered tools represented as ToolRead objects.
265 """
266 query = select(DbTool)
267 cursor = None # Placeholder for pagination; ignore for now
268 logger.debug(f"Listing tools with include_inactive={include_inactive}, cursor={cursor}")
269 if not include_inactive:
270 query = query.where(DbTool.enabled)
271 tools = db.execute(query).scalars().all()
272 return [self._convert_tool_to_read(t) for t in tools]
274 async def list_server_tools(self, db: Session, server_id: str, include_inactive: bool = False, cursor: Optional[str] = None) -> List[ToolRead]:
275 """
276 Retrieve a list of registered tools from the database.
278 Args:
279 db (Session): The SQLAlchemy database session.
280 server_id (str): Server ID
281 include_inactive (bool): If True, include inactive tools in the result.
282 Defaults to False.
283 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
284 this parameter is ignored. Defaults to None.
286 Returns:
287 List[ToolRead]: A list of registered tools represented as ToolRead objects.
288 """
289 query = select(DbTool).join(server_tool_association, DbTool.id == server_tool_association.c.tool_id).where(server_tool_association.c.server_id == server_id)
290 cursor = None # Placeholder for pagination; ignore for now
291 logger.debug(f"Listing server tools for server_id={server_id} with include_inactive={include_inactive}, cursor={cursor}")
292 if not include_inactive:
293 query = query.where(DbTool.enabled)
294 tools = db.execute(query).scalars().all()
295 return [self._convert_tool_to_read(t) for t in tools]
297 async def get_tool(self, db: Session, tool_id: str) -> ToolRead:
298 """Get a specific tool by ID.
300 Args:
301 db: Database session.
302 tool_id: Tool ID to retrieve.
304 Returns:
305 Tool information.
307 Raises:
308 ToolNotFoundError: If tool not found.
309 """
310 tool = db.get(DbTool, tool_id)
311 if not tool:
312 raise ToolNotFoundError(f"Tool not found: {tool_id}")
313 return self._convert_tool_to_read(tool)
315 async def delete_tool(self, db: Session, tool_id: str) -> None:
316 """Permanently delete a tool from the database.
318 Args:
319 db: Database session.
320 tool_id: Tool ID to delete.
322 Raises:
323 ToolNotFoundError: If tool not found.
324 ToolError: For other deletion errors.
325 """
326 try:
327 tool = db.get(DbTool, tool_id)
328 if not tool:
329 raise ToolNotFoundError(f"Tool not found: {tool_id}")
330 tool_info = {"id": tool.id, "name": tool.name}
331 db.delete(tool)
332 db.commit()
333 await self._notify_tool_deleted(tool_info)
334 logger.info(f"Permanently deleted tool: {tool_info['name']}")
335 except Exception as e:
336 db.rollback()
337 raise ToolError(f"Failed to delete tool: {str(e)}")
339 async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, reachable: bool) -> ToolRead:
340 """Toggle tool active status.
342 Args:
343 db: Database session.
344 tool_id: Tool ID to toggle.
345 activate: True to activate, False to deactivate.
346 reachable: True if the tool is reachable, False otherwise.
348 Returns:
349 Updated tool information.
351 Raises:
352 ToolNotFoundError: If tool not found.
353 ToolError: For other errors.
354 """
355 try:
356 tool = db.get(DbTool, tool_id)
357 if not tool:
358 raise ToolNotFoundError(f"Tool not found: {tool_id}")
360 is_activated = is_reachable = False
361 if tool.enabled != activate:
362 tool.enabled = activate
363 is_activated = True
365 if tool.reachable != reachable: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true
366 tool.reachable = reachable
367 is_reachable = True
369 if is_activated or is_reachable:
370 tool.updated_at = datetime.now(timezone.utc)
372 db.commit()
373 db.refresh(tool)
374 if activate:
375 await self._notify_tool_activated(tool)
376 else:
377 await self._notify_tool_deactivated(tool)
378 logger.info(f"Tool: {tool.name} is {'enabled' if activate else 'disabled'}{' and accessible' if reachable else ' but inaccessible'}")
380 return self._convert_tool_to_read(tool)
381 except Exception as e:
382 db.rollback()
383 raise ToolError(f"Failed to toggle tool status: {str(e)}")
385 async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -> ToolResult:
386 """
387 Invoke a registered tool and record execution metrics.
389 Args:
390 db: Database session.
391 name: Name of tool to invoke.
392 arguments: Tool arguments.
394 Returns:
395 Tool invocation result.
397 Raises:
398 ToolNotFoundError: If tool not found.
399 ToolInvocationError: If invocation fails.
400 """
401 separator = literal(settings.gateway_tool_name_separator)
402 slug_expr = case(
403 (
404 DbTool.gateway_slug.is_(None), # pylint: disable=no-member
405 DbTool.original_name_slug,
406 ), # WHEN gateway_slug IS NULL
407 else_=DbTool.gateway_slug + separator + DbTool.original_name_slug, # ELSE gateway_slug||sep||original
408 )
409 tool = db.execute(select(DbTool).where(slug_expr == name).where(DbTool.enabled)).scalar_one_or_none()
410 if not tool:
411 inactive_tool = db.execute(select(DbTool).where(slug_expr == name).where(not_(DbTool.enabled))).scalar_one_or_none()
412 if inactive_tool:
413 raise ToolNotFoundError(f"Tool '{name}' exists but is inactive")
414 raise ToolNotFoundError(f"Tool not found: {name}")
416 # is_reachable = db.execute(select(DbTool.reachable).where(slug_expr == name)).scalar_one_or_none()
417 is_reachable = tool.reachable
419 if not is_reachable: 419 ↛ 420line 419 didn't jump to line 420 because the condition on line 419 was never true
420 raise ToolNotFoundError(f"Tool '{name}' exists but is currently offline. Please verify if it is running.")
422 start_time = time.monotonic()
423 success = False
424 error_message = None
425 try:
426 # tool.validate_arguments(arguments)
427 # Build headers with auth if necessary.
428 headers = tool.headers or {}
429 if tool.integration_type == "REST":
430 credentials = decode_auth(tool.auth_value)
431 headers.update(credentials)
433 # Build the payload based on integration type.
434 payload = arguments.copy()
436 # Handle URL path parameter substitution
437 final_url = tool.url
438 if "{" in tool.url and "}" in tool.url:
439 # Extract path parameters from URL template and arguments
440 url_params = re.findall(r"\{(\w+)\}", tool.url)
441 url_substitutions = {}
443 for param in url_params:
444 if param in payload:
445 url_substitutions[param] = payload.pop(param) # Remove from payload
446 final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param]))
447 else:
448 raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments")
450 # Use the tool's request_type rather than defaulting to POST.
451 method = tool.request_type.upper()
452 if method == "GET":
453 response = await self._http_client.get(final_url, params=payload, headers=headers)
454 else:
455 response = await self._http_client.request(method, final_url, json=payload, headers=headers)
456 response.raise_for_status()
458 # Handle 204 No Content responses that have no body
459 if response.status_code == 204:
460 tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")])
461 elif response.status_code not in [200, 201, 202, 206]:
462 result = response.json()
463 tool_result = ToolResult(
464 content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")],
465 is_error=True,
466 )
467 else:
468 result = response.json()
469 filtered_response = extract_using_jq(result, tool.jsonpath_filter)
470 tool_result = ToolResult(content=[TextContent(type="text", text=json.dumps(filtered_response, indent=2))])
472 success = True
473 elif tool.integration_type == "MCP":
474 transport = tool.request_type.lower()
475 gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.enabled)).scalar_one_or_none()
476 headers = decode_auth(gateway.auth_value)
478 async def connect_to_sse_server(server_url: str) -> str:
479 """
480 Connect to an MCP server running with SSE transport
482 Args:
483 server_url (str): MCP Server SSE URL
485 Returns:
486 str: Result of tool call
487 """
488 # Use async with directly to manage the context
489 async with sse_client(url=server_url, headers=headers) as streams:
490 async with ClientSession(*streams) as session:
491 # Initialize the session
492 await session.initialize()
493 tool_call_result = await session.call_tool(tool.original_name, arguments)
494 return tool_call_result
496 async def connect_to_streamablehttp_server(server_url: str) -> str:
497 """
498 Connect to an MCP server running with Streamable HTTP transport
500 Args:
501 server_url (str): MCP Server URL
503 Returns:
504 str: Result of tool call
505 """
506 # Use async with directly to manage the context
507 async with streamablehttp_client(url=server_url, headers=headers) as (read_stream, write_stream, _get_session_id):
508 async with ClientSession(read_stream, write_stream) as session:
509 # Initialize the session
510 await session.initialize()
511 tool_call_result = await session.call_tool(tool.original_name, arguments)
512 return tool_call_result
514 tool_gateway_id = tool.gateway_id
515 tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.enabled)).scalar_one_or_none()
517 tool_call_result = ToolResult(content=[TextContent(text="", type="text")])
518 if transport == "sse":
519 tool_call_result = await connect_to_sse_server(tool_gateway.url)
520 elif transport == "streamablehttp":
521 tool_call_result = await connect_to_streamablehttp_server(tool_gateway.url)
522 content = tool_call_result.model_dump(by_alias=True).get("content", [])
524 success = True
525 filtered_response = extract_using_jq(content, tool.jsonpath_filter)
526 tool_result = ToolResult(content=filtered_response)
527 else:
528 return ToolResult(content=[TextContent(type="text", text="Invalid tool type")])
530 return tool_result
531 except Exception as e:
532 error_message = str(e)
533 raise ToolInvocationError(f"Tool invocation failed: {error_message}")
534 finally:
535 await self._record_tool_metric(db, tool, start_time, success, error_message)
537 async def update_tool(self, db: Session, tool_id: str, tool_update: ToolUpdate) -> ToolRead:
538 """Update an existing tool.
540 Args:
541 db: Database session.
542 tool_id: ID of tool to update.
543 tool_update: Updated tool data.
545 Returns:
546 Updated tool information.
548 Raises:
549 ToolNotFoundError: If tool not found.
550 ToolError: For other tool update errors.
551 ToolNameConflictError: If tool name conflict occurs
552 """
553 try:
554 tool = db.get(DbTool, tool_id)
555 if not tool:
556 raise ToolNotFoundError(f"Tool not found: {tool_id}")
557 if tool_update.name is not None and not (tool_update.name == tool.name and tool_update.gateway_id == tool.gateway_id):
558 existing_tool = db.execute(select(DbTool).where(DbTool.name == tool_update.name).where(DbTool.gateway_id == tool_update.gateway_id).where(DbTool.id != tool_id)).scalar_one_or_none()
559 if existing_tool:
560 raise ToolNameConflictError(
561 tool_update.name,
562 enabled=existing_tool.enabled,
563 tool_id=existing_tool.id,
564 )
566 if tool_update.name is not None:
567 tool.name = tool_update.name
568 if tool_update.url is not None:
569 tool.url = str(tool_update.url)
570 if tool_update.description is not None:
571 tool.description = tool_update.description
572 if tool_update.integration_type is not None:
573 tool.integration_type = tool_update.integration_type
574 if tool_update.request_type is not None:
575 tool.request_type = tool_update.request_type
576 if tool_update.headers is not None:
577 tool.headers = tool_update.headers
578 if tool_update.input_schema is not None:
579 tool.input_schema = tool_update.input_schema
580 if tool_update.annotations is not None:
581 tool.annotations = tool_update.annotations
582 if tool_update.jsonpath_filter is not None:
583 tool.jsonpath_filter = tool_update.jsonpath_filter
585 if tool_update.auth is not None:
586 if tool_update.auth.auth_type is not None:
587 tool.auth_type = tool_update.auth.auth_type
588 if tool_update.auth.auth_value is not None:
589 tool.auth_value = tool_update.auth.auth_value
590 else:
591 tool.auth_type = None
593 tool.updated_at = datetime.now(timezone.utc)
594 db.commit()
595 db.refresh(tool)
596 await self._notify_tool_updated(tool)
597 logger.info(f"Updated tool: {tool.name}")
598 return self._convert_tool_to_read(tool)
599 except Exception as e:
600 db.rollback()
601 raise ToolError(f"Failed to update tool: {str(e)}")
603 async def _notify_tool_updated(self, tool: DbTool) -> None:
604 """
605 Notify subscribers of tool update.
607 Args:
608 tool: Tool updated
609 """
610 event = {
611 "type": "tool_updated",
612 "data": {"id": tool.id, "name": tool.name, "url": tool.url, "description": tool.description, "enabled": tool.enabled},
613 "timestamp": datetime.now(timezone.utc).isoformat(),
614 }
615 await self._publish_event(event)
617 async def _notify_tool_activated(self, tool: DbTool) -> None:
618 """
619 Notify subscribers of tool activation.
621 Args:
622 tool: Tool activated
623 """
624 event = {
625 "type": "tool_activated",
626 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled},
627 "timestamp": datetime.now(timezone.utc).isoformat(),
628 }
629 await self._publish_event(event)
631 async def _notify_tool_deactivated(self, tool: DbTool) -> None:
632 """
633 Notify subscribers of tool deactivation.
635 Args:
636 tool: Tool deactivated
637 """
638 event = {
639 "type": "tool_deactivated",
640 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled},
641 "timestamp": datetime.now(timezone.utc).isoformat(),
642 }
643 await self._publish_event(event)
645 async def _notify_tool_deleted(self, tool_info: Dict[str, Any]) -> None:
646 """
647 Notify subscribers of tool deletion.
649 Args:
650 tool_info: Dictionary on tool deleted
651 """
652 event = {
653 "type": "tool_deleted",
654 "data": tool_info,
655 "timestamp": datetime.now(timezone.utc).isoformat(),
656 }
657 await self._publish_event(event)
659 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
660 """Subscribe to tool events.
662 Yields:
663 Tool event messages.
664 """
665 queue: asyncio.Queue = asyncio.Queue()
666 self._event_subscribers.append(queue)
667 try:
668 while True:
669 event = await queue.get()
670 yield event
671 finally:
672 self._event_subscribers.remove(queue)
674 async def _notify_tool_added(self, tool: DbTool) -> None:
675 """
676 Notify subscribers of tool addition.
678 Args:
679 tool: Tool added
680 """
681 event = {
682 "type": "tool_added",
683 "data": {
684 "id": tool.id,
685 "name": tool.name,
686 "url": tool.url,
687 "description": tool.description,
688 "enabled": tool.enabled,
689 },
690 "timestamp": datetime.now(timezone.utc).isoformat(),
691 }
692 await self._publish_event(event)
694 async def _notify_tool_removed(self, tool: DbTool) -> None:
695 """
696 Notify subscribers of tool removal (soft delete/deactivation).
698 Args:
699 tool: Tool removed
700 """
701 event = {
702 "type": "tool_removed",
703 "data": {"id": tool.id, "name": tool.name, "enabled": tool.enabled},
704 "timestamp": datetime.now(timezone.utc).isoformat(),
705 }
706 await self._publish_event(event)
708 async def _publish_event(self, event: Dict[str, Any]) -> None:
709 """
710 Publish event to all subscribers.
712 Args:
713 event: Event to publish
714 """
715 for queue in self._event_subscribers:
716 await queue.put(event)
718 async def _validate_tool_url(self, url: str) -> None:
719 """Validate tool URL is accessible.
721 Args:
722 url: URL to validate.
724 Raises:
725 ToolValidationError: If URL validation fails.
726 """
727 try:
728 response = await self._http_client.get(url)
729 response.raise_for_status()
730 except Exception as e:
731 raise ToolValidationError(f"Failed to validate tool URL: {str(e)}")
733 async def _check_tool_health(self, tool: DbTool) -> bool:
734 """Check if tool endpoint is healthy.
736 Args:
737 tool: Tool to check.
739 Returns:
740 True if tool is healthy.
741 """
742 try:
743 response = await self._http_client.get(tool.url)
744 return response.is_success
745 except Exception:
746 return False
748 async def event_generator(self) -> AsyncGenerator[Dict[str, Any], None]:
749 """Generate tool events for SSE.
751 Yields:
752 Tool events.
753 """
754 queue: asyncio.Queue = asyncio.Queue()
755 self._event_subscribers.append(queue)
756 try:
757 while True:
758 event = await queue.get()
759 yield event
760 finally:
761 self._event_subscribers.remove(queue)
763 # --- Metrics ---
764 async def aggregate_metrics(self, db: Session) -> Dict[str, Any]:
765 """
766 Aggregate metrics for all tool invocations.
768 Args:
769 db: Database session
771 Returns:
772 A dictionary with keys:
773 - total_executions
774 - successful_executions
775 - failed_executions
776 - failure_rate
777 - min_response_time
778 - max_response_time
779 - avg_response_time
780 - last_execution_time
781 """
783 total = db.execute(select(func.count(ToolMetric.id))).scalar() or 0 # pylint: disable=not-callable
784 successful = db.execute(select(func.count(ToolMetric.id)).where(ToolMetric.is_success)).scalar() or 0 # pylint: disable=not-callable
785 failed = db.execute(select(func.count(ToolMetric.id)).where(not_(ToolMetric.is_success))).scalar() or 0 # pylint: disable=not-callable
786 failure_rate = failed / total if total > 0 else 0.0
787 min_rt = db.execute(select(func.min(ToolMetric.response_time))).scalar()
788 max_rt = db.execute(select(func.max(ToolMetric.response_time))).scalar()
789 avg_rt = db.execute(select(func.avg(ToolMetric.response_time))).scalar()
790 last_time = db.execute(select(func.max(ToolMetric.timestamp))).scalar()
792 return {
793 "total_executions": total,
794 "successful_executions": successful,
795 "failed_executions": failed,
796 "failure_rate": failure_rate,
797 "min_response_time": min_rt,
798 "max_response_time": max_rt,
799 "avg_response_time": avg_rt,
800 "last_execution_time": last_time,
801 }
803 async def reset_metrics(self, db: Session, tool_id: Optional[int] = None) -> None:
804 """
805 Reset metrics for tool invocations.
807 If tool_id is provided, only the metrics for that specific tool will be deleted.
808 Otherwise, all tool metrics will be deleted (global reset).
810 Args:
811 db (Session): The SQLAlchemy database session.
812 tool_id (Optional[int]): Specific tool ID to reset metrics for.
813 """
815 if tool_id:
816 db.execute(delete(ToolMetric).where(ToolMetric.tool_id == tool_id))
817 else:
818 db.execute(delete(ToolMetric))
819 db.commit()