Coverage for agentos/tools/cors.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 08:19 +0800

1""" 

2CORS — Cross-Origin Resource Sharing middleware helper. 

3 

4Supports: 

5 - Fluent builder for CORS configuration 

6 - Origin validation (allowlist/regex) 

7 - Method (GET/POST/PUT/DELETE/etc.) and header control 

8 - Preflight (OPTIONS) response builder 

9 - Max-Age, Credentials, Expose-Headers 

10 - Serialize to dict of response headers 

11""" 

12 

13from __future__ import annotations 

14 

15import re 

16from typing import Dict, List, Optional, Pattern, Set, Union 

17 

18 

19# ============================================================================ 

20# CORS 

21# ============================================================================ 

22 

23SAFELISTED_METHODS = frozenset({"GET", "HEAD", "POST"}) 

24SAFELISTED_HEADERS = frozenset({ 

25 "accept", "accept-language", "content-language", "content-type", 

26}) 

27 

28 

29class CORSConfig: 

30 """CORS configuration builder. 

31 

32 Usage: 

33 cors = (CORSConfig() 

34 .allow_origins("https://example.com", "https://app.example.com") 

35 .allow_methods("GET", "POST", "PUT") 

36 .allow_headers("Content-Type", "Authorization") 

37 .allow_credentials() 

38 .max_age(3600) 

39 ) 

40 

41 # Check if origin is allowed 

42 ok = cors.is_origin_allowed("https://example.com") 

43 

44 # Build preflight response headers 

45 headers = cors.preflight_headers("https://example.com") 

46 """ 

47 

48 def __init__(self): 

49 self._origins: List[str] = [] 

50 self._origin_patterns: List[Pattern] = [] 

51 self._allow_any_origin: bool = False 

52 self._methods: Set[str] = set() 

53 self._headers: Set[str] = set() 

54 self._expose_headers: Set[str] = set() 

55 self._allow_credentials: bool = False 

56 self._max_age: Optional[int] = None 

57 

58 # ---------- Fluent setters ---------- 

59 

60 def allow_origins(self, *origins: str) -> CORSConfig: 

61 for origin in origins: 

62 if origin == "*": 

63 self._allow_any_origin = True 

64 elif "*" in origin and origin != "*": 

65 # Convert glob to regex 

66 pattern = re.escape(origin).replace(r"\*", ".*") 

67 self._origin_patterns.append(re.compile(f"^{pattern}$")) 

68 else: 

69 self._origins.append(origin.rstrip("/")) 

70 return self 

71 

72 def allow_methods(self, *methods: str) -> CORSConfig: 

73 self._methods.update(m.upper() for m in methods) 

74 return self 

75 

76 def allow_headers(self, *headers: str) -> CORSConfig: 

77 self._headers.update(h.lower() for h in headers) 

78 return self 

79 

80 def expose_headers(self, *headers: str) -> CORSConfig: 

81 self._expose_headers.update(h.lower() for h in headers) 

82 return self 

83 

84 def allow_credentials(self) -> CORSConfig: 

85 self._allow_credentials = True 

86 return self 

87 

88 def max_age(self, seconds: int) -> CORSConfig: 

89 self._max_age = seconds 

90 return self 

91 

92 # ---------- Convenience ---------- 

93 

94 def allow_all_origins(self) -> CORSConfig: 

95 self._allow_any_origin = True 

96 return self 

97 

98 def allow_all_methods(self) -> CORSConfig: 

99 self._methods = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"} 

100 return self 

101 

102 def allow_all_headers(self) -> CORSConfig: 

103 self._headers = {"*"} 

104 return self 

105 

106 # ---------- Logic ---------- 

107 

108 def is_origin_allowed(self, origin: str) -> bool: 

109 if self._allow_any_origin: 

110 return True 

111 origin = origin.rstrip("/") 

112 if origin in self._origins: 

113 return True 

114 for pattern in self._origin_patterns: 

115 if pattern.match(origin): 

116 return True 

117 return False 

118 

119 def preflight_headers( 

120 self, 

121 origin: str, 

122 request_method: Optional[str] = None, 

123 request_headers: Optional[List[str]] = None, 

124 ) -> Dict[str, str]: 

125 """Build response headers for a preflight OPTIONS request.""" 

126 headers: dict = {} 

127 

128 if not self.is_origin_allowed(origin): 

129 return headers 

130 

131 if self._allow_any_origin and not self._allow_credentials: 

132 headers["Access-Control-Allow-Origin"] = "*" 

133 else: 

134 headers["Access-Control-Allow-Origin"] = origin 

135 

136 if self._allow_credentials: 

137 headers["Access-Control-Allow-Credentials"] = "true" 

138 

139 # Methods — return all allowed methods (preflight spec) 

140 if self._methods: 

141 headers["Access-Control-Allow-Methods"] = ", ".join(sorted(self._methods)) 

142 elif request_method: 

143 headers["Access-Control-Allow-Methods"] = request_method.upper() 

144 

145 # Headers 

146 if request_headers: 

147 allowed = set(h.lower() for h in request_headers) 

148 if self._headers and "*" not in self._headers: 

149 allowed &= self._headers 

150 if allowed: 

151 headers["Access-Control-Allow-Headers"] = ", ".join(sorted(allowed)) 

152 elif self._headers: 

153 headers["Access-Control-Allow-Headers"] = ", ".join(sorted(self._headers)) 

154 

155 # Expose 

156 if self._expose_headers: 

157 headers["Access-Control-Expose-Headers"] = ", ".join(sorted(self._expose_headers)) 

158 

159 # Max-Age 

160 if self._max_age is not None: 

161 headers["Access-Control-Max-Age"] = str(self._max_age) 

162 

163 return headers 

164 

165 def actual_headers(self, origin: str) -> Dict[str, str]: 

166 """Build response headers for the actual (non-preflight) request.""" 

167 headers: dict = {} 

168 

169 if not self.is_origin_allowed(origin): 

170 return headers 

171 

172 if self._allow_any_origin and not self._allow_credentials: 

173 headers["Access-Control-Allow-Origin"] = "*" 

174 else: 

175 headers["Access-Control-Allow-Origin"] = origin 

176 

177 if self._allow_credentials: 

178 headers["Access-Control-Allow-Credentials"] = "true" 

179 

180 if self._expose_headers: 

181 headers["Access-Control-Expose-Headers"] = ", ".join(sorted(self._expose_headers)) 

182 

183 return headers 

184 

185 def is_preflight(self, method: str, headers: Optional[List[str]] = None) -> bool: 

186 """Check if a request is a CORS preflight request.""" 

187 if method.upper() != "OPTIONS": 

188 return False 

189 # Preflight requires Origin header + either non-safelisted method or custom header 

190 # (Origin header presence is assumed by the caller) 

191 return True