Coverage for agentos/enterprise/auth.py: 53%

207 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS Enterprise — SSO & RBAC. 

3 

4功能: 

5 - RBAC 角色模型(admin / developer / viewer / agent) 

6 - 权限定义与校验 

7 - SSO 集成接口(OIDC / SAML 抽象) 

8 - JWT Token 签发与验证 

9 - 会话管理 

10""" 

11 

12from __future__ import annotations 

13 

14import hashlib 

15import hmac 

16import json 

17import time 

18from dataclasses import dataclass, field 

19from enum import Enum 

20from typing import Optional 

21 

22 

23# ── 权限系统 ── 

24 

25 

26class Permission(str, Enum): 

27 """细粒度权限定义。""" 

28 # Agent 

29 AGENT_CREATE = "agent:create" 

30 AGENT_READ = "agent:read" 

31 AGENT_UPDATE = "agent:update" 

32 AGENT_DELETE = "agent:delete" 

33 AGENT_RUN = "agent:run" 

34 # Tools 

35 TOOLS_LIST = "tools:list" 

36 TOOLS_EXECUTE = "tools:execute" 

37 TOOLS_MANAGE = "tools:manage" 

38 # API Keys 

39 KEYS_CREATE = "keys:create" 

40 KEYS_READ = "keys:read" 

41 KEYS_REVOKE = "keys:revoke" 

42 # Tenants 

43 TENANT_READ = "tenant:read" 

44 TENANT_MANAGE = "tenant:manage" 

45 # Audit 

46 AUDIT_READ = "audit:read" 

47 AUDIT_EXPORT = "audit:export" 

48 # Admin 

49 ADMIN_ALL = "admin:*" 

50 SYSTEM_CONFIG = "system:config" 

51 

52 

53class Role(str, Enum): 

54 """预定义角色。""" 

55 ADMIN = "admin" 

56 DEVELOPER = "developer" 

57 VIEWER = "viewer" 

58 AGENT = "agent" 

59 

60 

61# 角色权限映射 

62ROLE_PERMISSIONS: dict[Role, set[Permission]] = { 

63 Role.ADMIN: set(Permission), # 全部权限 

64 Role.DEVELOPER: { 

65 Permission.AGENT_CREATE, Permission.AGENT_READ, Permission.AGENT_UPDATE, 

66 Permission.AGENT_RUN, 

67 Permission.TOOLS_LIST, Permission.TOOLS_EXECUTE, 

68 Permission.KEYS_CREATE, Permission.KEYS_READ, 

69 Permission.AUDIT_READ, 

70 }, 

71 Role.VIEWER: { 

72 Permission.AGENT_READ, 

73 Permission.TOOLS_LIST, 

74 Permission.KEYS_READ, 

75 Permission.AUDIT_READ, 

76 Permission.TENANT_READ, 

77 }, 

78 Role.AGENT: { 

79 Permission.AGENT_RUN, 

80 Permission.TOOLS_EXECUTE, 

81 }, 

82} 

83 

84 

85@dataclass 

86class User: 

87 """用户实体。""" 

88 user_id: str 

89 username: str 

90 email: str 

91 roles: list[Role] 

92 tenant_id: str 

93 custom_permissions: set[Permission] = field(default_factory=set) 

94 disabled: bool = False 

95 created_at: float = field(default_factory=time.time) 

96 metadata: dict = field(default_factory=dict) 

97 

98 

99class RBACEngine: 

100 """RBAC 权限引擎。 

101 

102 特性: 

103 - 角色 + 自定义权限叠加 

104 - 权限继承(admin 拥有全部) 

105 - 批量权限检查 

106 - 权限审计日志 

