Coverage for src / fastapi_authly / auth.py: 17%
127 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-20 11:54 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-20 11:54 +0800
1"""认证模块 - 提供可重写的 FastAPI 路由集合。"""
3from typing import Any, Dict, Optional
5from fastapi import APIRouter, Depends, HTTPException, status
6from fastapi.security import OAuth2PasswordBearer
7from jose import JWTError
8from pydantic import EmailStr
10from .core.config import AuthConfig, AuthDependencyConfig
11from .core.security import BcryptPasswordHasher, JWTTokenService
12from .interfaces import Mailer, PasswordHasher, TokenService, UserRepository
13from .schemas.user import (
14 PasswordReset,
15 PasswordResetRequest,
16 Token,
17 TokenData,
18 UserCreate,
19 UserPublic,
20 UserLogin
21)
22from fastapi_authly.contrib.tortoise_pg import TortoiseUserRepository
25def _merge_config(config: Optional[AuthConfig]) -> AuthConfig:
26 """Environment-driven config, overridden by user-supplied config when provided."""
27 env_config = AuthConfig()
28 if config is None:
29 return env_config
30 # user config overrides environment values
31 return env_config.model_copy(update=config.model_dump(exclude_none=True))
34class AuthModule:
35 """认证模块:默认实现最小逻辑,所有外部依赖均可替换。"""
37 def __init__(
38 self,
39 config: Optional[AuthConfig] = None,
40 dependencies: Optional[AuthDependencyConfig] = None,
41 ):
42 self.config = _merge_config(config)
43 self.dependencies = dependencies or AuthDependencyConfig()
45 self.password_hasher: PasswordHasher = (
46 self.dependencies.password_hasher or BcryptPasswordHasher()
47 )
48 self.token_service: TokenService = (
49 self.dependencies.token_service
50 or JWTTokenService(
51 secret_key=self.config.secret_key,
52 algorithm=self.config.algorithm,
53 access_token_expire_minutes=self.config.access_token_expire_minutes,
54 refresh_token_expire_days=self.config.refresh_token_expire_days,
55 )
56 )
57 self.user_repository = (
58 self.dependencies.user_repository
59 or TortoiseUserRepository()
60 )
62 self.mailer: Optional[Mailer] = self.dependencies.mailer
64 self.router = APIRouter(
65 prefix=self.config.router_prefix, tags=self.config.router_tags
66 )
67 self.oauth2_scheme = OAuth2PasswordBearer(
68 tokenUrl=f"{self.config.router_prefix.rstrip('/')}/{self.config.token_url}".strip(
69 "/"
70 )
71 )
73 self._setup_routes()
75 def _setup_routes(self) -> None:
76 self._add_token_routes()
77 if self.config.enable_user_registration:
78 self._add_user_routes()
79 if self.config.enable_password_recovery:
80 self._add_password_routes()
81 if self.config.enable_token_refresh:
82 self._add_refresh_routes()
84 def _ensure_user_repo(self) -> UserRepository:
85 if not self.user_repository:
86 raise HTTPException(
87 status_code=status.HTTP_501_NOT_IMPLEMENTED,
88 detail="User repository not provided. Supply a UserRepository implementation.",
89 )
90 return self.user_repository
92 def _add_token_routes(self) -> None:
93 @self.router.post("/login", response_model=Token)
94 async def login_for_access_token(
95 body: UserLogin,
96 ) -> Token:
97 repo = self._ensure_user_repo()
98 user = await repo.get_by_name(body.username)
99 if not user:
100 raise HTTPException(
101 status_code=status.HTTP_400_BAD_REQUEST,
102 detail="Incorrect username or password",
103 )
104 if not self.password_hasher.verify(body.password, getattr(user, "hashed_password", "")):
105 raise HTTPException(
106 status_code=status.HTTP_400_BAD_REQUEST,
107 detail="Incorrect username or password",
108 )
110 if not user.is_active:
111 raise HTTPException(status_code=400, detail="Inactive user")
113 user_id = str(getattr(user, "id", ""))
114 access_token = self.token_service.create_access_token(user_id)
115 refresh_token = (
116 self.token_service.create_refresh_token(user_id)
117 if self.config.enable_token_refresh
118 else None
119 )
120 return Token(access_token=access_token, refresh_token=refresh_token)
122 @self.router.post("/token/verify")
123 async def verify_token(token: str = Depends(self.oauth2_scheme)) -> Dict[str, Any]:
124 try:
125 payload = self.token_service.decode_token(token)
126 token_data = TokenData(**payload)
127 return {"valid": True, "user_id": token_data.sub, "exp": token_data.exp}
128 except JWTError:
129 raise HTTPException(
130 status_code=status.HTTP_401_UNAUTHORIZED,
131 detail="Invalid token",
132 headers={"WWW-Authenticate": "Bearer"},
133 )
135 def _add_user_routes(self) -> None:
136 @self.router.post("/register", response_model=UserPublic)
137 async def register_user(user: UserCreate) -> UserPublic:
138 repo = self._ensure_user_repo()
139 hashed_pw = self.password_hasher.hash(user.password)
140 user_to_create = user.model_copy(update={"password": hashed_pw})
141 created = await repo.create_user(user_to_create)
142 return await repo.to_public(created)
144 @self.router.get("/me", response_model=UserPublic)
145 async def get_current_user(token: str = Depends(self.oauth2_scheme)) -> UserPublic:
146 repo = self._ensure_user_repo()
147 try:
148 payload = self.token_service.decode_token(token, verify_type="access")
149 token_data = TokenData(**payload)
150 except JWTError:
151 raise HTTPException(
152 status_code=status.HTTP_401_UNAUTHORIZED,
153 detail="Could not validate credentials",
154 headers={"WWW-Authenticate": "Bearer"},
155 )
157 user = await repo.get_by_id(token_data.sub)
158 if not user:
159 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
160 return await repo.to_public(user)
162 def _add_password_routes(self) -> None:
163 @self.router.post("/password/reset-request")
164 async def request_password_reset(request: PasswordResetRequest) -> Dict[str, str]:
165 repo = self._ensure_user_repo()
166 if not self.mailer:
167 raise HTTPException(
168 status_code=status.HTTP_501_NOT_IMPLEMENTED,
169 detail="Mailer not provided. Supply a Mailer implementation.",
170 )
172 user = await repo.get_by_email(request.email)
173 if not user:
174 # Intentionally return generic message
175 return {"detail": "If the account exists, a reset email will be sent."}
177 token = self.token_service.create_refresh_token(str(getattr(user, "id", "")))
178 await self.mailer.send_password_reset(request, token)
179 return {"detail": "If the account exists, a reset email will be sent."}
181 @self.router.post("/password/reset")
182 async def reset_password(payload: PasswordReset) -> Dict[str, str]:
183 repo = self._ensure_user_repo()
184 try:
185 data = self.token_service.decode_token(payload.token)
186 user_id = data.get("sub")
187 except JWTError:
188 raise HTTPException(
189 status_code=status.HTTP_400_BAD_REQUEST,
190 detail="Invalid or expired token",
191 )
193 user = await repo.get_by_id(user_id)
194 if not user:
195 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
197 hashed = self.password_hasher.hash(payload.new_password)
199 if hasattr(repo, "update_password"):
200 await getattr(repo, "update_password")(user_id, hashed) # type: ignore[attr-defined]
201 return {"detail": "Password updated"}
203 if hasattr(user, "set_password"):
204 user.set_password(hashed)
205 return {"detail": "Password updated"}
207 raise HTTPException(
208 status_code=status.HTTP_501_NOT_IMPLEMENTED,
209 detail="Provide `update_password` on UserRepository or `set_password` on user model.",
210 )
212 def _add_refresh_routes(self) -> None:
213 @self.router.post("/token/refresh", response_model=Token)
214 async def refresh_access_token(refresh_token: str) -> Token:
215 try:
216 payload = self.token_service.decode_token(refresh_token, verify_type="refresh")
217 user_id = payload.get("sub")
218 except JWTError:
219 raise HTTPException(
220 status_code=status.HTTP_401_UNAUTHORIZED,
221 detail="Invalid refresh token",
222 headers={"WWW-Authenticate": "Bearer"},
223 )
225 access_token = self.token_service.create_access_token(user_id)
226 return Token(access_token=access_token, token_type="bearer")
228 def get_router(self) -> APIRouter:
229 return self.router
232def create_auth_router(
233 *,
234 config: Optional[AuthConfig] = None,
235 dependencies: Optional[AuthDependencyConfig] = None,
236) -> APIRouter:
237 """便捷函数:构建可配置的认证路由器。"""
238 module = AuthModule(config=config, dependencies=dependencies)
239 return module.get_router()