Coverage for mcpgateway/services/server_service.py: 70%

234 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2""" 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8MCP Gateway Server Service 

9 

10This module implements server management for the MCP Servers Catalog. 

11It handles server registration, listing, retrieval, updates, activation toggling, and deletion. 

12It also publishes event notifications for server changes. 

13""" 

14 

15# Standard 

16import asyncio 

17from datetime import datetime, timezone 

18import logging 

19from typing import Any, AsyncGenerator, Dict, List, Optional 

20 

21# Third-Party 

22import httpx 

23from sqlalchemy import delete, func, not_, select 

24from sqlalchemy.exc import IntegrityError 

25from sqlalchemy.orm import Session 

26 

27# First-Party 

28from mcpgateway.config import settings 

29from mcpgateway.db import Prompt as DbPrompt 

30from mcpgateway.db import Resource as DbResource 

31from mcpgateway.db import Server as DbServer 

32from mcpgateway.db import ServerMetric 

33from mcpgateway.db import Tool as DbTool 

34from mcpgateway.schemas import ServerCreate, ServerMetrics, ServerRead, ServerUpdate 

35 

36logger = logging.getLogger(__name__) 

37 

38 

39class ServerError(Exception): 

40 """Base class for server-related errors.""" 

41 

42 

43class ServerNotFoundError(ServerError): 

44 """Raised when a requested server is not found.""" 

45 

46 

47class ServerNameConflictError(ServerError): 

48 """Raised when a server name conflicts with an existing one.""" 

49 

50 def __init__(self, name: str, is_active: bool = True, server_id: Optional[int] = None): 

51 self.name = name 

52 self.is_active = is_active 

53 self.server_id = server_id 

54 message = f"Server already exists with name: {name}" 

55 if not is_active: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true

56 message += f" (currently inactive, ID: {server_id})" 

57 super().__init__(message) 

58 

59 

60class ServerService: 

61 """Service for managing MCP Servers in the catalog. 

62 

63 Provides methods to create, list, retrieve, update, toggle status, and delete server records. 

64 Also supports event notifications for changes in server data. 

65 """ 

66 

67 def __init__(self) -> None: 

68 self._event_subscribers: List[asyncio.Queue] = [] 

69 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) 

70 

71 async def initialize(self) -> None: 

72 """Initialize the server service.""" 

73 logger.info("Initializing server service") 

74 

75 async def shutdown(self) -> None: 

76 """Shutdown the server service.""" 

77 await self._http_client.aclose() 

78 logger.info("Server service shutdown complete") 

79 

80 def _convert_server_to_read(self, server: DbServer) -> ServerRead: 

81 """ 

82 Converts a DbServer instance into a ServerRead model, including aggregated metrics. 

83 

84 Args: 

85 server (DbServer): The ORM instance of the server. 

86 

87 Returns: 

88 ServerRead: The Pydantic model representing the server, including aggregated metrics. 

89 """ 

90 server_dict = server.__dict__.copy() 

91 server_dict.pop("_sa_instance_state", None) 

92 # Compute aggregated metrics from server.metrics; default to 0/None when no records exist. 

93 total = len(server.metrics) if hasattr(server, "metrics") else 0 

94 successful = sum(1 for m in server.metrics if m.is_success) if total > 0 else 0 

95 failed = sum(1 for m in server.metrics if not m.is_success) if total > 0 else 0 

96 failure_rate = (failed / total) if total > 0 else 0.0 

97 min_rt = min((m.response_time for m in server.metrics), default=None) if total > 0 else None 

98 max_rt = max((m.response_time for m in server.metrics), default=None) if total > 0 else None 

99 avg_rt = (sum(m.response_time for m in server.metrics) / total) if total > 0 else None 

100 last_time = max((m.timestamp for m in server.metrics), default=None) if total > 0 else None 

101 

102 server_dict["metrics"] = { 

103 "total_executions": total, 

104 "successful_executions": successful, 

105 "failed_executions": failed, 

106 "failure_rate": failure_rate, 

107 "min_response_time": min_rt, 

108 "max_response_time": max_rt, 

109 "avg_response_time": avg_rt, 

110 "last_execution_time": last_time, 

111 } 

