Coverage for src/oidc_provider_mock/_internal.py: 89%
171 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-11 15:07 +0100
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-11 15:07 +0100
1import functools
2import secrets
3from collections.abc import Callable
4from dataclasses import dataclass, field
5from datetime import UTC, datetime, timedelta
6from http import HTTPStatus
7from urllib.parse import urlencode, urlsplit
9import flask
10import pydantic
11import werkzeug.exceptions
12from authlib import jose
15class _AccessToken:
16 def __init__(self, sub: str, expires_in: timedelta):
17 self.sub = sub
18 self.expires_at = datetime.now(UTC) + expires_in
19 self.token = secrets.token_urlsafe(16)
22@dataclass(kw_only=True)
23class _AuthorizationGrant:
24 sub: str
25 nonce: str | None
26 client_id: str
27 code: str = field(init=False)
28 expires_at: datetime = field(init=False)
30 def __post_init__(self):
31 # TODO: allow configuration of expiration
32 self.expires_at = datetime.now(UTC) + timedelta(seconds=60)
33 self.code = secrets.token_urlsafe(16)
35 def valid(self) -> bool:
36 return self.expires_at > datetime.now(UTC)
39@dataclass(kw_only=True)
40class User:
41 sub: str
42 claims: dict[str, str] = field(default_factory=dict)
43 userinfo: dict[str, object] = field(default_factory=dict)
46class State:
47 _access_tokens: list[_AccessToken]
48 _authorization_grants: list[_AuthorizationGrant]
49 _users: dict[str, User]
51 def __init__(
52 self,
53 access_token_lifetime: timedelta = timedelta(hours=1),
54 ) -> None:
55 self._access_tokens_lifetime = access_token_lifetime
56 # TODO: limit number of items
57 self._access_tokens = []
58 self._authorization_grants = []
59 self.key = jose.RSAKey.generate_key(2048, is_private=True) # pyright: ignore[reportUnknownMemberType]
60 self._users = {}
62 def get_access_token(self, token: str) -> _AccessToken | None:
63 return next(
64 (identity for identity in self._access_tokens if identity.token == token),
65 None,
66 )
68 def get_authorization(self, code: str) -> _AuthorizationGrant | None:
69 authorization_grant = next(
70 (a for a in self._authorization_grants if a.code == code),
71 None,
72 )
74 if not authorization_grant: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true
75 return None
77 if not authorization_grant.valid(): 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true
78 self._authorization_grants.remove(authorization_grant)
79 return None
81 return authorization_grant
83 def add_authorization_grant(self, grant: _AuthorizationGrant):
84 self._authorization_grants.append(grant)
86 def add_access_token(
87 self, sub: str, expires_in: timedelta | None = None
88 ) -> _AccessToken:
89 identity = _AccessToken(sub, expires_in or self._access_tokens_lifetime)
90 self._access_tokens.append(identity)
91 return identity
93 @staticmethod
94 def current() -> "State":
95 return flask.g.oidc_mock_provider_state
97 def bind_to_app_context(self, app: flask.Flask):
98 @app.before_request
99 def provide_state():
100 flask.g.oidc_mock_provider_state = self
102 def update_user(self, user: User) -> User:
103 existing = self._users.get(user.sub)
104 if not existing:
105 self._users[user.sub] = user
106 return user
108 existing.claims.update(user.claims)
109 existing.userinfo.update(user.userinfo)
110 return existing
112 def get_user(self, sub: str) -> User | None:
113 return self._users.get(sub, None)
116blueprint = flask.Blueprint("oidc-provider", __name__)
119@blueprint.record
120def bind_state_to_app_context(setup_state: flask.blueprints.BlueprintSetupState):
121 assert isinstance(setup_state.app, flask.Flask)
122 state = setup_state.options["state"]
123 if not isinstance(state, State): 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true
124 raise TypeError("Blueprint option `state` must be an instance of `State`")
125 state.bind_to_app_context(setup_state.app)
128@blueprint.after_request
129def after_request(response: flask.Response):
130 response.headers.setdefault("cache-control", "no-cache")
131 return response
134@blueprint.get("/")
135def home():
136 return flask.render_template("index.html")
139@blueprint.get("/.well-known/openid-configuration")
140def openid_config():
141 jwks_uri = flask.url_for(".jwks", _external=True)
142 authorization_endpoint = flask.url_for(f".{authorize.__name__}", _external=True)
143 token_endpoint = flask.url_for(".get_token", _external=True)
144 userinfo_endpoint = flask.url_for(".userinfo", _external=True)
145 return flask.jsonify({
146 "issuer": flask.request.host_url.rstrip("/"),
147 "authorization_endpoint": authorization_endpoint,
148 "token_endpoint": token_endpoint,
149 "userinfo_endpoint": userinfo_endpoint,
150 "jwks_uri": jwks_uri,
151 "response_types_supported": ["code", "id_token", "id_token token"],
152 "subject_types_supported": ["public"],
153 "id_token_signing_alg_values_supported": ["RS256"],
154 })
157@blueprint.get("/jwks")
158def jwks():
159 return flask.jsonify(jose.KeySet((State.current().key,)).as_dict()) # type: ignore
162def show_bad_request_details[**T, R](fn: Callable[T, R]) -> Callable[T, R]:
163 @functools.wraps(fn)
164 def wrapped(*args: T.args, **kwargs: T.kwargs):
165 try:
166 return fn(*args, **kwargs)
167 except werkzeug.exceptions.BadRequestKeyError as ex:
168 ex.show_exception = True
169 raise ex
171 return wrapped
174@blueprint.route("/oauth2/authorize", methods=("GET", "POST"))
175@show_bad_request_details
176def authorize():
177 if flask.request.method == "GET":
178 return ask_authorization()
179 else:
180 return process_authorization()
183def ask_authorization():
184 query = flask.request.args
185 # TODO: verify client_id matches redirect uri
186 for name in {"client_id", "redirect_uri", "response_type"}:
187 if name not in query:
188 raise werkzeug.exceptions.BadRequest(
189 f"{name} missing from query parameters"
190 )
192 response_types = query["response_type"].split(" ")
193 if "code" not in response_types:
194 return (
195 'invalid response_type query parameter: only "code" is supported',
196 400,
197 {"content-type": "text/plain; charset=utf-8"},
198 )
200 # TODO include info about the client
201 # TODO client verification
202 return flask.render_template("authorization_form.html")
205def process_authorization():
206 query = flask.request.args
207 redirect_uri = urlsplit(query["redirect_uri"])
208 # TODO: ensure redirection only to localhost
210 if flask.request.form.get("action") == "deny_access": 210 ↛ 211line 210 didn't jump to line 211 because the condition on line 210 was never true
211 return flask.redirect(
212 redirect_uri._replace(query=urlencode({"error": "access_denied"})).geturl()
213 )
215 nonce = query.get("nonce", None)
216 client_id = query["client_id"]
218 authorization_grant = _AuthorizationGrant(
219 sub=flask.request.form["sub"],
220 nonce=nonce,
221 client_id=client_id,
222 )
223 State.current().add_authorization_grant(authorization_grant)
225 redirect_query_params = {
226 "code": authorization_grant.code,
227 }
229 if "state" in query: 229 ↛ 232line 229 didn't jump to line 232 because the condition on line 229 was always true
230 redirect_query_params["state"] = query["state"]
232 return flask.redirect(
233 redirect_uri._replace(query=urlencode(redirect_query_params)).geturl()
234 )
237@blueprint.post("/oauth2/token")
238@show_bad_request_details
239def get_token():
240 data = flask.request.form
241 # TODO: params grant_type, redirect_uri, client_id
242 # TODO: client auth
243 authorization = State.current().get_authorization(data["code"])
244 if not authorization: 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true
245 return flask.jsonify({"error": "invalid_grant"}), HTTPStatus.NOT_FOUND
247 user = User(sub=authorization.sub)
248 user = State.current().update_user(user)
249 identity = State.current().add_access_token(user.sub)
250 id_token = jose.jwt.encode( # pyright: ignore
251 {
252 "alg": "RS256",
253 "kid": State.current().key.thumbprint(),
254 },
255 {
256 **user.claims,
257 "iss": flask.request.host_url.rstrip("/"),
258 "aud": authorization.client_id,
259 "sub": authorization.sub,
260 "nonce": authorization.nonce,
261 "iat": datetime.now(UTC).timestamp(),
262 # TODO: allow configuration of expiration
263 "exp": (datetime.now(UTC) + timedelta(hours=1)).timestamp(),
264 },
265 State.current().key,
266 ).decode("utf-8")
268 return flask.jsonify({
269 "access_token": identity.token,
270 "token_type": "Bearer",
271 "expires_in": 3600,
272 # "refresh_token": "REFRESH_TOKEN",
273 "id_token": id_token,
274 })
277@blueprint.route("/userinfo", methods=["GET", "POST"])
278def userinfo():
279 """Return user info for the provided OAuth2 Bearer token"""
280 if ( 280 ↛ 286line 280 didn't jump to line 286 because the condition on line 280 was never true
281 not flask.request.authorization
282 or flask.request.authorization.type != "bearer"
283 or not flask.request.authorization.token
284 ):
285 # TODO: include error in json
286 return (
287 flask.jsonify({"error": ""}),
288 HTTPStatus.UNAUTHORIZED,
289 {"www-authenticate": "Bearer"},
290 )
292 identity = State.current().get_access_token(flask.request.authorization.token)
293 # TODO: check valid
294 if not identity: 294 ↛ 296line 294 didn't jump to line 296 because the condition on line 294 was never true
295 # TODO: include error in json
296 return (
297 flask.jsonify({"error": ""}),
298 HTTPStatus.UNAUTHORIZED,
299 {"www-authenticate": "Bearer"},
300 )
302 user = State.current().get_user(identity.sub)
303 if not user: 303 ↛ 305line 303 didn't jump to line 305 because the condition on line 303 was never true
304 # TODO: include error in json
305 return (
306 flask.jsonify({"error": ""}),
307 HTTPStatus.UNAUTHORIZED,
308 {"www-authenticate": "Bearer"},
309 )
311 return flask.jsonify(user.userinfo), HTTPStatus.OK
314class UserCreatePayload(pydantic.BaseModel):
315 sub: str
316 claims: dict[str, str] = pydantic.Field(default_factory=dict)
317 userinfo: dict[str, object] = pydantic.Field(default_factory=dict)
320@blueprint.route("/users", methods=["POST"])
321def create_user():
322 # TODO: document this endpoint for users
323 payload = UserCreatePayload.model_validate(flask.request.json, strict=True)
324 user = User(sub=payload.sub, claims=payload.claims, userinfo=payload.userinfo)
325 State.current().update_user(user)
326 return "", HTTPStatus.CREATED