Coverage for src/oidc_provider_mock/_internal.py: 89%

171 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-11 15:07 +0100

1import functools 

2import secrets 

3from collections.abc import Callable 

4from dataclasses import dataclass, field 

5from datetime import UTC, datetime, timedelta 

6from http import HTTPStatus 

7from urllib.parse import urlencode, urlsplit 

8 

9import flask 

10import pydantic 

11import werkzeug.exceptions 

12from authlib import jose 

13 

14 

15class _AccessToken: 

16 def __init__(self, sub: str, expires_in: timedelta): 

17 self.sub = sub 

18 self.expires_at = datetime.now(UTC) + expires_in 

19 self.token = secrets.token_urlsafe(16) 

20 

21 

22@dataclass(kw_only=True) 

23class _AuthorizationGrant: 

24 sub: str 

25 nonce: str | None 

26 client_id: str 

27 code: str = field(init=False) 

28 expires_at: datetime = field(init=False) 

29 

30 def __post_init__(self): 

31 # TODO: allow configuration of expiration 

32 self.expires_at = datetime.now(UTC) + timedelta(seconds=60) 

33 self.code = secrets.token_urlsafe(16) 

34 

35 def valid(self) -> bool: 

36 return self.expires_at > datetime.now(UTC) 

37 

38 

39@dataclass(kw_only=True) 

40class User: 

41 sub: str 

42 claims: dict[str, str] = field(default_factory=dict) 

43 userinfo: dict[str, object] = field(default_factory=dict) 

44 

45 

46class State: 

47 _access_tokens: list[_AccessToken] 

48 _authorization_grants: list[_AuthorizationGrant] 

49 _users: dict[str, User] 

50 

51 def __init__( 

52 self, 

53 access_token_lifetime: timedelta = timedelta(hours=1), 

54 ) -> None: 

55 self._access_tokens_lifetime = access_token_lifetime 

56 # TODO: limit number of items 

57 self._access_tokens = [] 

58 self._authorization_grants = [] 

59 self.key = jose.RSAKey.generate_key(2048, is_private=True) # pyright: ignore[reportUnknownMemberType] 

60 self._users = {} 

61 

62 def get_access_token(self, token: str) -> _AccessToken | None: 

63 return next( 

64 (identity for identity in self._access_tokens if identity.token == token), 

65 None, 

66 ) 

67 

68 def get_authorization(self, code: str) -> _AuthorizationGrant | None: 

69 authorization_grant = next( 

70 (a for a in self._authorization_grants if a.code == code), 

71 None, 

72 ) 

73 

74 if not authorization_grant: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true

75 return None 

76 

77 if not authorization_grant.valid(): 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true

78 self._authorization_grants.remove(authorization_grant) 

79 return None 

80 

81 return authorization_grant 

82 

83 def add_authorization_grant(self, grant: _AuthorizationGrant): 

84 self._authorization_grants.append(grant) 

85 

86 def add_access_token( 

87 self, sub: str, expires_in: timedelta | None = None 

88 ) -> _AccessToken: 

89 identity = _AccessToken(sub, expires_in or self._access_tokens_lifetime) 

90 self._access_tokens.append(identity) 

91 return identity 

92 

93 @staticmethod 

94 def current() -> "State": 

95 return flask.g.oidc_mock_provider_state 

96 

97 def bind_to_app_context(self, app: flask.Flask): 

98 @app.before_request 

99 def provide_state(): 

100 flask.g.oidc_mock_provider_state = self 

101 

102 def update_user(self, user: User) -> User: 

103 existing = self._users.get(user.sub) 

104 if not existing: 

105 self._users[user.sub] = user 

106 return user 

107 

108 existing.claims.update(user.claims) 

109 existing.userinfo.update(user.userinfo) 

110 return existing 

111 

112 def get_user(self, sub: str) -> User | None: 

113 return self._users.get(sub, None) 

114 

115 

116blueprint = flask.Blueprint("oidc-provider", __name__) 

117 

118 

119@blueprint.record 

120def bind_state_to_app_context(setup_state: flask.blueprints.BlueprintSetupState): 

