Coverage for /Users/antonigmitruk/golf/src/golf/auth/providers.py: 0%

179 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-16 18:46 +0200

1"""Modern authentication provider configurations for Golf MCP servers. 

2 

3This module provides configuration classes for FastMCP 2.11+ authentication providers, 

4replacing the legacy custom OAuth implementation with the new built-in auth system. 

5""" 

6 

7import os 

8from typing import Any, Literal 

9from urllib.parse import urlparse 

10 

11from pydantic import BaseModel, Field, field_validator, model_validator 

12 

13 

14class JWTAuthConfig(BaseModel): 

15 """Configuration for JWT token verification using FastMCP's JWTVerifier. 

16 

17 Use this when you have JWT tokens issued by an external OAuth server 

18 (like Auth0, Okta, etc.) and want to verify them in your Golf server. 

19 

20 Security Note: 

21 For production use, it's strongly recommended to specify both `issuer` and `audience` 

22 to ensure tokens are validated against the expected issuer and intended audience. 

23 This prevents token misuse across different services or environments. 

24 """ 

25 

26 provider_type: Literal["jwt"] = "jwt" 

27 

28 # JWT verification settings 

29 public_key: str | None = Field(None, description="PEM-encoded public key for JWT verification") 

30 jwks_uri: str | None = Field(None, description="URI to fetch JSON Web Key Set for verification") 

31 issuer: str | None = Field(None, description="Expected JWT issuer claim (strongly recommended for production)") 

32 audience: str | list[str] | None = Field( 

33 None, description="Expected JWT audience claim(s) (strongly recommended for production)" 

34 ) 

35 algorithm: str = Field("RS256", description="JWT signing algorithm") 

36 

37 # Scope and access control 

38 required_scopes: list[str] = Field(default_factory=list, description="Scopes required for all requests") 

39 

40 # Environment variable names for runtime configuration 

41 public_key_env_var: str | None = Field(None, description="Environment variable name for public key") 

42 jwks_uri_env_var: str | None = Field(None, description="Environment variable name for JWKS URI") 

43 issuer_env_var: str | None = Field(None, description="Environment variable name for issuer") 

44 audience_env_var: str | None = Field(None, description="Environment variable name for audience") 

45 

46 @model_validator(mode="after") 

47 def validate_jwt_config(self) -> "JWTAuthConfig": 

48 """Validate JWT configuration requirements.""" 

49 # Ensure exactly one of public_key or jwks_uri is provided 

50 if not self.public_key and not self.jwks_uri and not self.public_key_env_var and not self.jwks_uri_env_var: 

51 raise ValueError("Either public_key, jwks_uri, or their environment variable equivalents must be provided") 

52 

53 if (self.public_key or self.public_key_env_var) and (self.jwks_uri or self.jwks_uri_env_var): 

54 raise ValueError("Provide either public_key or jwks_uri (or their env vars), not both") 

55 

56 # Warn about missing issuer/audience in production-like environments 

57 is_production = ( 

58 os.environ.get("GOLF_ENV", "").lower() in ("prod", "production") 

59 or os.environ.get("NODE_ENV", "").lower() == "production" 

60 or os.environ.get("ENVIRONMENT", "").lower() in ("prod", "production") 

61 ) 

62 

63 if is_production: 

64 missing_fields = [] 

65 if not self.issuer and not self.issuer_env_var: 

66 missing_fields.append("issuer") 

67 if not self.audience and not self.audience_env_var: 

68 missing_fields.append("audience") 

69 

70 if missing_fields: 

71 import warnings 

72 

73 warnings.warn( 

74 f"JWT configuration is missing recommended fields for production: {', '.join(missing_fields)}. " 

75 "This may allow tokens from unintended issuers or audiences to be accepted.", 

76 UserWarning, 

77 stacklevel=2, 

78 ) 

79 

80 return self 

81 

82 

83class StaticTokenConfig(BaseModel): 

84 """Configuration for static token verification for development/testing. 

85 

86 Use this for local development and testing when you need predictable 

87 API keys without setting up a full OAuth server. 

88 

89 WARNING: Never use in production! 

90 """ 

91 

