Coverage for mcpgateway/config.py: 91%
194 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""MCP Gateway Configuration.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module defines configuration settings for the MCP Gateway using Pydantic.
9It loads configuration from environment variables with sensible defaults.
11Environment variables:
12- APP_NAME: Gateway name (default: "MCP_Gateway")
13- HOST: Host to bind to (default: "127.0.0.1")
14- PORT: Port to listen on (default: 4444)
15- DATABASE_URL: SQLite database URL (default: "sqlite:///./mcp.db")
16- BASIC_AUTH_USER: Admin username (default: "admin")
17- BASIC_AUTH_PASSWORD: Admin password (default: "changeme")
18- LOG_LEVEL: Logging level (default: "INFO")
19- SKIP_SSL_VERIFY: Disable SSL verification (default: False)
20- AUTH_REQUIRED: Require authentication (default: True)
21- TRANSPORT_TYPE: Transport mechanisms (default: "all")
22- FEDERATION_ENABLED: Enable gateway federation (default: True)
23- FEDERATION_DISCOVERY: Enable auto-discovery (default: False)
24- FEDERATION_PEERS: List of peer gateway URLs (default: [])
25- RESOURCE_CACHE_SIZE: Max cached resources (default: 1000)
26- RESOURCE_CACHE_TTL: Cache TTL in seconds (default: 3600)
27- TOOL_TIMEOUT: Tool invocation timeout (default: 60)
28- PROMPT_CACHE_SIZE: Max cached prompts (default: 100)
29- HEALTH_CHECK_INTERVAL: Gateway health check interval (default: 60)
30"""
32# Standard
33from functools import lru_cache
34from importlib.resources import files
35import json
36from pathlib import Path
37from typing import Annotated, Any, Dict, List, Optional, Set, Union
39# Third-Party
40from fastapi import HTTPException
41import jq
42from jsonpath_ng.ext import parse
43from jsonpath_ng.jsonpath import JSONPath
44from pydantic import field_validator
45from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
48class Settings(BaseSettings):
49 """MCP Gateway configuration settings."""
51 # Basic Settings
52 app_name: str = "MCP_Gateway"
53 host: str = "127.0.0.1"
54 port: int = 4444
55 database_url: str = "sqlite:///./mcp.db"
56 templates_dir: Path = Path("mcpgateway/templates")
57 # Absolute paths resolved at import-time (still override-able via env vars)
58 templates_dir: Path = files("mcpgateway") / "templates"
59 static_dir: Path = files("mcpgateway") / "static"
60 app_root_path: str = ""
62 # Protocol
63 protocol_version: str = "2025-03-26"
65 # Authentication
66 basic_auth_user: str = "admin"
67 basic_auth_password: str = "changeme"
68 jwt_secret_key: str = "my-test-key"
69 jwt_algorithm: str = "HS256"
70 auth_required: bool = True
71 token_expiry: int = 10080 # minutes
73 # Encryption key phrase for auth storage
74 auth_encryption_secret: str = "my-test-salt"
76 # UI/Admin Feature Flags
77 mcpgateway_ui_enabled: bool = True
78 mcpgateway_admin_api_enabled: bool = True
80 # Security
81 skip_ssl_verify: bool = False
83 # For allowed_origins, strip '' to ensure we're passing on valid JSON via env
84 # Tell pydantic *not* to touch this env var - our validator will.
85 allowed_origins: Annotated[Set[str], NoDecode] = {
86 "http://localhost",
87 "http://localhost:4444",
88 }
90 @field_validator("allowed_origins", mode="before")
91 @classmethod
92 def _parse_allowed_origins(cls, v):
93 if isinstance(v, str):
94 v = v.strip()
95 if v[:1] in "\"'" and v[-1:] == v[:1]: # strip 1 outer quote pair 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true
96 v = v[1:-1]
97 try:
98 parsed = set(json.loads(v))
99 except json.JSONDecodeError:
100 parsed = {s.strip() for s in v.split(",") if s.strip()}
101 return parsed
102 return set(v)
104 # Logging
105 log_level: str = "INFO"
106 log_format: str = "json" # json or text
107 log_file: Optional[Path] = None
109 # Transport
110 transport_type: str = "all" # http, ws, sse, all
111 websocket_ping_interval: int = 30 # seconds
112 sse_retry_timeout: int = 5000 # milliseconds
114 # Federation
115 federation_enabled: bool = True
116 federation_discovery: bool = False
118 # For federation_peers strip out quotes to ensure we're passing valid JSON via env
119 federation_peers: Annotated[List[str], NoDecode] = []
121 # Lock file path for initializing gateway service initialize
122 lock_file_path: str = "/tmp/gateway_init.done"
124 @field_validator("federation_peers", mode="before")
125 @classmethod
126 def _parse_federation_peers(cls, v):
127 if isinstance(v, str):
128 v = v.strip()
129 if v[:1] in "\"'" and v[-1:] == v[:1]: 129 ↛ 130line 129 didn't jump to line 130 because the condition on line 129 was never true
130 v = v[1:-1]
131 try:
132 peers = json.loads(v)
133 except json.JSONDecodeError:
134 peers = [s.strip() for s in v.split(",") if s.strip()]
135 return peers
136 return list(v)
138 federation_timeout: int = 30 # seconds
139 federation_sync_interval: int = 300 # seconds
141 # Resources
142 resource_cache_size: int = 1000
143 resource_cache_ttl: int = 3600 # seconds
144 max_resource_size: int = 10 * 1024 * 1024 # 10MB
145 allowed_mime_types: Set[str] = {
146 "text/plain",
147 "text/markdown",
148 "text/html",
149 "application/json",
150 "application/xml",
151 "image/png",
152 "image/jpeg",
153 "image/gif",
154 }
156 # Tools
157 tool_timeout: int = 60 # seconds
158 max_tool_retries: int = 3
159 tool_rate_limit: int = 100 # requests per minute
160 tool_concurrent_limit: int = 10
162 # Prompts
163 prompt_cache_size: int = 100
164 max_prompt_size: int = 100 * 1024 # 100KB
165 prompt_render_timeout: int = 10 # seconds
167 # Health Checks
168 health_check_interval: int = 60 # seconds
169 health_check_timeout: int = 10 # seconds
170 unhealthy_threshold: int = 5 # after this many failures, mark as Offline
172 filelock_name: str = "gateway_service_leader.lock"
174 # Default Roots
175 default_roots: List[str] = []
177 # Database
178 db_pool_size: int = 200
179 db_max_overflow: int = 10
180 db_pool_timeout: int = 30
181 db_pool_recycle: int = 3600
182 db_max_retries: int = 3
183 db_retry_interval_ms: int = 2000
185 # Cache
186 cache_type: str = "database" # memory or redis or database
187 redis_url: Optional[str] = "redis://localhost:6379/0"
188 cache_prefix: str = "mcpgw:"
189 session_ttl: int = 3600
190 message_ttl: int = 600
191 redis_max_retries: int = 3
192 redis_retry_interval_ms: int = 2000
194 # streamable http transport
195 use_stateful_sessions: bool = False # Set to False to use stateless sessions without event store
196 json_response_enabled: bool = True # Enable JSON responses instead of SSE streams
198 # Development
199 dev_mode: bool = False
200 reload: bool = False
201 debug: bool = False
203 model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore")
205 gateway_tool_name_separator: str = "-"
207 @property
208 def api_key(self) -> str:
209 """Generate API key from auth credentials.
211 Returns:
212 str: API key string in the format "username:password".
213 """
214 return f"{self.basic_auth_user}:{self.basic_auth_password}"
216 @property
217 def supports_http(self) -> bool:
218 """Check if HTTP transport is enabled.
220 Returns:
221 bool: True if HTTP transport is enabled, False otherwise.
222 """
223 return self.transport_type in ["http", "all"]
225 @property
226 def supports_websocket(self) -> bool:
227 """Check if WebSocket transport is enabled.
229 Returns:
230 bool: True if WebSocket transport is enabled, False otherwise.
231 """
232 return self.transport_type in ["ws", "all"]
234 @property
235 def supports_sse(self) -> bool:
236 """Check if SSE transport is enabled.
238 Returns:
239 bool: True if SSE transport is enabled, False otherwise.
240 """
241 return self.transport_type in ["sse", "all"]
243 @property
244 def database_settings(self) -> dict:
245 """Get SQLAlchemy database settings.
247 Returns:
248 dict: Dictionary containing SQLAlchemy database configuration options.
249 """
250 return {
251 "pool_size": self.db_pool_size,
252 "max_overflow": self.db_max_overflow,
253 "pool_timeout": self.db_pool_timeout,
254 "pool_recycle": self.db_pool_recycle,
255 "connect_args": {"check_same_thread": False} if self.database_url.startswith("sqlite") else {},
256 }
258 @property
259 def cors_settings(self) -> dict:
260 """Get CORS settings.
262 Returns:
263 dict: Dictionary containing CORS configuration options.
264 """
265 return (
266 {
267 "allow_origins": list(self.allowed_origins),
268 "allow_credentials": True,
269 "allow_methods": ["*"],
270 "allow_headers": ["*"],
271 }
272 if self.cors_enabled
273 else {}
274 )
276 def validate_transport(self) -> None:
277 """Validate transport configuration.
279 Raises:
280 ValueError: If the transport type is not one of the valid options.
281 """
282 valid_types = {"http", "ws", "sse", "all"}
283 if self.transport_type not in valid_types:
284 raise ValueError(f"Invalid transport type. Must be one of: {valid_types}")
286 def validate_database(self) -> None:
287 """Validate database configuration."""
288 if self.database_url.startswith("sqlite"): 288 ↛ exitline 288 didn't return from function 'validate_database' because the condition on line 288 was always true
289 db_path = Path(self.database_url.replace("sqlite:///", ""))
290 db_dir = db_path.parent
291 if not db_dir.exists():
292 db_dir.mkdir(parents=True)
295def extract_using_jq(data, jq_filter=""):
296 """
297 Extracts data from a given input (string, dict, or list) using a jq filter string.
299 Args:
300 data (str, dict, list): The input JSON data. Can be a string, dict, or list.
301 jq_filter (str): The jq filter string to extract the desired data.
303 Returns:
304 The result of applying the jq filter to the input data.
305 """
306 if jq_filter == "":
307 return data
308 if isinstance(data, str):
309 # If the input is a string, parse it as JSON
310 try:
311 data = json.loads(data)
312 except json.JSONDecodeError:
313 return ["Invalid JSON string provided."]
315 elif not isinstance(data, (dict, list)):
316 # If the input is not a string, dict, or list, raise an error
317 return ["Input data must be a JSON string, dictionary, or list."]
319 # Apply the jq filter to the data
320 try:
321 # Pylint can't introspect C-extension modules, so it doesn't know that jq really does export an all() function.
322 # pylint: disable=c-extension-no-member
323 result = jq.all(jq_filter, data) # Use `jq.all` to get all matches (returns a list)
324 if result == [None]: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true
325 result = "Error applying jsonpath filter"
326 except Exception as e:
327 message = "Error applying jsonpath filter: " + str(e)
328 return message
330 return result
333def jsonpath_modifier(data: Any, jsonpath: str = "$[*]", mappings: Optional[Dict[str, str]] = None) -> Union[List, Dict]:
334 """
335 Applies the given JSONPath expression and mappings to the data.
336 Only return data that is required by the user dynamically.
338 Args:
339 data: The JSON data to query.
340 jsonpath: The JSONPath expression to apply.
341 mappings: Optional dictionary of mappings where keys are new field names
342 and values are JSONPath expressions.
344 Returns:
345 Union[List, Dict]: A list (or mapped list) or a Dict of extracted data.
347 Raises:
348 HTTPException: If there's an error parsing or executing the JSONPath expressions.
349 """
350 if not jsonpath: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true
351 jsonpath = "$[*]"
353 try:
354 main_expr: JSONPath = parse(jsonpath)
355 except Exception as e:
356 raise HTTPException(status_code=400, detail=f"Invalid main JSONPath expression: {e}")
358 try:
359 main_matches = main_expr.find(data)
360 except Exception as e:
361 raise HTTPException(status_code=400, detail=f"Error executing main JSONPath: {e}")
363 results = [match.value for match in main_matches]
365 if mappings:
366 mapped_results = []
367 for item in results:
368 mapped_item = {}
369 for new_key, mapping_expr_str in mappings.items():
370 try:
371 mapping_expr = parse(mapping_expr_str)
372 except Exception as e:
373 raise HTTPException(status_code=400, detail=f"Invalid mapping JSONPath for key '{new_key}': {e}")
374 try:
375 mapping_matches = mapping_expr.find(item)
376 except Exception as e:
377 raise HTTPException(status_code=400, detail=f"Error executing mapping JSONPath for key '{new_key}': {e}")
378 if not mapping_matches: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true
379 mapped_item[new_key] = None
380 elif len(mapping_matches) == 1: 380 ↛ 383line 380 didn't jump to line 383 because the condition on line 380 was always true
381 mapped_item[new_key] = mapping_matches[0].value
382 else:
383 mapped_item[new_key] = [m.value for m in mapping_matches]
384 mapped_results.append(mapped_item)
385 results = mapped_results
387 if len(results) == 1 and isinstance(results[0], dict):
388 return results[0]
389 return results
392@lru_cache()
393def get_settings() -> Settings:
394 """Get cached settings instance.
396 Returns:
397 Settings: A cached instance of the Settings class.
398 """
399 # Instantiate a fresh Pydantic Settings object,
400 # loading from env vars or .env exactly once.
401 cfg = Settings()
402 # Validate that transport_type is correct; will
403 # raise if mis-configured.
404 cfg.validate_transport()
405 # Ensure sqlite DB directories exist if needed.
406 cfg.validate_database()
407 # Return the one-and-only Settings instance (cached).
408 return cfg
411# Create settings instance
412settings = get_settings()