Coverage for agentos/api/middleware.py: 57%
74 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""AgentOS API middleware — request/response processing pipeline.
3Provides authentication, CORS, request tracing, request-ID injection,
4and rate-limiting middleware for the AgentOS API server.
5"""
7from __future__ import annotations
9import time
10import uuid
11from dataclasses import dataclass, field
12from typing import Callable, Optional
15# ── Request ID ────────────────────────────────────────────────────────────────
18@dataclass
19class RequestContext:
20 """请求上下文。"""
21 request_id: str = ""
22 start_time: float = 0.0
23 method: str = ""
24 path: str = ""
25 client_ip: str = ""
26 user_agent: str = ""
28 @property
29 def elapsed_ms(self) -> float:
30 return (time.monotonic() - self.start_time) * 1000
33class RequestIDMiddleware:
34 """Inject X-Request-ID into every request and propagate it."""
36 def __init__(self, header: str = "X-Request-ID"):
37 self.header = header
39 def process_request(self, headers: dict) -> RequestContext:
40 rid = headers.get(self.header.lower(), headers.get(self.header, ""))
41 if not rid:
42 rid = str(uuid.uuid4())[:12]
43 return RequestContext(request_id=rid, start_time=time.monotonic())
46# ── CORS ──────────────────────────────────────────────────────────────────────
49@dataclass
50class CORSConfig:
51 """CORS 配置。"""
52 allow_origins: list[str] = field(default_factory=lambda: ["*"])
53 allow_methods: list[str] = field(default_factory=lambda: ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
54 allow_headers: list[str] = field(default_factory=lambda: ["Content-Type", "Authorization", "X-Request-ID"])
55 expose_headers: list[str] = field(default_factory=list)
56 max_age: int = 86400
57 allow_credentials: bool = False
60class CORSMiddleware:
61 """Add CORS headers to every response."""
63 def __init__(self, config: Optional[CORSConfig] = None):
64 self.config = config or CORSConfig()
66 def apply(self, response_headers: dict) -> dict:
67 origin = self.config.allow_origins[0] if self.config.allow_origins else ""
68 response_headers["Access-Control-Allow-Origin"] = origin
69 response_headers["Access-Control-Allow-Methods"] = ", ".join(self.config.allow_methods)
70 response_headers["Access-Control-Allow-Headers"] = ", ".join(self.config.allow_headers)
71 if self.config.expose_headers:
72 response_headers["Access-Control-Expose-Headers"] = ", ".join(self.config.expose_headers)
73 response_headers["Access-Control-Max-Age"] = str(self.config.max_age)
74 if self.config.allow_credentials:
75 response_headers["Access-Control-Allow-Credentials"] = "true"
76 return response_headers
79# ── Auth ──────────────────────────────────────────────────────────────────────
82@dataclass
83class AuthConfig:
84 """认证配置。"""
85 api_key_header: str = "X-API-Key"
86 api_key: str = ""
87 enabled: bool = True
90class AuthMiddleware:
91 """Simple API-key authentication middleware."""
93 def __init__(self, config: Optional[AuthConfig] = None):
94 self.config = config or AuthConfig()
96 def authenticate(self, headers: dict) -> tuple[bool, str]:
97 """Return (authorized, message)."""
98 if not self.config.enabled or not self.config.api_key:
99 return True, ""
100 provided = headers.get(self.config.api_key_header.lower(), headers.get(self.config.api_key_header, ""))
101 if provided != self.config.api_key:
102 return False, "Invalid or missing API key"
103 return True, ""
106# ── Request logger ────────────────────────────────────────────────────────────
109class RequestLogMiddleware:
110 """Log every request with method, path, status, and elapsed time."""
112 def __init__(self, logger: Optional[Callable[[str], None]] = None):
113 self._log = logger or print
115 def log(self, ctx: RequestContext, status: int) -> str:
116 msg = (
117 f"[{ctx.request_id}] {ctx.method} {ctx.path} "
118 f"→ {status} ({ctx.elapsed_ms:.1f}ms)"
119 )
120 self._log(msg)
121 return msg
124# ── Middleware stack ──────────────────────────────────────────────────────────
127class MiddlewareStack:
128 """Ordered middleware pipeline for the AgentOS API."""
130 def __init__(
131 self,
132 cors: Optional[CORSMiddleware] = None,
133 auth: Optional[AuthMiddleware] = None,
134 req_log: Optional[RequestLogMiddleware] = None,
135 req_id: Optional[RequestIDMiddleware] = None,
136 ):
137 self.cors = cors or CORSMiddleware()
138 self.auth = auth or AuthMiddleware()
139 self.req_log = req_log or RequestLogMiddleware()
140 self.req_id = req_id or RequestIDMiddleware()