121 assert isinstance(setup_state.app, flask.Flask) 

122 state = setup_state.options["state"] 

123 if not isinstance(state, State): 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true

124 raise TypeError("Blueprint option `state` must be an instance of `State`") 

125 state.bind_to_app_context(setup_state.app) 

126 

127 

128@blueprint.after_request 

129def after_request(response: flask.Response): 

130 response.headers.setdefault("cache-control", "no-cache") 

131 return response 

132 

133 

134@blueprint.get("/") 

135def home(): 

136 return flask.render_template("index.html") 

137 

138 

139@blueprint.get("/.well-known/openid-configuration") 

140def openid_config(): 

141 jwks_uri = flask.url_for(".jwks", _external=True) 

142 authorization_endpoint = flask.url_for(f".{authorize.__name__}", _external=True) 

143 token_endpoint = flask.url_for(".get_token", _external=True) 

144 userinfo_endpoint = flask.url_for(".userinfo", _external=True) 

145 return flask.jsonify({ 

146 "issuer": flask.request.host_url.rstrip("/"), 

147 "authorization_endpoint": authorization_endpoint, 

148 "token_endpoint": token_endpoint, 

149 "userinfo_endpoint": userinfo_endpoint, 

150 "jwks_uri": jwks_uri, 

151 "response_types_supported": ["code", "id_token", "id_token token"], 

152 "subject_types_supported": ["public"], 

153 "id_token_signing_alg_values_supported": ["RS256"], 

154 }) 

155 

156 

157@blueprint.get("/jwks") 

158def jwks(): 

159 return flask.jsonify(jose.KeySet((State.current().key,)).as_dict()) # type: ignore 

160 

161 

162def show_bad_request_details[**T, R](fn: Callable[T, R]) -> Callable[T, R]: 

163 @functools.wraps(fn) 

164 def wrapped(*args: T.args, **kwargs: T.kwargs): 

165 try: 

166 return fn(*args, **kwargs) 

167 except werkzeug.exceptions.BadRequestKeyError as ex: 

168 ex.show_exception = True 

169 raise ex 

170 

171 return wrapped 

172 

173 

174@blueprint.route("/oauth2/authorize", methods=("GET", "POST")) 

175@show_bad_request_details 

176def authorize(): 

177 if flask.request.method == "GET": 

178 return ask_authorization() 

179 else: 

180 return process_authorization() 

181 

182 

183def ask_authorization(): 

184 query = flask.request.args 

185 # TODO: verify client_id matches redirect uri 

186 for name in {"client_id", "redirect_uri", "response_type"}: 

187 if name not in query: 

188 raise werkzeug.exceptions.BadRequest( 

189 f"{name} missing from query parameters" 

190 ) 

191 

192 response_types = query["response_type"].split(" ") 

193 if "code" not in response_types: 

194 return ( 

195 'invalid response_type query parameter: only "code" is supported', 

196 400, 

197 {"content-type": "text/plain; charset=utf-8"}, 

198 ) 

199 

200 # TODO include info about the client 

201 # TODO client verification 

202 return flask.render_template("authorization_form.html") 

203 

204 

205def process_authorization(): 

206 query = flask.request.args 

207 redirect_uri = urlsplit(query["redirect_uri"]) 

208 # TODO: ensure redirection only to localhost 

209 

210 if flask.request.form.get("action") == "deny_access": 210 ↛ 211line 210 didn't jump to line 211 because the condition on line 210 was never true

211 return flask.redirect( 

212 redirect_uri._replace(query=urlencode({"error": "access_denied"})).geturl() 

213 ) 

214 

215 nonce = query.get("nonce", None) 

216 client_id = query["client_id"] 

217 

218 authorization_grant = _AuthorizationGrant( 

219 sub=flask.request.form["sub"], 

220 nonce=nonce, 

221 client_id=client_id, 

222 ) 

223 State.current().add_authorization_grant(authorization_grant) 

224 

225 redirect_query_params = { 

226 "code": authorization_grant.code, 

227 } 

