Coverage for mcpgateway/federation/discovery.py: 44%
176 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""Federation Discovery Service.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements automatic peer discovery for MCP Gateways.
9It supports multiple discovery mechanisms:
10- DNS-SD service discovery
11- Static peer lists
12- Peer exchange protocol
13- Manual registration
14"""
16# Standard
17import asyncio
18from dataclasses import dataclass
19from datetime import datetime, timedelta, timezone
20import logging
21import os
22import socket
23from typing import Dict, List, Optional
24from urllib.parse import urlparse
26# Third-Party
27import httpx
28from zeroconf import ServiceInfo, ServiceStateChange
29from zeroconf.asyncio import AsyncServiceBrowser, AsyncZeroconf
31# First-Party
32from mcpgateway.config import settings
33from mcpgateway.models import ServerCapabilities
35logger = logging.getLogger(__name__)
37PROTOCOL_VERSION = os.getenv("PROTOCOL_VERSION", "2025-03-26")
40@dataclass
41class DiscoveredPeer:
42 """Information about a discovered peer gateway."""
44 url: str
45 name: Optional[str]
46 protocol_version: Optional[str]
47 capabilities: Optional[ServerCapabilities]
48 discovered_at: datetime
49 last_seen: datetime
50 source: str
53class LocalDiscoveryService:
54 """Super class for DiscoveryService"""
56 def __init__(self):
57 """Initialize local discovery service"""
58 # Service info for local discovery
59 self._service_type = "_mcp._tcp.local."
60 self._service_info = ServiceInfo(
61 self._service_type,
62 f"{settings.app_name}.{self._service_type}",
63 addresses=[socket.inet_aton(addr) for addr in self._get_local_addresses()],
64 port=settings.port,
65 properties={
66 "name": settings.app_name,
67 "version": "1.0.0",
68 "protocol": PROTOCOL_VERSION,
69 },
70 )
72 def _get_local_addresses(self) -> List[str]:
73 """Get list of local network addresses.
75 Returns:
76 List of IP addresses
77 """
78 addresses = []
79 try:
80 # Get all network interfaces
81 for iface in socket.getaddrinfo(socket.gethostname(), None):
82 addr = iface[4][0]
83 # Skip localhost
84 if not addr.startswith("127."): 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true
85 addresses.append(addr)
86 except Exception as e:
87 logger.warning(f"Failed to get local addresses: {e}")
88 # Fall back to localhost
89 addresses.append("127.0.0.1")
91 return addresses or ["127.0.0.1"]
94class DiscoveryService(LocalDiscoveryService):
95 """Service for automatic gateway discovery.
97 Supports multiple discovery mechanisms:
98 - DNS-SD for local network discovery
99 - Static peer lists from configuration
100 - Peer exchange with known gateways
101 - Manual registration via API
102 """
104 def __init__(self):
105 """Initialize discovery service."""
106 super().__init__()
108 self._zeroconf: Optional[AsyncZeroconf] = None
109 self._browser: Optional[AsyncServiceBrowser] = None
110 self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
112 # Track discovered peers
113 self._discovered_peers: Dict[str, DiscoveredPeer] = {}
115 # Start background tasks
116 self._cleanup_task: Optional[asyncio.Task] = None
117 self._refresh_task: Optional[asyncio.Task] = None
119 async def start(self) -> None:
120 """
121 Start discovery service.
123 Raises:
124 Exception: If unable to start discovery service
125 """
126 try:
127 # Initialize DNS-SD
128 if settings.federation_discovery:
129 self._zeroconf = AsyncZeroconf()
130 await self._zeroconf.async_register_service(self._service_info)
131 self._browser = AsyncServiceBrowser(
132 self._zeroconf.zeroconf,
133 self._service_type,
134 handlers=[self._on_service_state_change],
135 )
137 # Start background tasks
138 self._cleanup_task = asyncio.create_task(self._cleanup_loop())
139 self._refresh_task = asyncio.create_task(self._refresh_loop())
141 # Load static peers
142 for peer_url in settings.federation_peers:
143 await self.add_peer(peer_url, source="static")
145 logger.info("Discovery service started")
147 except Exception as e:
148 logger.error(f"Failed to start discovery service: {e}")
149 await self.stop()
150 raise
152 async def stop(self) -> None:
153 """Stop discovery service."""
154 # Cancel background tasks
155 if self._cleanup_task: 155 ↛ 156line 155 didn't jump to line 156 because the condition on line 155 was never true
156 self._cleanup_task.cancel()
157 try:
158 await self._cleanup_task
159 except asyncio.CancelledError:
160 pass
162 if self._refresh_task: 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true
163 self._refresh_task.cancel()
164 try:
165 await self._refresh_task
166 except asyncio.CancelledError:
167 pass
169 # Stop DNS-SD
170 if self._browser: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true
171 await self._browser.async_cancel()
172 self._browser = None
174 if self._zeroconf: 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true
175 await self._zeroconf.async_unregister_service(self._service_info)
176 await self._zeroconf.async_close()
177 self._zeroconf = None
179 # Close HTTP client
180 await self._http_client.aclose()
182 logger.info("Discovery service stopped")
184 async def add_peer(self, url: str, source: str, name: Optional[str] = None) -> bool:
185 """Add a new peer gateway.
187 Args:
188 url: Gateway URL
189 source: Discovery source
190 name: Optional gateway name
192 Returns:
193 True if peer was added
194 """
195 # Validate URL
196 try:
197 parsed = urlparse(url)
198 if not parsed.scheme or not parsed.netloc:
199 logger.warning(f"Invalid peer URL: {url}")
200 return False
201 except Exception:
202 logger.warning(f"Failed to parse peer URL: {url}")
203 return False
205 # Skip if already known
206 if url in self._discovered_peers:
207 peer = self._discovered_peers[url]
208 peer.last_seen = datetime.now(timezone.utc)
209 return False
211 try:
212 # Try to get gateway info
213 capabilities = await self._get_gateway_info(url)
215 # Add to discovered peers
216 self._discovered_peers[url] = DiscoveredPeer(
217 url=url,
218 name=name,
219 protocol_version=PROTOCOL_VERSION,
220 capabilities=capabilities,
221 discovered_at=datetime.now(timezone.utc),
222 last_seen=datetime.now(timezone.utc),
223 source=source,
224 )
226 logger.info(f"Added peer gateway: {url} (via {source})")
227 return True
229 except Exception as e:
230 logger.warning(f"Failed to add peer {url}: {e}")
231 return False
233 def get_discovered_peers(self) -> List[DiscoveredPeer]:
234 """Get list of discovered peers.
236 Returns:
237 List of discovered peer information
238 """
239 return list(self._discovered_peers.values())
241 async def refresh_peer(self, url: str) -> bool:
242 """Refresh peer gateway information.
244 Args:
245 url: Gateway URL to refresh
247 Returns:
248 True if refresh succeeded
249 """
250 if url not in self._discovered_peers: 250 ↛ 251line 250 didn't jump to line 251 because the condition on line 250 was never true
251 return False
253 try:
254 capabilities = await self._get_gateway_info(url)
255 self._discovered_peers[url].capabilities = capabilities
256 self._discovered_peers[url].last_seen = datetime.now(timezone.utc)
257 return True
258 except Exception as e:
259 logger.warning(f"Failed to refresh peer {url}: {e}")
260 return False
262 async def remove_peer(self, url: str) -> None:
263 """Remove a peer gateway.
265 Args:
266 url: Gateway URL to remove
267 """
268 self._discovered_peers.pop(url, None)
270 async def _on_service_state_change(
271 self,
272 zeroconf: AsyncZeroconf,
273 service_type: str,
274 name: str,
275 state_change: ServiceStateChange,
276 ) -> None:
277 """Handle DNS-SD service changes.
279 Args:
280 zeroconf: Zeroconf instance
281 service_type: Service type
282 name: Service name
283 state_change: Type of state change
284 """
285 if state_change is ServiceStateChange.Added:
286 info = await zeroconf.async_get_service_info(service_type, name)
287 if info:
288 try:
289 # Extract gateway info
290 addresses = [socket.inet_ntoa(addr) for addr in info.addresses]
291 if addresses:
292 port = info.port
293 url = f"http://{addresses[0]}:{port}"
294 name = info.properties.get(b"name", b"").decode()
296 # Add peer
297 await self.add_peer(url, source="dns-sd", name=name)
299 except Exception as e:
300 logger.warning(f"Failed to process discovered service {name}: {e}")
302 async def _cleanup_loop(self) -> None:
303 """Periodically clean up stale peers."""
304 while True:
305 try:
306 now = datetime.now(timezone.utc)
307 stale_urls = [url for url, peer in self._discovered_peers.items() if now - peer.last_seen > timedelta(minutes=10)]
308 for url in stale_urls:
309 await self.remove_peer(url)
310 logger.info(f"Removed stale peer: {url}")
312 except Exception as e:
313 logger.error(f"Peer cleanup error: {e}")
315 await asyncio.sleep(60)
317 async def _refresh_loop(self) -> None:
318 """Periodically refresh peer information."""
319 while True:
320 try:
321 # Refresh all peers
322 for url in list(self._discovered_peers.keys()):
323 await self.refresh_peer(url)
325 # Exchange peers
326 await self._exchange_peers()
328 except Exception as e:
329 logger.error(f"Peer refresh error: {e}")
331 await asyncio.sleep(300) # 5 minutes
333 async def _get_gateway_info(self, url: str) -> ServerCapabilities:
334 """Get gateway capabilities.
336 Args:
337 url: Gateway URL
339 Returns:
340 Gateway capabilities
342 Raises:
343 ValueError: If protocol version is unsupported
344 """
345 # Build initialize request
346 request = {
347 "jsonrpc": "2.0",
348 "id": 1,
349 "method": "initialize",
350 "params": {
351 "protocol_version": PROTOCOL_VERSION,
352 "capabilities": {"roots": {"listChanged": True}, "sampling": {}},
353 "client_info": {"name": settings.app_name, "version": "1.0.0"},
354 },
355 }
357 # Send request using the persistent HTTP client directly
358 response = await self._http_client.post(f"{url}/initialize", json=request, headers=self._get_auth_headers())
359 response.raise_for_status()
360 result = response.json()
362 # Validate response
363 if result.get("protocol_version") != PROTOCOL_VERSION:
364 raise ValueError(f"Unsupported protocol version: {result.get('protocol_version')}")
366 return ServerCapabilities.model_validate(result["capabilities"])
368 async def _exchange_peers(self) -> None:
369 """Exchange peer lists with known gateways."""
370 for url in list(self._discovered_peers.keys()):
371 try:
372 # Get peer's peer list using the persistent HTTP client directly
373 response = await self._http_client.get(f"{url}/peers", headers=self._get_auth_headers())
374 response.raise_for_status()
375 peers = response.json()
377 # Add new peers from the response
378 for peer in peers:
379 if isinstance(peer, dict) and "url" in peer:
380 await self.add_peer(peer["url"], source="exchange", name=peer.get("name"))
382 except Exception as e:
383 logger.warning(f"Failed to exchange peers with {url}: {e}")
385 def _get_auth_headers(self) -> Dict[str, str]:
386 """
387 Get headers for gateway authentication.
389 Returns:
390 dict: Authorization header dict
391 """
392 api_key = f"{settings.basic_auth_user}:{settings.basic_auth_password}"
393 return {"Authorization": f"Basic {api_key}", "X-API-Key": api_key}