112 # Also update associated IDs (if not already done) 

113 server_dict["associated_tools"] = [tool.name for tool in server.tools] if server.tools else [] 

114 server_dict["associated_resources"] = [res.id for res in server.resources] if server.resources else [] 

115 server_dict["associated_prompts"] = [prompt.id for prompt in server.prompts] if server.prompts else [] 

116 return ServerRead.model_validate(server_dict) 

117 

118 def _assemble_associated_items( 

119 self, 

120 tools: Optional[List[str]], 

121 resources: Optional[List[str]], 

122 prompts: Optional[List[str]], 

123 ) -> Dict[str, Any]: 

124 """ 

125 Assemble the associated items dictionary from the separate fields. 

126 

127 Args: 

128 tools: List of tool IDs. 

129 resources: List of resource IDs. 

130 prompts: List of prompt IDs. 

131 

132 Returns: 

133 A dictionary with keys "tools", "resources", and "prompts". 

134 """ 

135 return { 

136 "tools": tools or [], 

137 "resources": resources or [], 

138 "prompts": prompts or [], 

139 } 

140 

141 async def register_server(self, db: Session, server_in: ServerCreate) -> ServerRead: 

142 """ 

143 Register a new server in the catalog and validate that all associated items exist. 

144 

145 This function performs the following steps: 

146 1. Checks if a server with the same name already exists. 

147 2. Creates a new server record. 

148 3. For each ID provided in associated_tools, associated_resources, and associated_prompts, 

149 verifies that the corresponding item exists. If an item does not exist, an error is raised. 

150 4. Associates the verified items to the new server. 

151 5. Commits the transaction, refreshes the ORM instance, and forces the loading of relationship data. 

152 6. Constructs a response dictionary that includes lists of associated item IDs. 

153 7. Notifies subscribers of the addition and returns the validated response. 

154 

155 Args: 

156 db (Session): The SQLAlchemy database session. 

157 server_in (ServerCreate): The server creation schema containing server details and lists of 

158 associated tool, resource, and prompt IDs (as strings). 

159 

160 Returns: 

161 ServerRead: The newly created server, with associated item IDs. 

162 

163 Raises: 

164 ServerNameConflictError: If a server with the same name already exists. 

165 ServerError: If any associated tool, resource, or prompt does not exist, or if any other 

166 registration error occurs. 

167 """ 

168 try: 

169 # Check for an existing server with the same name. 

170 existing = db.execute(select(DbServer).where(DbServer.name == server_in.name)).scalar_one_or_none() 

171 if existing: 

172 raise ServerNameConflictError(server_in.name, is_active=existing.is_active, server_id=existing.id) 

173 

174 # Create the new server record. 

175 db_server = DbServer( 

176 name=server_in.name, 

177 description=server_in.description, 

178 icon=server_in.icon, 

179 is_active=True, 

180 ) 

181 db.add(db_server) 

182 

183 # Associate tools, verifying each exists. 

184 if server_in.associated_tools: 184 ↛ 194line 184 didn't jump to line 194 because the condition on line 184 was always true

185 for tool_id in server_in.associated_tools: 

186 if tool_id.strip() == "": 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 continue 

188 tool_obj = db.get(DbTool, tool_id) 

189 if not tool_obj: 

190 raise ServerError(f"Tool with id {tool_id} does not exist.") 

191 db_server.tools.append(tool_obj) 

192 

193 # Associate resources, verifying each exists. 

194 if server_in.associated_resources: 194 ↛ 204line 194 didn't jump to line 204 because the condition on line 194 was always true

195 for resource_id in server_in.associated_resources: 

196 if resource_id.strip() == "": 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 continue 

198 resource_obj = db.get(DbResource, int(resource_id)) 

199 if not resource_obj: 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 raise ServerError(f"Resource with id {resource_id} does not exist.") 

201 db_server.resources.append(resource_obj) 

202 

203 # Associate prompts, verifying each exists. 

204 if server_in.associated_prompts: 204 ↛ 214line 204 didn't jump to line 214 because the condition on line 204 was always true