92 provider_type: Literal["static"] = "static" 

93 

94 # Static tokens mapping: token_string -> metadata 

95 tokens: dict[str, dict[str, Any]] = Field( 

96 default_factory=dict, 

97 description="Static tokens with their metadata (client_id, scopes, expires_at)", 

98 ) 

99 

100 # Scope and access control 

101 required_scopes: list[str] = Field(default_factory=list, description="Scopes required for all requests") 

102 

103 

104class OAuthServerConfig(BaseModel): 

105 """Configuration for full OAuth authorization server using FastMCP's OAuthProvider. 

106 

107 Use this when you want your Golf server to act as a complete OAuth server, 

108 handling authorization flows and token issuance. 

109 

110 Security Considerations: 

111 - URLs are validated to prevent SSRF attacks 

112 - Scopes are validated against OAuth 2.0 standards 

113 - Base URL must use HTTPS in production environments 

114 - Client registration is disabled for security 

115 """ 

116 

117 provider_type: Literal["oauth_server"] = "oauth_server" 

118 

119 # OAuth server URLs 

120 base_url: str = Field(..., description="Public URL of this Golf server (must use HTTPS in production)") 

121 issuer_url: str | None = Field(None, description="OAuth issuer URL (defaults to base_url, must be HTTPS)") 

122 service_documentation_url: str | None = Field(None, description="URL of service documentation") 

123 

124 # Client registration settings 

125 valid_scopes: list[str] = Field( 

126 default_factory=list, description="Valid scopes for client registration (OAuth 2.0 format)" 

127 ) 

128 default_scopes: list[str] = Field(default_factory=list, description="Default scopes for new clients") 

129 

130 # Token revocation settings 

131 allow_token_revocation: bool = Field(True, description="Allow token revocation") 

132 

133 # Access control 

134 required_scopes: list[str] = Field(default_factory=list, description="Scopes required for all requests") 

135 

136 # Environment variable names for runtime configuration 

137 base_url_env_var: str | None = Field(None, description="Environment variable name for base URL") 

138 

139 @field_validator("base_url") 

140 @classmethod 

141 def validate_base_url(cls, v: str) -> str: 

142 """Validate base URL for security and format compliance.""" 

143 if not v or not v.strip(): 

144 raise ValueError("base_url cannot be empty") 

145 

146 url = v.strip() 

147 try: 

148 parsed = urlparse(url) 

149 if not parsed.scheme or not parsed.netloc: 

150 raise ValueError(f"Invalid base URL format: '{url}' - must include scheme and netloc") 

151 

152 if parsed.scheme not in ("http", "https"): 

153 raise ValueError(f"Base URL must use http or https scheme: '{url}'") 

154 

155 # Warn about HTTP in production-like environments 

156 is_production = ( 

157 os.environ.get("GOLF_ENV", "").lower() in ("prod", "production") 

158 or os.environ.get("NODE_ENV", "").lower() == "production" 

159 or os.environ.get("ENVIRONMENT", "").lower() in ("prod", "production") 

160 ) 

161 

162 if is_production and parsed.scheme == "http": 

163 import warnings 

164 

165 warnings.warn( 

166 f"Base URL '{url}' uses HTTP in production environment. " 

167 "HTTPS is strongly recommended for OAuth servers to prevent token interception.", 

168 UserWarning, 

169 stacklevel=2, 

170 ) 

171 

172 # Prevent common SSRF targets 

173 if parsed.hostname in ("localhost", "127.0.0.1", "0.0.0.0"): 

174 if is_production: 

175 raise ValueError(f"Base URL cannot use localhost/loopback addresses in production: '{url}'") 

176 

177 except Exception as e: 

178 if isinstance(e, ValueError): 

179 raise 

180 raise ValueError(f"Invalid base URL '{url}': {e}") from e 

181 

182 return url 

183 

184 @field_validator("issuer_url", "service_documentation_url") 

185 @classmethod 

186 def validate_optional_urls(cls, v: str | None) -> str | None: 

187 """Validate optional URLs for security and format compliance.""" 

188 if not v: 

189 return v 

190 

191 url = v.strip() 

192 if not url: 

193 return None 

194 

