Coverage for mcpgateway/services/root_service.py: 92%

69 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""Root Service Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements root directory management according to the MCP specification. 

9It handles root registration, validation, and change notifications. 

10""" 

11 

12# Standard 

13import asyncio 

14import logging 

15import os 

16from typing import AsyncGenerator, Dict, List, Optional 

17from urllib.parse import urlparse 

18 

19# First-Party 

20from mcpgateway.config import settings 

21from mcpgateway.models import Root 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26class RootServiceError(Exception): 

27 """Base class for root service errors.""" 

28 

29 

30class RootService: 

31 """MCP root service. 

32 

33 Manages roots that can be exposed to MCP clients. 

34 Handles: 

35 - Root registration and validation 

36 - Change notifications 

37 - Root permissions and access control 

38 """ 

39 

40 def __init__(self): 

41 """Initialize root service.""" 

42 self._roots: Dict[str, Root] = {} 

43 self._subscribers: List[asyncio.Queue] = [] 

44 

45 async def initialize(self) -> None: 

46 """Initialize root service.""" 

47 logger.info("Initializing root service") 

48 # Add any configured default roots 

49 for root_uri in settings.default_roots: 

50 try: 

51 await self.add_root(root_uri) 

52 except RootServiceError as e: 

53 logger.error(f"Failed to add default root {root_uri}: {e}") 

54 

55 async def shutdown(self) -> None: 

56 """Shutdown root service.""" 

57 logger.info("Shutting down root service") 

58 # Clear all roots and subscribers 

59 self._roots.clear() 

60 self._subscribers.clear() 

61 

62 async def list_roots(self) -> List[Root]: 

63 """List available roots. 

64 

65 Returns: 

66 List of registered roots 

67 """ 

68 return list(self._roots.values()) 

69 

70 async def add_root(self, uri: str, name: Optional[str] = None) -> Root: 

71 """Add a new root. 

72 

73 Args: 

74 uri: Root URI 

75 name: Optional root name 

76 

77 Returns: 

78 Created root object 

79 

80 Raises: 

81 RootServiceError: If root is invalid or already exists 

82 """ 

83 try: 

84 root_uri = self._make_root_uri(uri) 

85 except ValueError as e: 

86 raise RootServiceError(f"Invalid root URI: {e}") 

87 

88 if root_uri in self._roots: 

89 raise RootServiceError(f"Root already exists: {root_uri}") 

90 

91 # Skip any access check; just store the key/value. 

92 root_obj = Root( 

93 uri=root_uri, 

94 name=name or os.path.basename(urlparse(root_uri).path) or root_uri, 

95 ) 

96 self._roots[root_uri] = root_obj 

97 

98 await self._notify_root_added(root_obj) 

99 logger.info(f"Added root: {root_uri}") 

100 return root_obj 

101 

102 async def remove_root(self, root_uri: str) -> None: 

103 """Remove a registered root. 

104 

105 Args: 

106 root_uri: Root URI to remove 

107 

108 Raises: 

109 RootServiceError: If root not found 

110 """ 

111 if root_uri not in self._roots: 

112 raise RootServiceError(f"Root not found: {root_uri}") 

113 root_obj = self._roots.pop(root_uri) 

114 await self._notify_root_removed(root_obj) 

115 logger.info(f"Removed root: {root_uri}") 

116 

117 async def subscribe_changes(self) -> AsyncGenerator[Dict, None]: 

118 """Subscribe to root changes. 

119 

120 Yields: 

121 Root change events 

122 """ 

123 queue: asyncio.Queue = asyncio.Queue() 

124 self._subscribers.append(queue) 

125 try: 

126 while True: 

127 event = await queue.get() 

128 yield event 

129 finally: 

130 self._subscribers.remove(queue) 

131 

132 def _make_root_uri(self, uri: str) -> str: 

133 """Convert input to a valid URI. 

134 

135 If no scheme is provided, assume a file URI and convert the path to an absolute path. 

136 

137 Args: 

138 uri: Input URI or filesystem path 

139 

140 Returns: 

141 A valid URI string 

142 """ 

143 parsed = urlparse(uri) 

144 if not parsed.scheme: 

145 # No scheme provided; assume a file URI. 

146 return f"file://{uri}" 

147 # If a scheme is present (e.g., http, https, ftp, etc.), return the URI as-is. 

148 return uri 

149 

150 async def _notify_root_added(self, root: Root) -> None: 

151 """Notify subscribers of root addition. 

152 

153 Args: 

154 root: Added root 

155 """ 

156 event = {"type": "root_added", "data": {"uri": root.uri, "name": root.name}} 

157 await self._notify_subscribers(event) 

158 

159 async def _notify_root_removed(self, root: Root) -> None: 

160 """Notify subscribers of root removal. 

161 

162 Args: 

163 root: Removed root 

164 """ 

165 event = {"type": "root_removed", "data": {"uri": root.uri}} 

166 await self._notify_subscribers(event) 

167 

168 async def _notify_subscribers(self, event: Dict) -> None: 

169 """Send event to all subscribers. 

170 

171 Args: 

172 event: Event to send 

173 """ 

174 for queue in self._subscribers: 

175 try: 

176 await queue.put(event) 

177 except Exception as e: 

178 logger.error(f"Failed to notify subscriber: {e}")