Coverage for src / fastapi_authly / auth.py: 17%

127 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-20 11:54 +0800

1"""认证模块 - 提供可重写的 FastAPI 路由集合。""" 

2 

3from typing import Any, Dict, Optional 

4 

5from fastapi import APIRouter, Depends, HTTPException, status 

6from fastapi.security import OAuth2PasswordBearer 

7from jose import JWTError 

8from pydantic import EmailStr 

9 

10from .core.config import AuthConfig, AuthDependencyConfig 

11from .core.security import BcryptPasswordHasher, JWTTokenService 

12from .interfaces import Mailer, PasswordHasher, TokenService, UserRepository 

13from .schemas.user import ( 

14 PasswordReset, 

15 PasswordResetRequest, 

16 Token, 

17 TokenData, 

18 UserCreate, 

19 UserPublic, 

20 UserLogin 

21) 

22from fastapi_authly.contrib.tortoise_pg import TortoiseUserRepository 

23 

24 

25def _merge_config(config: Optional[AuthConfig]) -> AuthConfig: 

26 """Environment-driven config, overridden by user-supplied config when provided.""" 

27 env_config = AuthConfig() 

28 if config is None: 

29 return env_config 

30 # user config overrides environment values 

31 return env_config.model_copy(update=config.model_dump(exclude_none=True)) 

32 

33 

34class AuthModule: 

35 """认证模块:默认实现最小逻辑,所有外部依赖均可替换。""" 

36 

37 def __init__( 

38 self, 

39 config: Optional[AuthConfig] = None, 

40 dependencies: Optional[AuthDependencyConfig] = None, 

41 ): 

42 self.config = _merge_config(config) 

43 self.dependencies = dependencies or AuthDependencyConfig() 

44 

45 self.password_hasher: PasswordHasher = ( 

46 self.dependencies.password_hasher or BcryptPasswordHasher() 

47 ) 

48 self.token_service: TokenService = ( 

49 self.dependencies.token_service 

50 or JWTTokenService( 

51 secret_key=self.config.secret_key, 

52 algorithm=self.config.algorithm, 

53 access_token_expire_minutes=self.config.access_token_expire_minutes, 

54 refresh_token_expire_days=self.config.refresh_token_expire_days, 

55 ) 

56 ) 

57 self.user_repository = ( 

58 self.dependencies.user_repository 

59 or TortoiseUserRepository() 

60 ) 

61 

62 self.mailer: Optional[Mailer] = self.dependencies.mailer 

63 

64 self.router = APIRouter( 

65 prefix=self.config.router_prefix, tags=self.config.router_tags 

66 ) 

67 self.oauth2_scheme = OAuth2PasswordBearer( 

68 tokenUrl=f"{self.config.router_prefix.rstrip('/')}/{self.config.token_url}".strip( 

69 "/" 

70 ) 

71 ) 

72 

73 self._setup_routes() 

74 

75 def _setup_routes(self) -> None: 

76 self._add_token_routes() 

77 if self.config.enable_user_registration: 

78 self._add_user_routes() 

79 if self.config.enable_password_recovery: 

80 self._add_password_routes() 

81 if self.config.enable_token_refresh: 

82 self._add_refresh_routes() 

83 

84 def _ensure_user_repo(self) -> UserRepository: 

85 if not self.user_repository: 

86 raise HTTPException( 

87 status_code=status.HTTP_501_NOT_IMPLEMENTED, 

88 detail="User repository not provided. Supply a UserRepository implementation.", 

89 ) 

90 return self.user_repository 

91 

92 def _add_token_routes(self) -> None: 

93 @self.router.post("/login", response_model=Token) 

94 async def login_for_access_token( 

95 body: UserLogin, 

96 ) -> Token: 

97 repo = self._ensure_user_repo() 

98 user = await repo.get_by_name(body.username) 

99 if not user: 

100 raise HTTPException( 

101 status_code=status.HTTP_400_BAD_REQUEST, 

102 detail="Incorrect username or password", 

103 ) 

104 if not self.password_hasher.verify(body.password, getattr(user, "hashed_password", "")): 

105 raise HTTPException( 

106 status_code=status.HTTP_400_BAD_REQUEST, 

107 detail="Incorrect username or password", 

108 ) 

109 

110 if not user.is_active: 

111 raise HTTPException(status_code=400, detail="Inactive user") 

112 

113 user_id = str(getattr(user, "id", "")) 

114 access_token = self.token_service.create_access_token(user_id) 

115 refresh_token = ( 

116 self.token_service.create_refresh_token(user_id) 

117 if self.config.enable_token_refresh 

118 else None 

119 ) 

120 return Token(access_token=access_token, refresh_token=refresh_token) 

121 

122 @self.router.post("/token/verify") 

123 async def verify_token(token: str = Depends(self.oauth2_scheme)) -> Dict[str, Any]: 

