Coverage for mcpgateway/config.py: 91%

194 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""MCP Gateway Configuration. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module defines configuration settings for the MCP Gateway using Pydantic. 

9It loads configuration from environment variables with sensible defaults. 

10 

11Environment variables: 

12- APP_NAME: Gateway name (default: "MCP_Gateway") 

13- HOST: Host to bind to (default: "127.0.0.1") 

14- PORT: Port to listen on (default: 4444) 

15- DATABASE_URL: SQLite database URL (default: "sqlite:///./mcp.db") 

16- BASIC_AUTH_USER: Admin username (default: "admin") 

17- BASIC_AUTH_PASSWORD: Admin password (default: "changeme") 

18- LOG_LEVEL: Logging level (default: "INFO") 

19- SKIP_SSL_VERIFY: Disable SSL verification (default: False) 

20- AUTH_REQUIRED: Require authentication (default: True) 

21- TRANSPORT_TYPE: Transport mechanisms (default: "all") 

22- FEDERATION_ENABLED: Enable gateway federation (default: True) 

23- FEDERATION_DISCOVERY: Enable auto-discovery (default: False) 

24- FEDERATION_PEERS: List of peer gateway URLs (default: []) 

25- RESOURCE_CACHE_SIZE: Max cached resources (default: 1000) 

26- RESOURCE_CACHE_TTL: Cache TTL in seconds (default: 3600) 

27- TOOL_TIMEOUT: Tool invocation timeout (default: 60) 

28- PROMPT_CACHE_SIZE: Max cached prompts (default: 100) 

29- HEALTH_CHECK_INTERVAL: Gateway health check interval (default: 60) 