195 try: 

196 parsed = urlparse(url) 

197 if not parsed.scheme or not parsed.netloc: 

198 raise ValueError(f"Invalid URL format: '{url}' - must include scheme and netloc") 

199 

200 if parsed.scheme not in ("http", "https"): 

201 raise ValueError(f"URL must use http or https scheme: '{url}'") 

202 

203 # Check for HTTPS requirement in production for issuer URL 

204 if v == cls.__dict__.get("issuer_url"): # This is the issuer_url field 

205 is_production = ( 

206 os.environ.get("GOLF_ENV", "").lower() in ("prod", "production") 

207 or os.environ.get("NODE_ENV", "").lower() == "production" 

208 or os.environ.get("ENVIRONMENT", "").lower() in ("prod", "production") 

209 ) 

210 

211 if is_production and parsed.scheme == "http": 

212 import warnings 

213 

214 warnings.warn( 

215 f"Issuer URL '{url}' uses HTTP in production. HTTPS is required for OAuth issuer URLs.", 

216 UserWarning, 

217 stacklevel=2, 

218 ) 

219 

220 except Exception as e: 

221 if isinstance(e, ValueError): 

222 raise 

223 raise ValueError(f"Invalid URL '{url}': {e}") from e 

224 

225 return url 

226 

227 @field_validator("valid_scopes", "default_scopes", "required_scopes") 

228 @classmethod 

229 def validate_scopes(cls, v: list[str]) -> list[str]: 

230 """Validate OAuth 2.0 scopes format and security.""" 

231 if not v: 

232 return v 

233 

234 valid_scopes = [] 

235 for scope in v: 

236 scope = scope.strip() 

237 if not scope: 

238 raise ValueError("Scopes cannot be empty or whitespace-only") 

239 

240 # OAuth 2.0 scope format validation (RFC 6749) 

241 # Scopes should be ASCII printable characters except space, and no control characters 

242 if not all(32 < ord(c) < 127 and c not in ' "\\' for c in scope): 

243 raise ValueError( 

244 f"Invalid scope format: '{scope}' - must be ASCII printable without spaces, quotes, or backslashes" 

245 ) 

246 

247 # Reasonable length limit to prevent abuse 

248 if len(scope) > 128: 

249 raise ValueError(f"Scope too long: '{scope}' - maximum 128 characters") 

250 

251 # Prevent potentially dangerous scope names 

252 dangerous_scopes = {"admin", "root", "superuser", "system", "*", "all"} 

253 if scope.lower() in dangerous_scopes: 

254 import warnings 

255 

256 warnings.warn( 

257 f"Potentially dangerous scope detected: '{scope}'. " 

258 "Consider using more specific, principle-of-least-privilege scopes.", 

259 UserWarning, 

260 stacklevel=2, 

261 ) 

262 

263 valid_scopes.append(scope) 

264 

265 return valid_scopes 

266 

267 @model_validator(mode="after") 

268 def validate_oauth_server_config(self) -> "OAuthServerConfig": 

269 """Validate OAuth server configuration for security and consistency.""" 

270 # Validate default_scopes are subset of valid_scopes 

271 if self.default_scopes and self.valid_scopes: 

272 invalid_defaults = set(self.default_scopes) - set(self.valid_scopes) 

273 if invalid_defaults: 

274 raise ValueError(f"default_scopes contains invalid scopes not in valid_scopes: {invalid_defaults}") 

275 

276 # Validate required_scopes are subset of valid_scopes 

277 if self.required_scopes and self.valid_scopes: 

278 invalid_required = set(self.required_scopes) - set(self.valid_scopes) 

279 if invalid_required: 

280 raise ValueError(f"required_scopes contains invalid scopes not in valid_scopes: {invalid_required}") 

281 

282 return self 

283 

284 

285class RemoteAuthConfig(BaseModel): 

286 """Configuration for remote authorization server integration. 

287 

288 Use this when you have token verification logic and want to advertise 

289 the authorization servers that issue valid tokens (RFC 9728 compliance). 

290 """ 

291 

292 provider_type: Literal["remote"] = "remote" 

293 

294 # Authorization servers that issue tokens 