205 for prompt_id in server_in.associated_prompts: 

206 if prompt_id.strip() == "": 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true

207 continue 

208 prompt_obj = db.get(DbPrompt, int(prompt_id)) 

209 if not prompt_obj: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 raise ServerError(f"Prompt with id {prompt_id} does not exist.") 

211 db_server.prompts.append(prompt_obj) 

212 

213 # Commit the new record and refresh. 

214 db.commit() 

215 db.refresh(db_server) 

216 # Force load the relationship attributes. 

217 _ = db_server.tools, db_server.resources, db_server.prompts 

218 

219 # Assemble response data with associated item IDs. 

220 server_data = { 

221 "id": db_server.id, 

222 "name": db_server.name, 

223 "description": db_server.description, 

224 "icon": db_server.icon, 

225 "created_at": db_server.created_at, 

226 "updated_at": db_server.updated_at, 

227 "is_active": db_server.is_active, 

228 "associated_tools": [str(tool.id) for tool in db_server.tools], 

229 "associated_resources": [str(resource.id) for resource in db_server.resources], 

230 "associated_prompts": [str(prompt.id) for prompt in db_server.prompts], 

231 } 

232 logger.debug(f"Server Data: {server_data}") 

233 await self._notify_server_added(db_server) 

234 logger.info(f"Registered server: {server_in.name}") 

235 return self._convert_server_to_read(db_server) 

236 except IntegrityError: 

237 db.rollback() 

238 raise ServerError(f"Server already exists: {server_in.name}") 

239 except Exception as e: 

240 db.rollback() 

241 raise ServerError(f"Failed to register server: {str(e)}") 

242 

243 async def list_servers(self, db: Session, include_inactive: bool = False) -> List[ServerRead]: 

244 """List all registered servers. 

245 

246 Args: 

247 db: Database session. 

248 include_inactive: Whether to include inactive servers. 

249 

250 Returns: 

251 A list of ServerRead objects. 

252 """ 

253 query = select(DbServer) 

254 if not include_inactive: 254 ↛ 256line 254 didn't jump to line 256 because the condition on line 254 was always true

255 query = query.where(DbServer.is_active) 

256 servers = db.execute(query).scalars().all() 

257 return [self._convert_server_to_read(s) for s in servers] 

258 

259 async def get_server(self, db: Session, server_id: str) -> ServerRead: 

260 """Retrieve server details by ID. 

261 

262 Args: 

263 db: Database session. 

264 server_id: The unique identifier of the server. 

265 

266 Returns: 

267 The corresponding ServerRead object. 

268 

269 Raises: 

270 ServerNotFoundError: If no server with the given ID exists. 

271 """ 

272 server = db.get(DbServer, server_id) 

273 if not server: 

274 raise ServerNotFoundError(f"Server not found: {server_id}") 

275 server_data = { 

276 "id": server.id, 

277 "name": server.name, 

278 "description": server.description, 

279 "icon": server.icon, 

280 "created_at": server.created_at, 

281 "updated_at": server.updated_at, 

282 "is_active": server.is_active, 

283 "associated_tools": [tool.name for tool in server.tools], 

284 "associated_resources": [res.id for res in server.resources], 

285 "associated_prompts": [prompt.id for prompt in server.prompts], 

286 } 

287 logger.debug(f"Server Data: {server_data}") 

288 return self._convert_server_to_read(server) 

289 

290 async def update_server(self, db: Session, server_id: str, server_update: ServerUpdate) -> ServerRead: 

291 """Update an existing server. 

292 

293 Args: 

294 db: Database session. 

295 server_id: The unique identifier of the server. 

296 server_update: Server update schema with new data. 

297 

298 Returns: 

299 The updated ServerRead object. 

300 

301 Raises: 

302 ServerNotFoundError: If the server is not found. 

303 ServerNameConflictError: If a new name conflicts with an existing server. 

304 ServerError: For other update errors. 

305 """ 

306 try: 

307 server = db.get(DbServer, server_id) 

308 if not server: 

309 raise ServerNotFoundError(f"Server not found: {server_id}") 

310 