30""" 

31 

32# Standard 

33from functools import lru_cache 

34from importlib.resources import files 

35import json 

36from pathlib import Path 

37from typing import Annotated, Any, Dict, List, Optional, Set, Union 

38 

39# Third-Party 

40from fastapi import HTTPException 

41import jq 

42from jsonpath_ng.ext import parse 

43from jsonpath_ng.jsonpath import JSONPath 

44from pydantic import field_validator 

45from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict 

46 

47 

48class Settings(BaseSettings): 

49 """MCP Gateway configuration settings.""" 

50 

51 # Basic Settings 

52 app_name: str = "MCP_Gateway" 

53 host: str = "127.0.0.1" 

54 port: int = 4444 

55 database_url: str = "sqlite:///./mcp.db" 

56 templates_dir: Path = Path("mcpgateway/templates") 

57 # Absolute paths resolved at import-time (still override-able via env vars) 

58 templates_dir: Path = files("mcpgateway") / "templates" 

59 static_dir: Path = files("mcpgateway") / "static" 

60 app_root_path: str = "" 

61 

62 # Protocol 

63 protocol_version: str = "2025-03-26" 

64 

65 # Authentication 

66 basic_auth_user: str = "admin" 

67 basic_auth_password: str = "changeme" 

68 jwt_secret_key: str = "my-test-key" 

69 jwt_algorithm: str = "HS256" 

70 auth_required: bool = True 

71 token_expiry: int = 10080 # minutes 

72 

73 # Encryption key phrase for auth storage 

74 auth_encryption_secret: str = "my-test-salt" 

75 

76 # UI/Admin Feature Flags 

77 mcpgateway_ui_enabled: bool = True 

78 mcpgateway_admin_api_enabled: bool = True 

79 

80 # Security 

81 skip_ssl_verify: bool = False 

82 

83 # For allowed_origins, strip '' to ensure we're passing on valid JSON via env 

84 # Tell pydantic *not* to touch this env var - our validator will. 

85 allowed_origins: Annotated[Set[str], NoDecode] = { 

86 "http://localhost", 

87 "http://localhost:4444", 

88 } 

89 

90 @field_validator("allowed_origins", mode="before") 

91 @classmethod 

92 def _parse_allowed_origins(cls, v): 

93 if isinstance(v, str): 

94 v = v.strip() 

95 if v[:1] in "\"'" and v[-1:] == v[:1]: # strip 1 outer quote pair 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true

96 v = v[1:-1] 

97 try: 

98 parsed = set(json.loads(v)) 

99 except json.JSONDecodeError: 

100 parsed = {s.strip() for s in v.split(",") if s.strip()} 

101 return parsed 

102 return set(v) 

103 

104 # Logging 

105 log_level: str = "INFO" 

106 log_format: str = "json" # json or text 

107 log_file: Optional[Path] = None 

108 

109 # Transport 

110 transport_type: str = "all" # http, ws, sse, all 

111 websocket_ping_interval: int = 30 # seconds 

112 sse_retry_timeout: int = 5000 # milliseconds 

113 

114 # Federation 

115 federation_enabled: bool = True 

116 federation_discovery: bool = False 

117 

118 # For federation_peers strip out quotes to ensure we're passing valid JSON via env 

119 federation_peers: Annotated[List[str], NoDecode] = [] 

120 

121 # Lock file path for initializing gateway service initialize 

122 lock_file_path: str = "/tmp/gateway_init.done" 

123 

124 @field_validator("federation_peers", mode="before") 

125 @classmethod 

126 def _parse_federation_peers(cls, v): 

127 if isinstance(v, str): 

128 v = v.strip() 

129 if v[:1] in "\"'" and v[-1:] == v[:1]: 129 ↛ 130line 129 didn't jump to line 130 because the condition on line 129 was never true

130 v = v[1:-1] 

131 try: 

132 peers = json.loads(v) 

133 except json.JSONDecodeError: 

134 peers = [s.strip() for s in v.split(",") if s.strip()] 

135 return peers 

136 return list(v) 

137 

138 federation_timeout: int = 30 # seconds 

139 federation_sync_interval: int = 300 # seconds 

140 

141 # Resources 

142 resource_cache_size: int = 1000 

143 resource_cache_ttl: int = 3600 # seconds 

144 max_resource_size: int = 10 * 1024 * 1024 # 10MB 

145 allowed_mime_types: Set[str] = { 

146 "text/plain", 

147 "text/markdown", 

148 "text/html", 

149 "application/json", 

150 "application/xml", 

151 "image/png", 

152 "image/jpeg", 

153 "image/gif", 

154 } 

155 

156 # Tools 

157 tool_timeout: int = 60 # seconds 

158 max_tool_retries: int = 3 

159 tool_rate_limit: int = 100 # requests per minute 

160 tool_concurrent_limit: int = 10 

161 

162 # Prompts 

163 prompt_cache_size: int = 100 

164 max_prompt_size: int = 100 * 1024 # 100KB 

165 prompt_render_timeout: int = 10 # seconds 

166 

167 # Health Checks 

168 health_check_interval: int = 60 # seconds 

169 health_check_timeout: int = 10 # seconds 

170 unhealthy_threshold: int = 5 # after this many failures, mark as Offline 

171 

172 filelock_name: str = "gateway_service_leader.lock" 

173 

174 # Default Roots 

175 default_roots: List[str] = [] 

176 

177 # Database 

178 db_pool_size: int = 200 

179 db_max_overflow: int = 10 

180 db_pool_timeout: int = 30 

181 db_pool_recycle: int = 3600 

182 db_max_retries: int = 3 

183 db_retry_interval_ms: int = 2000 

184 

185 # Cache 

186 cache_type: str = "database" # memory or redis or database 

187 redis_url: Optional[str] = "redis://localhost:6379/0" 

188 cache_prefix: str = "mcpgw:" 

189 session_ttl: int = 3600 

190 message_ttl: int = 600 

191 redis_max_retries: int = 3 

192 redis_retry_interval_ms: int = 2000 

193 

194 # streamable http transport 

195 use_stateful_sessions: bool = False # Set to False to use stateless sessions without event store 

196 json_response_enabled: bool = True # Enable JSON responses instead of SSE streams 

197 

198 # Development 

199 dev_mode: bool = False 

200 reload: bool = False 

201 debug: bool = False 

202 

203 model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore") 

204 

205 gateway_tool_name_separator: str = "-" 

206 

207 @property 

208 def api_key(self) -> str: 

209 """Generate API key from auth credentials. 