228 

229 if "state" in query: 229 ↛ 232line 229 didn't jump to line 232 because the condition on line 229 was always true

230 redirect_query_params["state"] = query["state"] 

231 

232 return flask.redirect( 

233 redirect_uri._replace(query=urlencode(redirect_query_params)).geturl() 

234 ) 

235 

236 

237@blueprint.post("/oauth2/token") 

238@show_bad_request_details 

239def get_token(): 

240 data = flask.request.form 

241 # TODO: params grant_type, redirect_uri, client_id 

242 # TODO: client auth 

243 authorization = State.current().get_authorization(data["code"]) 

244 if not authorization: 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true

245 return flask.jsonify({"error": "invalid_grant"}), HTTPStatus.NOT_FOUND 

246 

247 user = User(sub=authorization.sub) 

248 user = State.current().update_user(user) 

249 identity = State.current().add_access_token(user.sub) 

250 id_token = jose.jwt.encode( # pyright: ignore 

251 { 

252 "alg": "RS256", 

253 "kid": State.current().key.thumbprint(), 

254 }, 

255 { 

256 **user.claims, 

257 "iss": flask.request.host_url.rstrip("/"), 

258 "aud": authorization.client_id, 

259 "sub": authorization.sub, 

260 "nonce": authorization.nonce, 

261 "iat": datetime.now(UTC).timestamp(), 

262 # TODO: allow configuration of expiration 

263 "exp": (datetime.now(UTC) + timedelta(hours=1)).timestamp(), 

264 }, 

265 State.current().key, 

266 ).decode("utf-8") 

267 

268 return flask.jsonify({ 

269 "access_token": identity.token, 

270 "token_type": "Bearer", 

271 "expires_in": 3600, 

272 # "refresh_token": "REFRESH_TOKEN", 

273 "id_token": id_token, 

274 }) 

275 

276 

277@blueprint.route("/userinfo", methods=["GET", "POST"]) 

278def userinfo(): 

279 """Return user info for the provided OAuth2 Bearer token""" 

280 if ( 280 ↛ 286line 280 didn't jump to line 286 because the condition on line 280 was never true

281 not flask.request.authorization 

282 or flask.request.authorization.type != "bearer" 

283 or not flask.request.authorization.token 

284 ): 

285 # TODO: include error in json 

286 return ( 

287 flask.jsonify({"error": ""}), 

288 HTTPStatus.UNAUTHORIZED, 

289 {"www-authenticate": "Bearer"}, 

290 ) 

291 

292 identity = State.current().get_access_token(flask.request.authorization.token) 

293 # TODO: check valid 

294 if not identity: 294 ↛ 296line 294 didn't jump to line 296 because the condition on line 294 was never true

295 # TODO: include error in json 

296 return ( 

297 flask.jsonify({"error": ""}), 

298 HTTPStatus.UNAUTHORIZED, 

299 {"www-authenticate": "Bearer"}, 

300 ) 

301 

302 user = State.current().get_user(identity.sub) 

303 if not user: 303 ↛ 305line 303 didn't jump to line 305 because the condition on line 303 was never true

304 # TODO: include error in json 

305 return ( 

306 flask.jsonify({"error": ""}), 

307 HTTPStatus.UNAUTHORIZED, 

308 {"www-authenticate": "Bearer"}, 

309 ) 

310 

311 return flask.jsonify(user.userinfo), HTTPStatus.OK 

312 

313 

314class UserCreatePayload(pydantic.BaseModel): 

315 sub: str 

316 claims: dict[str, str] = pydantic.Field(default_factory=dict) 

317 userinfo: dict[str, object] = pydantic.Field(default_factory=dict) 

318 

319 

320@blueprint.route("/users", methods=["POST"]) 

321def create_user(): 

322 # TODO: document this endpoint for users 

323 payload = UserCreatePayload.model_validate(flask.request.json, strict=True) 

324 user = User(sub=payload.sub, claims=payload.claims, userinfo=payload.userinfo) 

325 State.current().update_user(user) 

326 return "", HTTPStatus.CREATED