311 # Check for name conflict if name is being changed 

312 if server_update.name and server_update.name != server.name: 312 ↛ 322line 312 didn't jump to line 322 because the condition on line 312 was always true

313 conflict = db.execute(select(DbServer).where(DbServer.name == server_update.name).where(DbServer.id != server_id)).scalar_one_or_none() 

314 if conflict: 

315 raise ServerNameConflictError( 

316 server_update.name, 

317 is_active=conflict.is_active, 

318 server_id=conflict.id, 

319 ) 

320 

321 # Update simple fields 

322 if server_update.name is not None: 322 ↛ 324line 322 didn't jump to line 324 because the condition on line 322 was always true

323 server.name = server_update.name 

324 if server_update.description is not None: 324 ↛ 326line 324 didn't jump to line 326 because the condition on line 324 was always true

325 server.description = server_update.description 

326 if server_update.icon is not None: 326 ↛ 330line 326 didn't jump to line 330 because the condition on line 326 was always true

327 server.icon = server_update.icon 

328 

329 # Update associated tools if provided 

330 if server_update.associated_tools is not None: 330 ↛ 338line 330 didn't jump to line 338 because the condition on line 330 was always true

331 server.tools = [] 

332 for tool_id in server_update.associated_tools: 

333 tool_obj = db.get(DbTool, tool_id) 

334 if tool_obj: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true

335 server.tools.append(tool_obj) 

336 

337 # Update associated resources if provided 

338 if server_update.associated_resources is not None: 338 ↛ 346line 338 didn't jump to line 346 because the condition on line 338 was always true

339 server.resources = [] 

340 for resource_id in server_update.associated_resources: 

341 resource_obj = db.get(DbResource, int(resource_id)) 

342 if resource_obj: 342 ↛ 340line 342 didn't jump to line 340 because the condition on line 342 was always true

343 server.resources.append(resource_obj) 

344 

345 # Update associated prompts if provided 

346 if server_update.associated_prompts is not None: 346 ↛ 353line 346 didn't jump to line 353 because the condition on line 346 was always true

347 server.prompts = [] 

348 for prompt_id in server_update.associated_prompts: 

349 prompt_obj = db.get(DbPrompt, int(prompt_id)) 

350 if prompt_obj: 350 ↛ 348line 350 didn't jump to line 348 because the condition on line 350 was always true

351 server.prompts.append(prompt_obj) 

352 

353 server.updated_at = datetime.now(timezone.utc) 

354 db.commit() 

355 db.refresh(server) 

356 # Force loading relationships 

357 _ = server.tools, server.resources, server.prompts 

358 

359 await self._notify_server_updated(server) 

360 logger.info(f"Updated server: {server.name}") 

361 

362 # Build a dictionary with associated IDs 

363 server_data = { 

364 "id": server.id, 

365 "name": server.name, 

366 "description": server.description, 

367 "icon": server.icon, 

368 "created_at": server.created_at, 

369 "updated_at": server.updated_at, 

370 "is_active": server.is_active, 

371 "associated_tools": [tool.id for tool in server.tools], 

372 "associated_resources": [res.id for res in server.resources], 

373 "associated_prompts": [prompt.id for prompt in server.prompts], 

374 } 

375 logger.debug(f"Server Data: {server_data}") 

376 return self._convert_server_to_read(server) 

377 except Exception as e: 

378 db.rollback() 

379 raise ServerError(f"Failed to update server: {str(e)}") 

380 

381 async def toggle_server_status(self, db: Session, server_id: str, activate: bool) -> ServerRead: 

382 """Toggle the activation status of a server. 

383 

384 Args: 

385 db: Database session. 

386 server_id: The unique identifier of the server. 

387 activate: True to activate, False to deactivate. 

388 

389 Returns: 

390 The updated ServerRead object. 

391 

392 Raises: 

393 ServerNotFoundError: If the server is not found. 

394 ServerError: For other errors. 

395 """ 

396 try: 

397 server = db.get(DbServer, server_id) 

398 if not server: 398 ↛ 399line 398 didn't jump to line 399 because the condition on line 398 was never true

