Coverage for src/oidc_provider_mock/_app.py: 92%
174 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-15 20:58 +0100
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-15 20:58 +0100
1import secrets
2from collections.abc import Callable, Sequence
3from dataclasses import dataclass
4from datetime import datetime, timedelta, timezone
5from http import HTTPStatus
6from typing import cast
7from uuid import uuid4
9import authlib.oauth2.rfc6749
10import authlib.oauth2.rfc6749.grants
11import authlib.oauth2.rfc6750
12import authlib.oidc.core.grants
13import flask
14import flask.typing
15import pydantic
16from authlib import jose
17from authlib.integrations import flask_oauth2
18from authlib.oauth2 import OAuth2Request
19from authlib.oauth2.rfc6749 import AccessDeniedError
20from typing_extensions import override
22from ._storage import (
23 AccessToken,
24 AuthorizationCode,
25 Client,
26 ClientSkipVerification,
27 Storage,
28 User,
29 storage,
30)
33class AuthlibClient(authlib.oauth2.rfc6749.ClientMixin):
34 """Wrap ``Client`` to implement authlib’s client protocol."""
36 def __init__(self, client: Client) -> None:
37 self._client = client
39 @override
40 def get_client_id(self):
41 return self._client.id
43 @override
44 def get_default_redirect_uri(self) -> str:
45 raise NotImplementedError()
47 @override
48 def get_allowed_scope(self, scope: str) -> str:
49 return " ".join(s for s in scope.split() if s in self._client.allowed_scopes)
51 @override
52 def check_redirect_uri(self, redirect_uri: str) -> bool:
53 if isinstance(self._client.redirect_uris, ClientSkipVerification): 53 ↛ anywhereline 53 didn't jump anywhere: it always raised an exception.
54 return True
56 return redirect_uri in self._client.redirect_uris
58 @override
59 def check_client_secret(self, client_secret: str) -> bool:
60 if type(self._client.secret) is ClientSkipVerification: 60 ↛ anywhereline 60 didn't jump anywhere: it always raised an exception.
61 return True
63 return client_secret == self._client.secret
65 # TODO
66 @override
67 def check_endpoint_auth_method(self, method: str, endpoint: object):
68 return method in {"client_secret_post", "client_secret_basic"}
70 # TODO
71 @override
72 def check_grant_type(self, grant_type: str):
73 return True
75 @override
76 def check_response_type(self, response_type: str):
77 return response_type == "code"
80class TokenValidator(authlib.oauth2.rfc6750.BearerTokenValidator):
81 def authenticate_token(self, token_string: str):
82 token = storage.get_access_token(token_string)
83 if not token: 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true
84 raise AccessDeniedError
86 return token
89class AuthorizationCodeGrant(authlib.oauth2.rfc6749.AuthorizationCodeGrant):
90 @override
91 def query_authorization_code(
92 self, code: str, client: AuthlibClient
93 ) -> AuthorizationCode | None:
94 auth_code = storage.get_authorization_code(code)
95 if auth_code and auth_code.client_id == client.get_client_id(): 95 ↛ anywhereline 95 didn't jump anywhere: it always raised an exception.
96 return auth_code
98 @override
99 def delete_authorization_code(self, authorization_code: AuthorizationCode):
100 storage.remove_authorization_code(authorization_code.code)
102 @override
103 def authenticate_user(self, authorization_code: AuthorizationCode) -> User | None:
104 return storage.get_user(authorization_code.user_id)
106 @override
107 def save_authorization_code(self, code: str, request: object):
108 assert isinstance(request, OAuth2Request)
109 assert isinstance(request.user, User)
110 client = cast("AuthlibClient", request.client)
111 assert isinstance(request.redirect_uri, str) # type: ignore
112 storage.store_authorization_code(
113 AuthorizationCode(
114 code=code,
115 user_id=request.user.sub,
116 client_id=client.get_client_id(),
117 redirect_uri=request.redirect_uri,
118 scope=request.scope,
119 nonce=request.data.get("nonce"), # type: ignore
120 )
121 )
123 # @override
124 # def validate_authorization_request(self):
125 # if "scope" not in self.request.scope.split(" "):
126 # raise
127 # return super().validate_authorization_request()
130class OpenIdGrantExtension:
131 def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
132 return storage.exists_nonce(nonce)
134 def get_jwt_config(self, *args: object, **kwargs: object):
135 return {
136 "key": storage.jwk,
137 "alg": "RS256",
138 "exp": 3600,
139 "iss": flask.request.host_url.rstrip("/"),
140 }
142 def generate_user_info(self, user: User, scope: Sequence[str]):
143 return {
144 **user.claims,
145 "sub": user.sub,
146 }
149class OpenIDCode(OpenIdGrantExtension, authlib.oidc.core.OpenIDCode):
150 pass
153class ImplicitGrant(OpenIdGrantExtension, authlib.oidc.core.OpenIDImplicitGrant):
154 pass
157class HybridGrant(OpenIdGrantExtension, authlib.oidc.core.OpenIDHybridGrant):
158 pass
161# TODO: turn into context variables
162authorization = flask_oauth2.AuthorizationServer()
163require_oauth = flask_oauth2.ResourceProtector()
166blueprint = flask.Blueprint("oidc-provider-mock-authlib", __name__)
169@dataclass(kw_only=True, frozen=True)
170class Config:
171 require_client_registration: bool = False
174@blueprint.record
175def setup(setup_state: flask.blueprints.BlueprintSetupState):
176 assert isinstance(setup_state.app, flask.Flask)
178 config = setup_state.options.get("config", Config())
179 if not isinstance(config, Config): 179 ↛ 180line 179 didn't jump to line 180 because the condition on line 179 was never true
180 raise TypeError(
181 f"Expected {Config.__name__} as `config` option for blueprint, got {type(config)}"
182 )
184 storage = Storage()
186 @setup_state.app.before_request
187 def set_storage():
188 flask.g.oidc_provider_mock_storage = storage
190 def query_client(id: str) -> AuthlibClient | None:
191 client = storage.get_client(id)
192 if not client and not config.require_client_registration:
193 client = Client(
194 id=id,
195 secret=ClientSkipVerification(),
196 redirect_uris=ClientSkipVerification(),
197 )
199 if client: 199 ↛ anywhereline 199 didn't jump anywhere: it always raised an exception.
200 return AuthlibClient(client)
202 def save_token(token: dict[str, object], request: OAuth2Request):
203 assert token["token_type"] == "Bearer"
204 assert isinstance(token["access_token"], str)
205 assert isinstance(token["expires_in"], int)
206 assert isinstance(request.user, User)
207 user = cast("User", request.user)
209 storage.store_access_token(
210 AccessToken(
211 token=token["access_token"],
212 user_id=user.sub,
213 # request.scope may actually be None
214 scope=request.scope or "",
215 expires_at=datetime.now(timezone.utc)
216 + timedelta(seconds=token["expires_in"]),
217 )
218 )
220 authorization.init_app( # type: ignore
221 setup_state.app,
222 query_client=query_client,
223 save_token=save_token,
224 )
226 authorization.register_grant( # type: ignore
227 AuthorizationCodeGrant,
228 # TODO: Make this configurable
229 [OpenIDCode(require_nonce=False)],
230 )
231 authorization.register_grant(ImplicitGrant) # type: ignore
232 authorization.register_grant(HybridGrant) # type: ignore
234 require_oauth.register_token_validator(TokenValidator())
237def app(*, require_client_registration: bool = False) -> flask.Flask:
238 """Create a flask app for the OpenID provider.
240 Call :any:`app().run() <flask.Flask.run>` to start the server.
243 :param require_client_registration: If false (the default) any client ID and
244 secret can be used to authenticate with the token endpoint. If true,
245 clients have to be registered using the `OAuth 2.0 Dynamic Client
246 Registration Protocol <https://datatracker.ietf.org/doc/html/rfc7591>`_.
247 """
248 app = flask.Flask(__name__)
250 app.register_blueprint(
251 blueprint,
252 config=Config(require_client_registration=require_client_registration),
253 )
254 return app
257@blueprint.get("/")
258def home():
259 return flask.render_template("index.html")
262@blueprint.get("/.well-known/openid-configuration")
263def openid_config():
264 def url_for(fn: Callable[..., object]) -> str:
265 return flask.url_for(f".{fn.__name__}", _external=True)
267 return flask.jsonify({
268 "issuer": flask.request.host_url.rstrip("/"),
269 "authorization_endpoint": url_for(authorize),
270 "token_endpoint": url_for(issue_token),
271 "userinfo_endpoint": url_for(userinfo),
272 "registration_endpoint": url_for(register_client),
273 "jwks_uri": url_for(jwks),
274 # TODO properly populate these
275 "response_types_supported": ["code", "id_token", "id_token token"],
276 "subject_types_supported": ["public"],
277 "id_token_signing_alg_values_supported": ["RS256"],
278 })
281@blueprint.get("/jwks")
282def jwks():
283 return flask.jsonify(
284 jose.KeySet((storage.jwk,)).as_dict(), # pyright: ignore[reportUnknownMemberType]
285 )
288class RegisterClientBody(pydantic.BaseModel):
289 redirect_uris: Sequence[pydantic.HttpUrl]
292@blueprint.post("/register-client")
293def register_client():
294 payload = RegisterClientBody.model_validate(flask.request.json)
296 client = Client(
297 id=str(uuid4()),
298 secret=secrets.token_urlsafe(16),
299 redirect_uris=[str(uri) for uri in payload.redirect_uris],
300 )
301 storage.store_client(client)
302 return flask.jsonify({
303 "client_id": client.id,
304 "client_secret": client.secret,
305 "redirect_uris": client.redirect_uris,
306 # For now, limit the accepted flow configuration
307 "token_endpoint_auth_method": ["client_secret_basic"],
308 "grant_types": ["authorization_code"],
309 "response_types": ["code"],
310 }), HTTPStatus.CREATED
313@blueprint.route("/oauth2/authorize", methods=["GET", "POST"])
314def authorize() -> flask.typing.ResponseReturnValue:
315 if flask.request.method == "GET":
316 # Validates request parameters
317 try:
318 authorization.get_consent_grant() # type: ignore
319 except authlib.oauth2.rfc6749.errors.InvalidClientError as e:
320 raise NotImplementedError() from e
322 return flask.render_template("authorization_form.html")
323 else:
324 # TODO: validate sub
325 user = storage.get_user(flask.request.form["sub"])
326 if not user: 326 ↛ 329line 326 didn't jump to line 329 because the condition on line 326 was always true
327 user = User(sub=flask.request.form["sub"])
328 storage.store_user(user)
329 return authorization.create_authorization_response(grant_user=user) # type: ignore
332@blueprint.route("/oauth2/token", methods=["POST"])
333def issue_token() -> flask.typing.ResponseReturnValue:
334 return authorization.create_token_response() # pyright: ignore
337@blueprint.route("/oauth/userinfo", methods=["GET", "POST"])
338@require_oauth()
339def userinfo():
340 # TODO implement filtering by scope
341 return flask.jsonify({
342 **flask_oauth2.current_token.get_user().userinfo,
343 "sub": flask_oauth2.current_token.user_id,
344 })
347class SetUserBody(pydantic.BaseModel):
348 claims: dict[str, str] = pydantic.Field(default_factory=dict)
349 userinfo: dict[str, object] = pydantic.Field(default_factory=dict)
352@blueprint.route("/users/<sub>", methods=["PUT"])
353def set_user(sub: str):
354 # TODO: Return 400 if validation fails
355 payload = SetUserBody.model_validate(flask.request.json, strict=True)
356 storage.store_user(User(sub=sub, claims=payload.claims, userinfo=payload.userinfo))
357 return "", HTTPStatus.NO_CONTENT