107 """ 

108 

109 def __init__(self): 

110 self._custom_roles: dict[str, set[Permission]] = {} 

111 

112 def get_permissions(self, user: User) -> set[Permission]: 

113 """获取用户的所有有效权限。""" 

114 if user.disabled: 

115 return set() 

116 

117 perms: set[Permission] = set(user.custom_permissions) 

118 

119 for role in user.roles: 

120 perms |= ROLE_PERMISSIONS.get(role, set()) 

121 

122 # Admin 自动获得全部 

123 if Role.ADMIN in user.roles: 

124 perms = set(Permission) 

125 

126 return perms 

127 

128 def check_permission(self, user: User, permission: Permission) -> bool: 

129 """检查用户是否有某权限。""" 

130 return permission in self.get_permissions(user) 

131 

132 def check_permissions(self, user: User, permissions: list[Permission]) -> dict[Permission, bool]: 

133 """批量权限检查。""" 

134 user_perms = self.get_permissions(user) 

135 return {p: p in user_perms for p in permissions} 

136 

137 def has_any(self, user: User, permissions: list[Permission]) -> bool: 

138 """用户是否拥有任一权限。""" 

139 user_perms = self.get_permissions(user) 

140 return bool(user_perms & set(permissions)) 

141 

142 def has_all(self, user: User, permissions: list[Permission]) -> bool: 

143 """用户是否拥有全部权限。""" 

144 user_perms = self.get_permissions(user) 

145 return set(permissions).issubset(user_perms) 

146 

147 def register_custom_role(self, name: str, permissions: set[Permission]): 

148 """注册自定义角色。""" 

149 self._custom_roles[name] = permissions 

150 

151 def get_role_permissions(self, role: Role) -> set[Permission]: 

152 return ROLE_PERMISSIONS.get(role, set()) 

153 

154 

155# ── SSO 集成 ── 

156 

157 

158@dataclass 

159class OIDCConfig: 

160 """OIDC 提供商配置。""" 

161 issuer: str # 如 "https://accounts.google.com" 

162 client_id: str 

163 client_secret: str 

164 redirect_uri: str 

165 scopes: list[str] = field(default_factory=lambda: ["openid", "email", "profile"]) 

166 authorization_endpoint: str = "" 

167 token_endpoint: str = "" 

168 userinfo_endpoint: str = "" 

169 jwks_uri: str = "" 

170 

171 

172@dataclass 

173class SAMLConfig: 

174 """SAML 提供商配置。""" 

175 idp_entity_id: str 

176 idp_sso_url: str 

177 idp_certificate: str 

178 sp_entity_id: str 

179 sp_acs_url: str 

180 name_id_format: str = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress" 

181 

182 

183@dataclass 

184class SSOUser: 

185 """SSO 返回的用户信息。""" 

186 external_id: str 

187 email: str 

188 display_name: str 

189 provider: str # "oidc" / "saml" 

190 raw_claims: dict = field(default_factory=dict) 

191 

192 

193class SSOProvider: 

194 """SSO 抽象层 — OIDC / SAML 统一接口。""" 

195 

196 @staticmethod 

197 def build_oidc_login_url(config: OIDCConfig, state: str = "", nonce: str = "") -> str: 

198 """构建 OIDC 登录 URL。""" 

199 import urllib.parse 

200 params = { 

201 "response_type": "code", 

202 "client_id": config.client_id, 

203 "redirect_uri": config.redirect_uri, 

204 "scope": " ".join(config.scopes), 

205 "state": state or _rand_str(16), 

206 "nonce": nonce or _rand_str(16), 

207 } 

208 ep = config.authorization_endpoint or f"{config.issuer.rstrip('/')}/authorize" 

209 return f"{ep}?{urllib.parse.urlencode(params)}" 

210 

211 @staticmethod 

212 def build_saml_login_url(config: SAMLConfig, relay_state: str = "") -> str: 

213 """构建 SAML 登录 URL(SAMLRequest Base64)。""" 

214 import base64 

215 import uuid 

216 saml_request = ( 

217 f'<?xml version="1.0" encoding="UTF-8"?>' 

218 f'<samlp:AuthnRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"' 

219 f' ID="_{uuid.uuid4().hex}" Version="2.0"' 

220 f' IssueInstant="{time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())}"' 

221 f' Destination="{config.idp_sso_url}"' 

222 f' AssertionConsumerServiceURL="{config.sp_acs_url}">' 

223 f'<saml:Issuer xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion">' 

224 f'{config.sp_entity_id}</saml:Issuer>' 

225 f'</samlp:AuthnRequest>' 

226 ) 

227 encoded = base64.b64encode(saml_request.encode()).decode() 

228 import urllib.parse 

229 params = {"SAMLRequest": encoded} 

230 if relay_state: 

231 params["RelayState"] = relay_state 

232 return f"{config.idp_sso_url}?{urllib.parse.urlencode(params)}" 

233 

234 @staticmethod 

235 async def exchange_oidc_code(config: OIDCConfig, code: str) -> Optional[SSOUser]: 

236 """用 OIDC authorization_code 交换 token 并获取用户信息。(需要 httpx)""" 

237 try: 

238 import httpx 

239 except ImportError: 

240 return None 

241 

242 token_ep = config.token_endpoint or f"{config.issuer.rstrip('/')}/token" 

243 async with httpx.AsyncClient() as client: 

244 resp = await client.post(token_ep, data={ 

245 "grant_type": "authorization_code", 

246 "code": code, 

247 "redirect_uri": config.redirect_uri, 

248 "client_id": config.client_id, 

249 "client_secret": config.client_secret, 

250 }) 

251 if resp.status_code != 200: 

252 return None 

253 token_data = resp.json() 

254 access_token = token_data.get("access_token") 

255 

256 userinfo_ep = config.userinfo_endpoint or f"{config.issuer.rstrip('/')}/userinfo" 

257 resp2 = await client.get(userinfo_ep, headers={ 

258 "Authorization": f"Bearer {access_token}", 

259 }) 

260 if resp2.status_code != 200: 

261 return None 

262 info = resp2.json() 

263 return SSOUser( 

264 external_id=info.get("sub", ""), 

265 email=info.get("email", ""), 

266 display_name=info.get("name", info.get("preferred_username", "")), 

267 provider="oidc", 

268 raw_claims=info, 

269 ) 

270 return None 

271 

272 

273# ── 会话管理 ── 

274 

275 

276@dataclass 

277class Session: 

278 """用户会话。""" 

279 session_id: str 

280 user_id: str 

281 tenant_id: str 

282 roles: list[Role] 

283 created_at: float = field(default_factory=time.time) 

284 expires_at: float = field(default_factory=lambda: time.time() + 3600) # 1 小时 

285 ip_address: str = "" 

286 user_agent: str = "" 

287 

288 def is_expired(self) -> bool: 

289 return time.time() > self.expires_at 

290 

291 

292class SessionStore: 

293 """内存会话存储(生产环境应替换为 Redis)。""" 

294 

295 def __init__(self): 

296 self._sessions: dict[str, Session] = {} 

297 

298 def create(self, user: User, ip: str = "", ua: str = "", ttl: int = 3600) -> Session: 

299 import uuid 

300 session = Session( 

301 session_id=f"sess_{uuid.uuid4().hex[:16]}", 

302 user_id=user.user_id, 

303 tenant_id=user.tenant_id, 

304 roles=user.roles, 

305 expires_at=time.time() + ttl, 

306 ip_address=ip, 

307 user_agent=ua, 

308 ) 

309 self._sessions[session.session_id] = session 

310 return session 

311 

312 def get(self, session_id: str) -> Optional[Session]: 

313 s = self._sessions.get(session_id) 

314 if s and s.is_expired(): 

315 del self._sessions[session_id] 

316 return None 

317 return s 

318 

319 def revoke(self, session_id: str): 

320 self._sessions.pop(session_id, None) 

321 

322 def revoke_user_sessions(self, user_id: str): 

323 to_remove = [sid for sid, s in self._sessions.items() if s.user_id == user_id] 

324 for sid in to_remove: 

325 del self._sessions[sid] 

326 

327 def stats(self) -> dict: 

328 active = sum(1 for s in self._sessions.values() if not s.is_expired()) 

329 return {"total": len(self._sessions), "active": active} 

330 

331 

332# ── JWT ── 

333 

334 

335class JWTManager: 

336 """简易 JWT 签发/验证(无外部依赖)。 