399 raise ServerNotFoundError(f"Server not found: {server_id}") 

400 

401 if server.is_active != activate: 401 ↛ 412line 401 didn't jump to line 412 because the condition on line 401 was always true

402 server.is_active = activate 

403 server.updated_at = datetime.now(timezone.utc) 

404 db.commit() 

405 db.refresh(server) 

406 if activate: 406 ↛ 407line 406 didn't jump to line 407 because the condition on line 406 was never true

407 await self._notify_server_activated(server) 

408 else: 

409 await self._notify_server_deactivated(server) 

410 logger.info(f"Server {server.name} {'activated' if activate else 'deactivated'}") 

411 

412 server_data = { 

413 "id": server.id, 

414 "name": server.name, 

415 "description": server.description, 

416 "icon": server.icon, 

417 "created_at": server.created_at, 

418 "updated_at": server.updated_at, 

419 "is_active": server.is_active, 

420 "associated_tools": [tool.id for tool in server.tools], 

421 "associated_resources": [res.id for res in server.resources], 

422 "associated_prompts": [prompt.id for prompt in server.prompts], 

423 } 

424 logger.debug(f"Server Data: {server_data}") 

425 return self._convert_server_to_read(server) 

426 except Exception as e: 

427 db.rollback() 

428 raise ServerError(f"Failed to toggle server status: {str(e)}") 

429 

430 async def delete_server(self, db: Session, server_id: str) -> None: 

431 """Permanently delete a server. 

432 

433 Args: 

434 db: Database session. 

435 server_id: The unique identifier of the server. 

436 

437 Raises: 

438 ServerNotFoundError: If the server is not found. 

439 ServerError: For other deletion errors. 

440 """ 

441 try: 

442 server = db.get(DbServer, server_id) 

443 if not server: 

444 raise ServerNotFoundError(f"Server not found: {server_id}") 

445 

446 server_info = {"id": server.id, "name": server.name} 

447 db.delete(server) 

448 db.commit() 

449 

450 await self._notify_server_deleted(server_info) 

451 logger.info(f"Deleted server: {server_info['name']}") 

452 except Exception as e: 

453 db.rollback() 

454 raise ServerError(f"Failed to delete server: {str(e)}") 

455 

456 async def _publish_event(self, event: Dict[str, Any]) -> None: 

457 """ 

458 Publish an event to all subscribed queues. 

459 

460 Args: 

461 event: Event to publish 

462 """ 

463 for queue in self._event_subscribers: 

464 await queue.put(event) 

465 

466 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]: 

467 """Subscribe to server events. 

468 

469 Yields: 

470 Server event messages. 

471 """ 

472 queue: asyncio.Queue = asyncio.Queue() 

473 self._event_subscribers.append(queue) 

474 try: 

475 while True: 

476 event = await queue.get() 

477 yield event 

478 finally: 

479 self._event_subscribers.remove(queue) 

480 

481 async def _notify_server_added(self, server: DbServer) -> None: 

482 """ 

483 Notify subscribers that a new server has been added. 

484 

485 Args: 

486 server: Server to add 

487 """ 

488 associated_tools = [tool.id for tool in server.tools] if server.tools else [] 

489 associated_resources = [res.id for res in server.resources] if server.resources else [] 

490 associated_prompts = [prompt.id for prompt in server.prompts] if server.prompts else [] 

491 event = { 

492 "type": "server_added", 

493 "data": { 

494 "id": server.id, 

495 "name": server.name, 

496 "description": server.description, 

497 "icon": server.icon, 

498 "associated_tools": associated_tools, 

499 "associated_resources": associated_resources, 

500 "associated_prompts": associated_prompts, 

501 "is_active": server.is_active, 

502 }, 

503 "timestamp": datetime.now(timezone.utc).isoformat(), 

504 } 

505 await self._publish_event(event) 

506 

507 async def _notify_server_updated(self, server: DbServer) -> None: 

508 """ 

509 Notify subscribers that a server has been updated. 

510 

511 Args: 

512 server: Server to update 

513 """ 

514 associated_tools = [tool.id for tool in server.tools] if server.tools else [] 

