Coverage for agentos/enterprise/auth.py: 53%
207 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS Enterprise — SSO & RBAC.
4功能:
5 - RBAC 角色模型(admin / developer / viewer / agent)
6 - 权限定义与校验
7 - SSO 集成接口(OIDC / SAML 抽象)
8 - JWT Token 签发与验证
9 - 会话管理
10"""
12from __future__ import annotations
14import hashlib
15import hmac
16import json
17import time
18from dataclasses import dataclass, field
19from enum import Enum
20from typing import Optional
23# ── 权限系统 ──
26class Permission(str, Enum):
27 """细粒度权限定义。"""
28 # Agent
29 AGENT_CREATE = "agent:create"
30 AGENT_READ = "agent:read"
31 AGENT_UPDATE = "agent:update"
32 AGENT_DELETE = "agent:delete"
33 AGENT_RUN = "agent:run"
34 # Tools
35 TOOLS_LIST = "tools:list"
36 TOOLS_EXECUTE = "tools:execute"
37 TOOLS_MANAGE = "tools:manage"
38 # API Keys
39 KEYS_CREATE = "keys:create"
40 KEYS_READ = "keys:read"
41 KEYS_REVOKE = "keys:revoke"
42 # Tenants
43 TENANT_READ = "tenant:read"
44 TENANT_MANAGE = "tenant:manage"
45 # Audit
46 AUDIT_READ = "audit:read"
47 AUDIT_EXPORT = "audit:export"
48 # Admin
49 ADMIN_ALL = "admin:*"
50 SYSTEM_CONFIG = "system:config"
53class Role(str, Enum):
54 """预定义角色。"""
55 ADMIN = "admin"
56 DEVELOPER = "developer"
57 VIEWER = "viewer"
58 AGENT = "agent"
61# 角色权限映射
62ROLE_PERMISSIONS: dict[Role, set[Permission]] = {
63 Role.ADMIN: set(Permission), # 全部权限
64 Role.DEVELOPER: {
65 Permission.AGENT_CREATE, Permission.AGENT_READ, Permission.AGENT_UPDATE,
66 Permission.AGENT_RUN,
67 Permission.TOOLS_LIST, Permission.TOOLS_EXECUTE,
68 Permission.KEYS_CREATE, Permission.KEYS_READ,
69 Permission.AUDIT_READ,
70 },
71 Role.VIEWER: {
72 Permission.AGENT_READ,
73 Permission.TOOLS_LIST,
74 Permission.KEYS_READ,
75 Permission.AUDIT_READ,
76 Permission.TENANT_READ,
77 },
78 Role.AGENT: {
79 Permission.AGENT_RUN,
80 Permission.TOOLS_EXECUTE,
81 },
82}
85@dataclass
86class User:
87 """用户实体。"""
88 user_id: str
89 username: str
90 email: str
91 roles: list[Role]
92 tenant_id: str
93 custom_permissions: set[Permission] = field(default_factory=set)
94 disabled: bool = False
95 created_at: float = field(default_factory=time.time)
96 metadata: dict = field(default_factory=dict)
99class RBACEngine:
100 """RBAC 权限引擎。
102 特性:
103 - 角色 + 自定义权限叠加
104 - 权限继承(admin 拥有全部)
105 - 批量权限检查
106 - 权限审计日志
107 """
109 def __init__(self):
110 self._custom_roles: dict[str, set[Permission]] = {}
112 def get_permissions(self, user: User) -> set[Permission]:
113 """获取用户的所有有效权限。"""
114 if user.disabled:
115 return set()
117 perms: set[Permission] = set(user.custom_permissions)
119 for role in user.roles:
120 perms |= ROLE_PERMISSIONS.get(role, set())
122 # Admin 自动获得全部
123 if Role.ADMIN in user.roles:
124 perms = set(Permission)
126 return perms
128 def check_permission(self, user: User, permission: Permission) -> bool:
129 """检查用户是否有某权限。"""
130 return permission in self.get_permissions(user)
132 def check_permissions(self, user: User, permissions: list[Permission]) -> dict[Permission, bool]:
133 """批量权限检查。"""
134 user_perms = self.get_permissions(user)
135 return {p: p in user_perms for p in permissions}
137 def has_any(self, user: User, permissions: list[Permission]) -> bool:
138 """用户是否拥有任一权限。"""
139 user_perms = self.get_permissions(user)
140 return bool(user_perms & set(permissions))
142 def has_all(self, user: User, permissions: list[Permission]) -> bool:
143 """用户是否拥有全部权限。"""
144 user_perms = self.get_permissions(user)
145 return set(permissions).issubset(user_perms)
147 def register_custom_role(self, name: str, permissions: set[Permission]):
148 """注册自定义角色。"""
149 self._custom_roles[name] = permissions
151 def get_role_permissions(self, role: Role) -> set[Permission]:
152 return ROLE_PERMISSIONS.get(role, set())
155# ── SSO 集成 ──
158@dataclass
159class OIDCConfig:
160 """OIDC 提供商配置。"""
161 issuer: str # 如 "https://accounts.google.com"
162 client_id: str
163 client_secret: str
164 redirect_uri: str
165 scopes: list[str] = field(default_factory=lambda: ["openid", "email", "profile"])
166 authorization_endpoint: str = ""
167 token_endpoint: str = ""
168 userinfo_endpoint: str = ""
169 jwks_uri: str = ""
172@dataclass
173class SAMLConfig:
174 """SAML 提供商配置。"""
175 idp_entity_id: str
176 idp_sso_url: str
177 idp_certificate: str
178 sp_entity_id: str
179 sp_acs_url: str
180 name_id_format: str = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
183@dataclass
184class SSOUser:
185 """SSO 返回的用户信息。"""
186 external_id: str
187 email: str
188 display_name: str
189 provider: str # "oidc" / "saml"
190 raw_claims: dict = field(default_factory=dict)
193class SSOProvider:
194 """SSO 抽象层 — OIDC / SAML 统一接口。"""
196 @staticmethod
197 def build_oidc_login_url(config: OIDCConfig, state: str = "", nonce: str = "") -> str:
198 """构建 OIDC 登录 URL。"""
199 import urllib.parse
200 params = {
201 "response_type": "code",
202 "client_id": config.client_id,
203 "redirect_uri": config.redirect_uri,
204 "scope": " ".join(config.scopes),
205 "state": state or _rand_str(16),
206 "nonce": nonce or _rand_str(16),
207 }
208 ep = config.authorization_endpoint or f"{config.issuer.rstrip('/')}/authorize"
209 return f"{ep}?{urllib.parse.urlencode(params)}"
211 @staticmethod
212 def build_saml_login_url(config: SAMLConfig, relay_state: str = "") -> str:
213 """构建 SAML 登录 URL(SAMLRequest Base64)。"""
214 import base64
215 import uuid
216 saml_request = (
217 f'<?xml version="1.0" encoding="UTF-8"?>'
218 f'<samlp:AuthnRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"'
219 f' ID="_{uuid.uuid4().hex}" Version="2.0"'
220 f' IssueInstant="{time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())}"'
221 f' Destination="{config.idp_sso_url}"'
222 f' AssertionConsumerServiceURL="{config.sp_acs_url}">'
223 f'<saml:Issuer xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion">'
224 f'{config.sp_entity_id}</saml:Issuer>'
225 f'</samlp:AuthnRequest>'
226 )
227 encoded = base64.b64encode(saml_request.encode()).decode()
228 import urllib.parse
229 params = {"SAMLRequest": encoded}
230 if relay_state:
231 params["RelayState"] = relay_state
232 return f"{config.idp_sso_url}?{urllib.parse.urlencode(params)}"
234 @staticmethod
235 async def exchange_oidc_code(config: OIDCConfig, code: str) -> Optional[SSOUser]:
236 """用 OIDC authorization_code 交换 token 并获取用户信息。(需要 httpx)"""
237 try:
238 import httpx
239 except ImportError:
240 return None
242 token_ep = config.token_endpoint or f"{config.issuer.rstrip('/')}/token"
243 async with httpx.AsyncClient() as client:
244 resp = await client.post(token_ep, data={
245 "grant_type": "authorization_code",
246 "code": code,
247 "redirect_uri": config.redirect_uri,
248 "client_id": config.client_id,
249 "client_secret": config.client_secret,
250 })
251 if resp.status_code != 200:
252 return None
253 token_data = resp.json()
254 access_token = token_data.get("access_token")
256 userinfo_ep = config.userinfo_endpoint or f"{config.issuer.rstrip('/')}/userinfo"
257 resp2 = await client.get(userinfo_ep, headers={
258 "Authorization": f"Bearer {access_token}",
259 })
260 if resp2.status_code != 200:
261 return None
262 info = resp2.json()
263 return SSOUser(
264 external_id=info.get("sub", ""),
265 email=info.get("email", ""),
266 display_name=info.get("name", info.get("preferred_username", "")),
267 provider="oidc",
268 raw_claims=info,
269 )
270 return None
273# ── 会话管理 ──
276@dataclass
277class Session:
278 """用户会话。"""
279 session_id: str
280 user_id: str
281 tenant_id: str
282 roles: list[Role]
283 created_at: float = field(default_factory=time.time)
284 expires_at: float = field(default_factory=lambda: time.time() + 3600) # 1 小时
285 ip_address: str = ""
286 user_agent: str = ""
288 def is_expired(self) -> bool:
289 return time.time() > self.expires_at
292class SessionStore:
293 """内存会话存储(生产环境应替换为 Redis)。"""
295 def __init__(self):
296 self._sessions: dict[str, Session] = {}
298 def create(self, user: User, ip: str = "", ua: str = "", ttl: int = 3600) -> Session:
299 import uuid
300 session = Session(
301 session_id=f"sess_{uuid.uuid4().hex[:16]}",
302 user_id=user.user_id,
303 tenant_id=user.tenant_id,
304 roles=user.roles,
305 expires_at=time.time() + ttl,
306 ip_address=ip,
307 user_agent=ua,
308 )
309 self._sessions[session.session_id] = session
310 return session
312 def get(self, session_id: str) -> Optional[Session]:
313 s = self._sessions.get(session_id)
314 if s and s.is_expired():
315 del self._sessions[session_id]
316 return None
317 return s
319 def revoke(self, session_id: str):
320 self._sessions.pop(session_id, None)
322 def revoke_user_sessions(self, user_id: str):
323 to_remove = [sid for sid, s in self._sessions.items() if s.user_id == user_id]
324 for sid in to_remove:
325 del self._sessions[sid]
327 def stats(self) -> dict:
328 active = sum(1 for s in self._sessions.values() if not s.is_expired())
329 return {"total": len(self._sessions), "active": active}
332# ── JWT ──
335class JWTManager:
336 """简易 JWT 签发/验证(无外部依赖)。
338 生产环境建议使用 PyJWT / jwcrypto。
339 """
341 def __init__(self, secret: str):
342 self.secret = secret
344 def encode(self, payload: dict, ttl: int = 3600) -> str:
345 """签发 JWT。"""
346 import base64
347 header = {"alg": "HS256", "typ": "JWT"}
348 claims = {
349 **payload,
350 "iat": int(time.time()),
351 "exp": int(time.time()) + ttl,
352 }
353 segments = [
354 base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode(),
355 base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode(),
356 ]
357 signing_input = ".".join(segments)
358 sig = hmac.new(self.secret.encode(), signing_input.encode(), hashlib.sha256).digest()
359 segments.append(base64.urlsafe_b64encode(sig).rstrip(b"=").decode())
360 return ".".join(segments)
362 def decode(self, token: str) -> Optional[dict]:
363 """验证并解码 JWT。"""
364 import base64
365 try:
366 parts = token.split(".")
367 if len(parts) != 3:
368 return None
370 header_b64, payload_b64, sig_b64 = parts
371 signing_input = f"{header_b64}.{payload_b64}"
373 # Verify signature
374 expected_sig = base64.urlsafe_b64encode(
375 hmac.new(self.secret.encode(), signing_input.encode(), hashlib.sha256).digest()
376 ).rstrip(b"=").decode()
378 if not hmac.compare_digest(sig_b64, expected_sig):
379 return None
381 # Decode payload
382 payload = json.loads(
383 base64.urlsafe_b64decode(payload_b64 + "==").decode()
384 )
386 # Check expiration
387 if payload.get("exp", 0) < time.time():
388 return None
390 return payload
391 except Exception:
392 return None
395# ── 工具函数 ──
398def _rand_str(n: int) -> str:
399 import secrets
400 return secrets.token_hex(n // 2 + 1)[:n]
403def require_permission(permission: Permission):
404 """装饰器:要求调用者拥有指定权限。(示例用途)"""
405 def decorator(func):
406 def wrapper(*args, **kwargs):
407 # 实际使用时会从上下文获取当前用户
408 raise NotImplementedError("权限检查需在框架中间件中实现")
409 return wrapper
410 return decorator