124 try: 

125 payload = self.token_service.decode_token(token) 

126 token_data = TokenData(**payload) 

127 return {"valid": True, "user_id": token_data.sub, "exp": token_data.exp} 

128 except JWTError: 

129 raise HTTPException( 

130 status_code=status.HTTP_401_UNAUTHORIZED, 

131 detail="Invalid token", 

132 headers={"WWW-Authenticate": "Bearer"}, 

133 ) 

134 

135 def _add_user_routes(self) -> None: 

136 @self.router.post("/register", response_model=UserPublic) 

137 async def register_user(user: UserCreate) -> UserPublic: 

138 repo = self._ensure_user_repo() 

139 hashed_pw = self.password_hasher.hash(user.password) 

140 user_to_create = user.model_copy(update={"password": hashed_pw}) 

141 created = await repo.create_user(user_to_create) 

142 return await repo.to_public(created) 

143 

144 @self.router.get("/me", response_model=UserPublic) 

145 async def get_current_user(token: str = Depends(self.oauth2_scheme)) -> UserPublic: 

146 repo = self._ensure_user_repo() 

147 try: 

148 payload = self.token_service.decode_token(token, verify_type="access") 

149 token_data = TokenData(**payload) 

150 except JWTError: 

151 raise HTTPException( 

152 status_code=status.HTTP_401_UNAUTHORIZED, 

153 detail="Could not validate credentials", 

154 headers={"WWW-Authenticate": "Bearer"}, 

155 ) 

156 

157 user = await repo.get_by_id(token_data.sub) 

158 if not user: 

159 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") 

160 return await repo.to_public(user) 

161 

162 def _add_password_routes(self) -> None: 

163 @self.router.post("/password/reset-request") 

164 async def request_password_reset(request: PasswordResetRequest) -> Dict[str, str]: 

165 repo = self._ensure_user_repo() 

166 if not self.mailer: 

167 raise HTTPException( 

168 status_code=status.HTTP_501_NOT_IMPLEMENTED, 

169 detail="Mailer not provided. Supply a Mailer implementation.", 

170 ) 

171 

172 user = await repo.get_by_email(request.email) 

173 if not user: 

174 # Intentionally return generic message 

175 return {"detail": "If the account exists, a reset email will be sent."} 

176 

177 token = self.token_service.create_refresh_token(str(getattr(user, "id", ""))) 

178 await self.mailer.send_password_reset(request, token) 

179 return {"detail": "If the account exists, a reset email will be sent."} 

180 

181 @self.router.post("/password/reset") 

182 async def reset_password(payload: PasswordReset) -> Dict[str, str]: 

183 repo = self._ensure_user_repo() 

184 try: 

185 data = self.token_service.decode_token(payload.token) 

186 user_id = data.get("sub") 

187 except JWTError: 

188 raise HTTPException( 

189 status_code=status.HTTP_400_BAD_REQUEST, 

190 detail="Invalid or expired token", 

191 ) 

192 

193 user = await repo.get_by_id(user_id) 

194 if not user: 

195 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") 

196 

197 hashed = self.password_hasher.hash(payload.new_password) 

198 

199 if hasattr(repo, "update_password"): 

200 await getattr(repo, "update_password")(user_id, hashed) # type: ignore[attr-defined] 

201 return {"detail": "Password updated"} 

202 

203 if hasattr(user, "set_password"): 

204 user.set_password(hashed) 

205 return {"detail": "Password updated"} 

206 

207 raise HTTPException( 

208 status_code=status.HTTP_501_NOT_IMPLEMENTED, 

209 detail="Provide `update_password` on UserRepository or `set_password` on user model.", 

210 ) 

211 

212 def _add_refresh_routes(self) -> None: 

213 @self.router.post("/token/refresh", response_model=Token) 

214 async def refresh_access_token(refresh_token: str) -> Token: 

215 try: 

216 payload = self.token_service.decode_token(refresh_token, verify_type="refresh") 

217 user_id = payload.get("sub") 

218 except JWTError: 

219 raise HTTPException( 

220 status_code=status.HTTP_401_UNAUTHORIZED, 

221 detail="Invalid refresh token", 

222 headers={"WWW-Authenticate": "Bearer"}, 

223 ) 

224 

225 access_token = self.token_service.create_access_token(user_id) 

226 return Token(access_token=access_token, token_type="bearer") 

227 

228 def get_router(self) -> APIRouter: 

229 return self.router 

230 

231 

232def create_auth_router( 

233 *, 

234 config: Optional[AuthConfig] = None, 

235 dependencies: Optional[AuthDependencyConfig] = None, 

236) -> APIRouter: 

237 """便捷函数:构建可配置的认证路由器。""" 

238 module = AuthModule(config=config, dependencies=dependencies) 

239 return module.get_router()