337 

338 生产环境建议使用 PyJWT / jwcrypto。 

339 """ 

340 

341 def __init__(self, secret: str): 

342 self.secret = secret 

343 

344 def encode(self, payload: dict, ttl: int = 3600) -> str: 

345 """签发 JWT。""" 

346 import base64 

347 header = {"alg": "HS256", "typ": "JWT"} 

348 claims = { 

349 **payload, 

350 "iat": int(time.time()), 

351 "exp": int(time.time()) + ttl, 

352 } 

353 segments = [ 

354 base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode(), 

355 base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode(), 

356 ] 

357 signing_input = ".".join(segments) 

358 sig = hmac.new(self.secret.encode(), signing_input.encode(), hashlib.sha256).digest() 

359 segments.append(base64.urlsafe_b64encode(sig).rstrip(b"=").decode()) 

360 return ".".join(segments) 

361 

362 def decode(self, token: str) -> Optional[dict]: 

363 """验证并解码 JWT。""" 

364 import base64 

365 try: 

366 parts = token.split(".") 

367 if len(parts) != 3: 

368 return None 

369 

370 header_b64, payload_b64, sig_b64 = parts 

371 signing_input = f"{header_b64}.{payload_b64}" 

372 

373 # Verify signature 

374 expected_sig = base64.urlsafe_b64encode( 

375 hmac.new(self.secret.encode(), signing_input.encode(), hashlib.sha256).digest() 

376 ).rstrip(b"=").decode() 

377 

378 if not hmac.compare_digest(sig_b64, expected_sig): 

379 return None 

380 

381 # Decode payload 

382 payload = json.loads( 

383 base64.urlsafe_b64decode(payload_b64 + "==").decode() 

384 ) 

385 

386 # Check expiration 

387 if payload.get("exp", 0) < time.time(): 

388 return None 

389 

390 return payload 

391 except Exception: 

392 return None 

393 

394 

395# ── 工具函数 ── 

396 

397 

398def _rand_str(n: int) -> str: 

399 import secrets 

400 return secrets.token_hex(n // 2 + 1)[:n] 

401 

402 

403def require_permission(permission: Permission): 

404 """装饰器:要求调用者拥有指定权限。(示例用途)""" 

405 def decorator(func): 

406 def wrapper(*args, **kwargs): 

407 # 实际使用时会从上下文获取当前用户 

408 raise NotImplementedError("权限检查需在框架中间件中实现") 

409 return wrapper 

410 return decorator