Coverage for agentos/api/middleware.py: 57%

74 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""AgentOS API middleware — request/response processing pipeline. 

2 

3Provides authentication, CORS, request tracing, request-ID injection, 

4and rate-limiting middleware for the AgentOS API server. 

5""" 

6 

7from __future__ import annotations 

8 

9import time 

10import uuid 

11from dataclasses import dataclass, field 

12from typing import Callable, Optional 

13 

14 

15# ── Request ID ──────────────────────────────────────────────────────────────── 

16 

17 

18@dataclass 

19class RequestContext: 

20 """请求上下文。""" 

21 request_id: str = "" 

22 start_time: float = 0.0 

23 method: str = "" 

24 path: str = "" 

25 client_ip: str = "" 

26 user_agent: str = "" 

27 

28 @property 

29 def elapsed_ms(self) -> float: 

30 return (time.monotonic() - self.start_time) * 1000 

31 

32 

33class RequestIDMiddleware: 

34 """Inject X-Request-ID into every request and propagate it.""" 

35 

36 def __init__(self, header: str = "X-Request-ID"): 

37 self.header = header 

38 

39 def process_request(self, headers: dict) -> RequestContext: 

40 rid = headers.get(self.header.lower(), headers.get(self.header, "")) 

41 if not rid: 

42 rid = str(uuid.uuid4())[:12] 

43 return RequestContext(request_id=rid, start_time=time.monotonic()) 

44 

45 

46# ── CORS ────────────────────────────────────────────────────────────────────── 

47 

48 

49@dataclass 

50class CORSConfig: 

51 """CORS 配置。""" 

52 allow_origins: list[str] = field(default_factory=lambda: ["*"]) 

53 allow_methods: list[str] = field(default_factory=lambda: ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) 

54 allow_headers: list[str] = field(default_factory=lambda: ["Content-Type", "Authorization", "X-Request-ID"]) 

55 expose_headers: list[str] = field(default_factory=list) 

56 max_age: int = 86400 

57 allow_credentials: bool = False 

58 

59 

60class CORSMiddleware: 

61 """Add CORS headers to every response.""" 

62 

63 def __init__(self, config: Optional[CORSConfig] = None): 

64 self.config = config or CORSConfig() 

65 

66 def apply(self, response_headers: dict) -> dict: 

67 origin = self.config.allow_origins[0] if self.config.allow_origins else "" 

68 response_headers["Access-Control-Allow-Origin"] = origin 

69 response_headers["Access-Control-Allow-Methods"] = ", ".join(self.config.allow_methods) 

70 response_headers["Access-Control-Allow-Headers"] = ", ".join(self.config.allow_headers) 

71 if self.config.expose_headers: 

72 response_headers["Access-Control-Expose-Headers"] = ", ".join(self.config.expose_headers) 

73 response_headers["Access-Control-Max-Age"] = str(self.config.max_age) 

74 if self.config.allow_credentials: 

75 response_headers["Access-Control-Allow-Credentials"] = "true" 

76 return response_headers 

77 

78 

79# ── Auth ────────────────────────────────────────────────────────────────────── 

80 

81 

82@dataclass 

83class AuthConfig: 

84 """认证配置。""" 

85 api_key_header: str = "X-API-Key" 

86 api_key: str = "" 

87 enabled: bool = True 

88 

89 

90class AuthMiddleware: 

91 """Simple API-key authentication middleware.""" 

92 

93 def __init__(self, config: Optional[AuthConfig] = None): 

94 self.config = config or AuthConfig() 

95 

96 def authenticate(self, headers: dict) -> tuple[bool, str]: 

97 """Return (authorized, message).""" 

98 if not self.config.enabled or not self.config.api_key: 

99 return True, "" 

100 provided = headers.get(self.config.api_key_header.lower(), headers.get(self.config.api_key_header, "")) 

101 if provided != self.config.api_key: 

102 return False, "Invalid or missing API key" 

103 return True, "" 

104 

105 

106# ── Request logger ──────────────────────────────────────────────────────────── 

107 

108 

109class RequestLogMiddleware: 

110 """Log every request with method, path, status, and elapsed time.""" 

111 

112 def __init__(self, logger: Optional[Callable[[str], None]] = None): 

113 self._log = logger or print 

114 

115 def log(self, ctx: RequestContext, status: int) -> str: 

116 msg = ( 

117 f"[{ctx.request_id}] {ctx.method} {ctx.path} " 

118 f"→ {status} ({ctx.elapsed_ms:.1f}ms)" 

119 ) 

120 self._log(msg) 

121 return msg 

122 

123 

124# ── Middleware stack ────────────────────────────────────────────────────────── 

125 

126 

127class MiddlewareStack: 

128 """Ordered middleware pipeline for the AgentOS API.""" 

129 

130 def __init__( 

131 self, 

132 cors: Optional[CORSMiddleware] = None, 

133 auth: Optional[AuthMiddleware] = None, 

134 req_log: Optional[RequestLogMiddleware] = None, 

135 req_id: Optional[RequestIDMiddleware] = None, 

136 ): 

137 self.cors = cors or CORSMiddleware() 

138 self.auth = auth or AuthMiddleware() 

139 self.req_log = req_log or RequestLogMiddleware() 

140 self.req_id = req_id or RequestIDMiddleware()