210 

211 Returns: 

212 str: API key string in the format "username:password". 

213 """ 

214 return f"{self.basic_auth_user}:{self.basic_auth_password}" 

215 

216 @property 

217 def supports_http(self) -> bool: 

218 """Check if HTTP transport is enabled. 

219 

220 Returns: 

221 bool: True if HTTP transport is enabled, False otherwise. 

222 """ 

223 return self.transport_type in ["http", "all"] 

224 

225 @property 

226 def supports_websocket(self) -> bool: 

227 """Check if WebSocket transport is enabled. 

228 

229 Returns: 

230 bool: True if WebSocket transport is enabled, False otherwise. 

231 """ 

232 return self.transport_type in ["ws", "all"] 

233 

234 @property 

235 def supports_sse(self) -> bool: 

236 """Check if SSE transport is enabled. 

237 

238 Returns: 

239 bool: True if SSE transport is enabled, False otherwise. 

240 """ 

241 return self.transport_type in ["sse", "all"] 

242 

243 @property 

244 def database_settings(self) -> dict: 

245 """Get SQLAlchemy database settings. 

246 

247 Returns: 

248 dict: Dictionary containing SQLAlchemy database configuration options. 

249 """ 

250 return { 

251 "pool_size": self.db_pool_size, 

252 "max_overflow": self.db_max_overflow, 

253 "pool_timeout": self.db_pool_timeout, 

254 "pool_recycle": self.db_pool_recycle, 

255 "connect_args": {"check_same_thread": False} if self.database_url.startswith("sqlite") else {}, 

256 } 

257 

258 @property 

259 def cors_settings(self) -> dict: 

260 """Get CORS settings. 

261 

262 Returns: 

263 dict: Dictionary containing CORS configuration options. 

264 """ 

265 return ( 

266 { 

267 "allow_origins": list(self.allowed_origins), 

268 "allow_credentials": True, 

269 "allow_methods": ["*"], 

270 "allow_headers": ["*"], 

271 } 

272 if self.cors_enabled 

273 else {} 

274 ) 

275 

276 def validate_transport(self) -> None: 

277 """Validate transport configuration. 

278 

279 Raises: 

280 ValueError: If the transport type is not one of the valid options. 

281 """ 

282 valid_types = {"http", "ws", "sse", "all"} 

283 if self.transport_type not in valid_types: 

284 raise ValueError(f"Invalid transport type. Must be one of: {valid_types}") 

285 

286 def validate_database(self) -> None: 

287 """Validate database configuration.""" 

288 if self.database_url.startswith("sqlite"): 288 ↛ exitline 288 didn't return from function 'validate_database' because the condition on line 288 was always true

289 db_path = Path(self.database_url.replace("sqlite:///", "")) 

290 db_dir = db_path.parent 

291 if not db_dir.exists(): 

292 db_dir.mkdir(parents=True) 

293 

294 

295def extract_using_jq(data, jq_filter=""): 

296 """ 

297 Extracts data from a given input (string, dict, or list) using a jq filter string. 

298 

299 Args: 

300 data (str, dict, list): The input JSON data. Can be a string, dict, or list. 

301 jq_filter (str): The jq filter string to extract the desired data. 

302 

303 Returns: 

304 The result of applying the jq filter to the input data. 

