Coverage for src/oidc_provider_mock/_storage.py: 98%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-15 20:58 +0100

1from collections.abc import Sequence 

2from dataclasses import dataclass, field 

3from datetime import datetime, timezone 

4from typing import cast 

5 

6import authlib.oauth2.rfc6749 

7import authlib.oidc.core 

8import flask 

9import werkzeug.local 

10from authlib import jose 

11from typing_extensions import override 

12 

13 

14class ClientSkipVerification: 

15 """Special value for client secret and redirect URIs that indicates that 

16 verification is skipped.""" 

17 

18 

19@dataclass(kw_only=True, frozen=True) 

20class Client: 

21 id: str 

22 secret: str | ClientSkipVerification 

23 redirect_uris: Sequence[str] | ClientSkipVerification 

24 allowed_scopes: Sequence[str] = ("openid", "profile") 

25 

26 

27@dataclass(kw_only=True, frozen=True) 

28class User: 

29 sub: str 

30 claims: dict[str, str] = field(default_factory=dict) 

31 userinfo: dict[str, object] = field(default_factory=dict) 

32 

33 

34@dataclass(kw_only=True, frozen=True) 

35class AuthorizationCode(authlib.oidc.core.AuthorizationCodeMixin): 

36 code: str 

37 client_id: str 

38 redirect_uri: str 

39 user_id: str 

40 scope: str 

41 nonce: str | None 

42 

43 # Implement AuthorizationCodeMixin 

44 

45 @override 

46 def get_redirect_uri(self): 

47 return self.redirect_uri 

48 

49 @override 

50 def get_scope(self): 

51 return self.scope 

52 

53 @override 

54 def get_nonce(self) -> str | None: 

55 return self.nonce 

56 

57 @override 

58 def get_auth_time(self) -> int | None: 

59 return None 

60 

61 

62@dataclass(kw_only=True, frozen=True) 

63class AccessToken(authlib.oauth2.rfc6749.TokenMixin): 

64 token: str 

65 user_id: str 

66 scope: str 

67 expires_at: datetime 

68 

69 def get_user(self): 

70 return storage.get_user(self.user_id) 

71 

72 # Implement `TokenMixin` 

73 

74 @override 

75 def check_client(self, client: Client): 

76 # TODO implement 

77 return True 

78 

79 @override 

80 def is_expired(self): 

81 return datetime.now(timezone.utc) >= self.expires_at 

82 

83 @override 

84 def is_revoked(self): 

85 return False 

86 

87 @override 

88 def get_scope(self) -> str: 

89 return self.scope 

90 

91 

92class Storage: 

93 jwk: jose.RSAKey 

94 

95 _clients: dict[str, Client] 

96 _users: dict[str, User] 

97 _authorization_codes: dict[str, AuthorizationCode] 

98 _access_tokens: dict[str, AccessToken] 

99 _nonces: set[str] 

100 

101 def __init__(self) -> None: 

102 self.jwk = jose.RSAKey.generate_key(is_private=True) # pyright: ignore[reportUnknownMemberType] 

103 self._users = {} 

104 self._authorization_codes = {} 

105 self._access_tokens = {} 

106 self._nonces = set() 

107 self._clients = {} 

108 

109 def get_user(self, sub: str) -> User | None: 

110 return self._users.get(sub) 

111 

112 def store_user(self, user: User): 

113 self._users[user.sub] = user 

114 

115 def get_authorization_code(self, code: str) -> AuthorizationCode | None: 

116 return self._authorization_codes.get(code) 

117 

118 def store_authorization_code(self, code: AuthorizationCode): 

119 self._authorization_codes[code.code] = code 

120 

121 def remove_authorization_code(self, code: str) -> AuthorizationCode | None: 

122 return self._authorization_codes.pop(code, None) 

123 

124 def get_access_token(self, token: str) -> AccessToken | None: 

125 return self._access_tokens.get(token) 

126 

127 def store_access_token(self, access_token: AccessToken): 

128 self._access_tokens[access_token.token] = access_token 

129 

130 def get_client(self, id: str) -> Client | None: 

131 return self._clients.get(id) 

132 

133 def store_client(self, client: Client): 

134 self._clients[client.id] = client 

135 

136 def add_nonce(self, nonce: str): 

137 self._nonces.add(nonce) 

138 

139 def exists_nonce(self, nonce: str) -> bool: 

140 return nonce in self._nonces 

141 

142 

143storage = cast( 

144 "Storage", werkzeug.local.LocalProxy(lambda: flask.g.oidc_provider_mock_storage) 

145)