Coverage for src/oidc_provider_mock/_storage.py: 98%
99 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-15 20:58 +0100
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-15 20:58 +0100
1from collections.abc import Sequence
2from dataclasses import dataclass, field
3from datetime import datetime, timezone
4from typing import cast
6import authlib.oauth2.rfc6749
7import authlib.oidc.core
8import flask
9import werkzeug.local
10from authlib import jose
11from typing_extensions import override
14class ClientSkipVerification:
15 """Special value for client secret and redirect URIs that indicates that
16 verification is skipped."""
19@dataclass(kw_only=True, frozen=True)
20class Client:
21 id: str
22 secret: str | ClientSkipVerification
23 redirect_uris: Sequence[str] | ClientSkipVerification
24 allowed_scopes: Sequence[str] = ("openid", "profile")
27@dataclass(kw_only=True, frozen=True)
28class User:
29 sub: str
30 claims: dict[str, str] = field(default_factory=dict)
31 userinfo: dict[str, object] = field(default_factory=dict)
34@dataclass(kw_only=True, frozen=True)
35class AuthorizationCode(authlib.oidc.core.AuthorizationCodeMixin):
36 code: str
37 client_id: str
38 redirect_uri: str
39 user_id: str
40 scope: str
41 nonce: str | None
43 # Implement AuthorizationCodeMixin
45 @override
46 def get_redirect_uri(self):
47 return self.redirect_uri
49 @override
50 def get_scope(self):
51 return self.scope
53 @override
54 def get_nonce(self) -> str | None:
55 return self.nonce
57 @override
58 def get_auth_time(self) -> int | None:
59 return None
62@dataclass(kw_only=True, frozen=True)
63class AccessToken(authlib.oauth2.rfc6749.TokenMixin):
64 token: str
65 user_id: str
66 scope: str
67 expires_at: datetime
69 def get_user(self):
70 return storage.get_user(self.user_id)
72 # Implement `TokenMixin`
74 @override
75 def check_client(self, client: Client):
76 # TODO implement
77 return True
79 @override
80 def is_expired(self):
81 return datetime.now(timezone.utc) >= self.expires_at
83 @override
84 def is_revoked(self):
85 return False
87 @override
88 def get_scope(self) -> str:
89 return self.scope
92class Storage:
93 jwk: jose.RSAKey
95 _clients: dict[str, Client]
96 _users: dict[str, User]
97 _authorization_codes: dict[str, AuthorizationCode]
98 _access_tokens: dict[str, AccessToken]
99 _nonces: set[str]
101 def __init__(self) -> None:
102 self.jwk = jose.RSAKey.generate_key(is_private=True) # pyright: ignore[reportUnknownMemberType]
103 self._users = {}
104 self._authorization_codes = {}
105 self._access_tokens = {}
106 self._nonces = set()
107 self._clients = {}
109 def get_user(self, sub: str) -> User | None:
110 return self._users.get(sub)
112 def store_user(self, user: User):
113 self._users[user.sub] = user
115 def get_authorization_code(self, code: str) -> AuthorizationCode | None:
116 return self._authorization_codes.get(code)
118 def store_authorization_code(self, code: AuthorizationCode):
119 self._authorization_codes[code.code] = code
121 def remove_authorization_code(self, code: str) -> AuthorizationCode | None:
122 return self._authorization_codes.pop(code, None)
124 def get_access_token(self, token: str) -> AccessToken | None:
125 return self._access_tokens.get(token)
127 def store_access_token(self, access_token: AccessToken):
128 self._access_tokens[access_token.token] = access_token
130 def get_client(self, id: str) -> Client | None:
131 return self._clients.get(id)
133 def store_client(self, client: Client):
134 self._clients[client.id] = client
136 def add_nonce(self, nonce: str):
137 self._nonces.add(nonce)
139 def exists_nonce(self, nonce: str) -> bool:
140 return nonce in self._nonces
143storage = cast(
144 "Storage", werkzeug.local.LocalProxy(lambda: flask.g.oidc_provider_mock_storage)
145)