305 """ 

306 if jq_filter == "": 

307 return data 

308 if isinstance(data, str): 

309 # If the input is a string, parse it as JSON 

310 try: 

311 data = json.loads(data) 

312 except json.JSONDecodeError: 

313 return ["Invalid JSON string provided."] 

314 

315 elif not isinstance(data, (dict, list)): 

316 # If the input is not a string, dict, or list, raise an error 

317 return ["Input data must be a JSON string, dictionary, or list."] 

318 

319 # Apply the jq filter to the data 

320 try: 

321 # Pylint can't introspect C-extension modules, so it doesn't know that jq really does export an all() function. 

322 # pylint: disable=c-extension-no-member 

323 result = jq.all(jq_filter, data) # Use `jq.all` to get all matches (returns a list) 

324 if result == [None]: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true

325 result = "Error applying jsonpath filter" 

326 except Exception as e: 

327 message = "Error applying jsonpath filter: " + str(e) 

328 return message 

329 

330 return result 

331 

332 

333def jsonpath_modifier(data: Any, jsonpath: str = "$[*]", mappings: Optional[Dict[str, str]] = None) -> Union[List, Dict]: 

334 """ 

335 Applies the given JSONPath expression and mappings to the data. 

336 Only return data that is required by the user dynamically. 

337 

338 Args: 

339 data: The JSON data to query. 

340 jsonpath: The JSONPath expression to apply. 

341 mappings: Optional dictionary of mappings where keys are new field names 

342 and values are JSONPath expressions. 

343 

344 Returns: 

345 Union[List, Dict]: A list (or mapped list) or a Dict of extracted data. 

346 

347 Raises: 

348 HTTPException: If there's an error parsing or executing the JSONPath expressions. 

349 """ 

350 if not jsonpath: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 jsonpath = "$[*]" 

352 

353 try: 

354 main_expr: JSONPath = parse(jsonpath) 

355 except Exception as e: 

356 raise HTTPException(status_code=400, detail=f"Invalid main JSONPath expression: {e}") 

357 

358 try: 

359 main_matches = main_expr.find(data) 

360 except Exception as e: 

361 raise HTTPException(status_code=400, detail=f"Error executing main JSONPath: {e}") 

362 

363 results = [match.value for match in main_matches] 

364 

365 if mappings: 

366 mapped_results = [] 

367 for item in results: 

368 mapped_item = {} 

369 for new_key, mapping_expr_str in mappings.items(): 

370 try: 

371 mapping_expr = parse(mapping_expr_str) 

372 except Exception as e: 

373 raise HTTPException(status_code=400, detail=f"Invalid mapping JSONPath for key '{new_key}': {e}") 

374 try: 

375 mapping_matches = mapping_expr.find(item) 

376 except Exception as e: 

377 raise HTTPException(status_code=400, detail=f"Error executing mapping JSONPath for key '{new_key}': {e}") 

378 if not mapping_matches: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true

379 mapped_item[new_key] = None 

380 elif len(mapping_matches) == 1: 380 ↛ 383line 380 didn't jump to line 383 because the condition on line 380 was always true

381 mapped_item[new_key] = mapping_matches[0].value 

382 else: 

383 mapped_item[new_key] = [m.value for m in mapping_matches] 

384 mapped_results.append(mapped_item) 

385 results = mapped_results 

386 

387 if len(results) == 1 and isinstance(results[0], dict): 

388 return results[0] 

389 return results 

390 

391 

392@lru_cache() 

393def get_settings() -> Settings: 

394 """Get cached settings instance. 

395 

396 Returns: 

397 Settings: A cached instance of the Settings class. 

398 """ 

399 # Instantiate a fresh Pydantic Settings object, 

400 # loading from env vars or .env exactly once. 

401 cfg = Settings() 

402 # Validate that transport_type is correct; will 

403 # raise if mis-configured. 

404 cfg.validate_transport() 

405 # Ensure sqlite DB directories exist if needed. 

406 cfg.validate_database() 

407 # Return the one-and-only Settings instance (cached). 

408 return cfg 

409 

410 

411# Create settings instance 

412settings = get_settings()