Coverage for agentos/tools/cors.py: 0%
100 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 08:19 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 08:19 +0800
1"""
2CORS — Cross-Origin Resource Sharing middleware helper.
4Supports:
5 - Fluent builder for CORS configuration
6 - Origin validation (allowlist/regex)
7 - Method (GET/POST/PUT/DELETE/etc.) and header control
8 - Preflight (OPTIONS) response builder
9 - Max-Age, Credentials, Expose-Headers
10 - Serialize to dict of response headers
11"""
13from __future__ import annotations
15import re
16from typing import Dict, List, Optional, Pattern, Set, Union
19# ============================================================================
20# CORS
21# ============================================================================
23SAFELISTED_METHODS = frozenset({"GET", "HEAD", "POST"})
24SAFELISTED_HEADERS = frozenset({
25 "accept", "accept-language", "content-language", "content-type",
26})
29class CORSConfig:
30 """CORS configuration builder.
32 Usage:
33 cors = (CORSConfig()
34 .allow_origins("https://example.com", "https://app.example.com")
35 .allow_methods("GET", "POST", "PUT")
36 .allow_headers("Content-Type", "Authorization")
37 .allow_credentials()
38 .max_age(3600)
39 )
41 # Check if origin is allowed
42 ok = cors.is_origin_allowed("https://example.com")
44 # Build preflight response headers
45 headers = cors.preflight_headers("https://example.com")
46 """
48 def __init__(self):
49 self._origins: List[str] = []
50 self._origin_patterns: List[Pattern] = []
51 self._allow_any_origin: bool = False
52 self._methods: Set[str] = set()
53 self._headers: Set[str] = set()
54 self._expose_headers: Set[str] = set()
55 self._allow_credentials: bool = False
56 self._max_age: Optional[int] = None
58 # ---------- Fluent setters ----------
60 def allow_origins(self, *origins: str) -> CORSConfig:
61 for origin in origins:
62 if origin == "*":
63 self._allow_any_origin = True
64 elif "*" in origin and origin != "*":
65 # Convert glob to regex
66 pattern = re.escape(origin).replace(r"\*", ".*")
67 self._origin_patterns.append(re.compile(f"^{pattern}$"))
68 else:
69 self._origins.append(origin.rstrip("/"))
70 return self
72 def allow_methods(self, *methods: str) -> CORSConfig:
73 self._methods.update(m.upper() for m in methods)
74 return self
76 def allow_headers(self, *headers: str) -> CORSConfig:
77 self._headers.update(h.lower() for h in headers)
78 return self
80 def expose_headers(self, *headers: str) -> CORSConfig:
81 self._expose_headers.update(h.lower() for h in headers)
82 return self
84 def allow_credentials(self) -> CORSConfig:
85 self._allow_credentials = True
86 return self
88 def max_age(self, seconds: int) -> CORSConfig:
89 self._max_age = seconds
90 return self
92 # ---------- Convenience ----------
94 def allow_all_origins(self) -> CORSConfig:
95 self._allow_any_origin = True
96 return self
98 def allow_all_methods(self) -> CORSConfig:
99 self._methods = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}
100 return self
102 def allow_all_headers(self) -> CORSConfig:
103 self._headers = {"*"}
104 return self
106 # ---------- Logic ----------
108 def is_origin_allowed(self, origin: str) -> bool:
109 if self._allow_any_origin:
110 return True
111 origin = origin.rstrip("/")
112 if origin in self._origins:
113 return True
114 for pattern in self._origin_patterns:
115 if pattern.match(origin):
116 return True
117 return False
119 def preflight_headers(
120 self,
121 origin: str,
122 request_method: Optional[str] = None,
123 request_headers: Optional[List[str]] = None,
124 ) -> Dict[str, str]:
125 """Build response headers for a preflight OPTIONS request."""
126 headers: dict = {}
128 if not self.is_origin_allowed(origin):
129 return headers
131 if self._allow_any_origin and not self._allow_credentials:
132 headers["Access-Control-Allow-Origin"] = "*"
133 else:
134 headers["Access-Control-Allow-Origin"] = origin
136 if self._allow_credentials:
137 headers["Access-Control-Allow-Credentials"] = "true"
139 # Methods — return all allowed methods (preflight spec)
140 if self._methods:
141 headers["Access-Control-Allow-Methods"] = ", ".join(sorted(self._methods))
142 elif request_method:
143 headers["Access-Control-Allow-Methods"] = request_method.upper()
145 # Headers
146 if request_headers:
147 allowed = set(h.lower() for h in request_headers)
148 if self._headers and "*" not in self._headers:
149 allowed &= self._headers
150 if allowed:
151 headers["Access-Control-Allow-Headers"] = ", ".join(sorted(allowed))
152 elif self._headers:
153 headers["Access-Control-Allow-Headers"] = ", ".join(sorted(self._headers))
155 # Expose
156 if self._expose_headers:
157 headers["Access-Control-Expose-Headers"] = ", ".join(sorted(self._expose_headers))
159 # Max-Age
160 if self._max_age is not None:
161 headers["Access-Control-Max-Age"] = str(self._max_age)
163 return headers
165 def actual_headers(self, origin: str) -> Dict[str, str]:
166 """Build response headers for the actual (non-preflight) request."""
167 headers: dict = {}
169 if not self.is_origin_allowed(origin):
170 return headers
172 if self._allow_any_origin and not self._allow_credentials:
173 headers["Access-Control-Allow-Origin"] = "*"
174 else:
175 headers["Access-Control-Allow-Origin"] = origin
177 if self._allow_credentials:
178 headers["Access-Control-Allow-Credentials"] = "true"
180 if self._expose_headers:
181 headers["Access-Control-Expose-Headers"] = ", ".join(sorted(self._expose_headers))
183 return headers
185 def is_preflight(self, method: str, headers: Optional[List[str]] = None) -> bool:
186 """Check if a request is a CORS preflight request."""
187 if method.upper() != "OPTIONS":
188 return False
189 # Preflight requires Origin header + either non-safelisted method or custom header
190 # (Origin header presence is assumed by the caller)
191 return True