295 authorization_servers: list[str] = Field( 

296 ..., description="List of authorization server URLs that issue valid tokens" 

297 ) 

298 

299 # This server's URL 

300 resource_server_url: str = Field(..., description="URL of this resource server") 

301 

302 # Token verification (delegate to another config) 

303 token_verifier_config: JWTAuthConfig | StaticTokenConfig = Field( 

304 ..., description="Configuration for the underlying token verifier" 

305 ) 

306 

307 # Environment variable names for runtime configuration 

308 authorization_servers_env_var: str | None = Field( 

309 None, description="Environment variable name for comma-separated authorization server URLs" 

310 ) 

311 resource_server_url_env_var: str | None = Field( 

312 None, description="Environment variable name for resource server URL" 

313 ) 

314 

315 @field_validator("authorization_servers") 

316 @classmethod 

317 def validate_authorization_servers(cls, v: list[str]) -> list[str]: 

318 """Validate authorization servers are non-empty and valid URLs.""" 

319 if not v: 

320 raise ValueError( 

321 "authorization_servers cannot be empty - at least one authorization server URL is required" 

322 ) 

323 

324 valid_urls = [] 

325 for url in v: 

326 url = url.strip() 

327 if not url: 

328 raise ValueError("authorization_servers cannot contain empty URLs") 

329 

330 # Validate URL format 

331 try: 

332 parsed = urlparse(url) 

333 if not parsed.scheme or not parsed.netloc: 

334 raise ValueError( 

335 f"Invalid URL format for authorization server: '{url}' - must include scheme and netloc" 

336 ) 

337 if parsed.scheme not in ("http", "https"): 

338 raise ValueError(f"Authorization server URL must use http or https scheme: '{url}'") 

339 except Exception as e: 

340 raise ValueError(f"Invalid authorization server URL '{url}': {e}") from e 

341 

342 valid_urls.append(url) 

343 

344 return valid_urls 

345 

346 @field_validator("resource_server_url") 

347 @classmethod 

348 def validate_resource_server_url(cls, v: str) -> str: 

349 """Validate resource server URL is a valid URL.""" 

350 if not v or not v.strip(): 

351 raise ValueError("resource_server_url cannot be empty") 

352 

353 url = v.strip() 

354 try: 

355 parsed = urlparse(url) 

356 if not parsed.scheme or not parsed.netloc: 

357 raise ValueError(f"Invalid URL format for resource server: '{url}' - must include scheme and netloc") 

358 if parsed.scheme not in ("http", "https"): 

359 raise ValueError(f"Resource server URL must use http or https scheme: '{url}'") 

360 except Exception as e: 

361 raise ValueError(f"Invalid resource server URL '{url}': {e}") from e 

362 

363 return url 

364 

365 @model_validator(mode="after") 

366 def validate_token_verifier_compatibility(self) -> "RemoteAuthConfig": 

367 """Validate that the token verifier config is compatible with token verification.""" 

368 # The duck-typing check is already handled by the factory function, but we can 

369 # add a basic sanity check here that the config types are ones we know work 

370 config = self.token_verifier_config 

371 

372 if not isinstance(config, JWTAuthConfig | StaticTokenConfig): 

373 raise ValueError( 

374 f"token_verifier_config must be JWTAuthConfig or StaticTokenConfig, got {type(config).__name__}" 

375 ) 

376 

377 # For JWT configs, ensure they have the minimum required fields 

378 if isinstance(config, JWTAuthConfig) and ( 

379 not config.public_key 

380 and not config.jwks_uri 

381 and not config.public_key_env_var 

382 and not config.jwks_uri_env_var 

383 ): 

384 raise ValueError( 

385 "JWT token verifier config must provide public_key, jwks_uri, or their environment variable equivalents" 

386 ) 

387 

388 # For static token configs, ensure they have tokens 

389 if isinstance(config, StaticTokenConfig) and not config.tokens: 

390 raise ValueError("Static token verifier config must provide at least one token") 

391 

392 return self 

393 

394 

395# Union type for all auth configurations 

396AuthConfig = JWTAuthConfig | StaticTokenConfig | OAuthServerConfig | RemoteAuthConfig