515 associated_resources = [res.id for res in server.resources] if server.resources else [] 

516 associated_prompts = [prompt.id for prompt in server.prompts] if server.prompts else [] 

517 event = { 

518 "type": "server_updated", 

519 "data": { 

520 "id": server.id, 

521 "name": server.name, 

522 "description": server.description, 

523 "icon": server.icon, 

524 "associated_tools": associated_tools, 

525 "associated_resources": associated_resources, 

526 "associated_prompts": associated_prompts, 

527 "is_active": server.is_active, 

528 }, 

529 "timestamp": datetime.now(timezone.utc).isoformat(), 

530 } 

531 await self._publish_event(event) 

532 

533 async def _notify_server_activated(self, server: DbServer) -> None: 

534 """ 

535 Notify subscribers that a server has been activated. 

536 

537 Args: 

538 server: Server to activate 

539 """ 

540 event = { 

541 "type": "server_activated", 

542 "data": { 

543 "id": server.id, 

544 "name": server.name, 

545 "is_active": True, 

546 }, 

547 "timestamp": datetime.now(timezone.utc).isoformat(), 

548 } 

549 await self._publish_event(event) 

550 

551 async def _notify_server_deactivated(self, server: DbServer) -> None: 

552 """ 

553 Notify subscribers that a server has been deactivated. 

554 

555 Args: 

556 server: Server to deactivate 

557 """ 

558 event = { 

559 "type": "server_deactivated", 

560 "data": { 

561 "id": server.id, 

562 "name": server.name, 

563 "is_active": False, 

564 }, 

565 "timestamp": datetime.now(timezone.utc).isoformat(), 

566 } 

567 await self._publish_event(event) 

568 

569 async def _notify_server_deleted(self, server_info: Dict[str, Any]) -> None: 

570 """ 

571 Notify subscribers that a server has been deleted. 

572 

573 Args: 

574 server_info: Dictionary on server to be deleted 

575 """ 

576 event = { 

577 "type": "server_deleted", 

578 "data": server_info, 

579 "timestamp": datetime.now(timezone.utc).isoformat(), 

580 } 

581 await self._publish_event(event) 

582 

583 # --- Metrics --- 

584 async def aggregate_metrics(self, db: Session) -> ServerMetrics: 

585 """ 

586 Aggregate metrics for all server invocations across all servers. 

587 

588 Args: 

589 db: Database session 

590 

591 Returns: 

592 ServerMetrics: Aggregated metrics computed from all ServerMetric records. 

593 """ 

594 total_executions = db.execute(select(func.count()).select_from(ServerMetric)).scalar() or 0 # pylint: disable=not-callable 

595 

596 successful_executions = db.execute(select(func.count()).select_from(ServerMetric).where(ServerMetric.is_success)).scalar() or 0 # pylint: disable=not-callable 

597 

598 failed_executions = db.execute(select(func.count()).select_from(ServerMetric).where(not_(ServerMetric.is_success))).scalar() or 0 # pylint: disable=not-callable 

599 

600 min_response_time = db.execute(select(func.min(ServerMetric.response_time))).scalar() 

601 

602 max_response_time = db.execute(select(func.max(ServerMetric.response_time))).scalar() 

603 

604 avg_response_time = db.execute(select(func.avg(ServerMetric.response_time))).scalar() 

605 

606 last_execution_time = db.execute(select(func.max(ServerMetric.timestamp))).scalar() 

607 

608 return ServerMetrics( 

609 total_executions=total_executions, 

610 successful_executions=successful_executions, 

611 failed_executions=failed_executions, 

612 failure_rate=(failed_executions / total_executions) if total_executions > 0 else 0.0, 

613 min_response_time=min_response_time, 

614 max_response_time=max_response_time, 

615 avg_response_time=avg_response_time, 

616 last_execution_time=last_execution_time, 

617 ) 

618 

619 async def reset_metrics(self, db: Session) -> None: 

620 """ 

621 Reset all server metrics by deleting all records from the server metrics table. 

622 

623 Args: 

624 db: Database session 

625 """ 

626 db.execute(delete(ServerMetric)) 

627 db.commit()