Coverage for src/edwh_auth_rbac/model.py: 78%
195 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-22 15:42 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-22 15:42 +0200
1import copy
2import datetime as dt
3import hashlib
4import hmac
5import typing
6import uuid
7from typing import Optional
8from uuid import UUID
10import dateutil.parser
11from pydal import DAL, Field, SQLCustomType
12from pydal.objects import SQLALL, Query, Table
14from .helpers import IS_IN_LIST
17class DEFAULT:
18 pass
21DEFAULT_STARTS = dt.datetime(2000, 1, 1)
22DEFAULT_ENDS = dt.datetime(3000, 1, 1)
25def unstr_datetime(s: dt.datetime | str) -> dt.datetime:
26 """json helper... might values arrive as str"""
27 return dateutil.parser.parse(s) if isinstance(s, str) else s
30class Password:
31 """
32 Encode a password using: Password.encode('secret')
33 """
35 @classmethod
36 def hmac_hash(cls, value: str, key: str, salt: str = None) -> str:
37 digest_alg = hashlib.sha512
38 d = hmac.new(str(key).encode(), str(value).encode(), digest_alg)
39 if salt:
40 d.update(str(salt).encode())
41 return d.hexdigest()
43 @classmethod
44 def validate(cls, password: str, candidate: str) -> bool:
45 salt, hashed = candidate.split(":", 1)
46 return cls.hmac_hash(value=password, key="secret_start", salt=salt) == hashed
48 @classmethod
49 def encode(cls, password: str) -> str:
50 salt = uuid.uuid4().hex
51 return salt + ":" + cls.hmac_hash(value=password, key="secret_start", salt=salt)
54def is_uuid(s) -> bool:
55 try:
56 UUID(s)
57 return True
58 except Exception:
59 return False
62IdentityKey: typing.TypeAlias = str | int | UUID
63ObjectTypes = typing.Literal["user", "group", "item"]
66def key_lookup_query(db: DAL, identity_key: IdentityKey, object_type: ObjectTypes = None) -> Query:
67 if "@" in str(identity_key):
68 query = db.identity.email == identity_key.lower()
69 elif isinstance(identity_key, int):
70 query = db.identity.id == identity_key
71 elif is_uuid(identity_key):
72 query = db.identity.object_id == identity_key.lower()
73 else:
74 query = db.identity.firstname == identity_key
76 if object_type:
77 query &= db.identity.object_type == object_type
79 return query
82def key_lookup(db: DAL, identity_key: IdentityKey, object_type: ObjectTypes = None) -> str | None:
83 query = key_lookup_query(db, identity_key, object_type)
85 rowset = db(query).select(db.identity.object_id)
87 if not rowset:
88 return None
89 elif len(rowset) > 1:
90 raise ValueError("Keep lookup for {} returned {} results.".format(identity_key, len(rowset)))
92 return rowset.first().object_id
95my_datetime = SQLCustomType(
96 type="string", native="char(35)", encoder=(lambda x: x.isoformat(" ")), decoder=(lambda x: dateutil.parser.parse(x))
97)
100class RbacKwargs(typing.TypedDict, total=False):
101 allowed_types: list[str]
102 migrate: bool
105class Identity(typing.Protocol):
106 object_id: str
107 object_type: str
108 created: dt.datetime
109 email: str
110 firstname: str
111 fullname: str
112 encoded_password: str
115def define_auth_rbac_model(db: DAL, other_args: RbacKwargs):
116 migrate = other_args.get("migrate", False)
118 db.define_table(
119 "identity",
120 # std uuid from uuid libs are 36 chars long
121 Field("object_id", "string", length=36, unique=True, notnull=True, default=str(uuid.uuid4())),
122 Field("object_type", "string", requires=(IS_IN_LIST(other_args["allowed_types"]))),
123 Field("created", "datetime", default=dt.datetime.now),
124 # email needn't be unique, groups can share email addresses, and with people too
125 Field("email", "string"),
126 Field("firstname", "string", comment="also used as short code for groups"),
127 Field("fullname", "string"),
128 Field("encoded_password", "string"),
129 migrate=migrate,
130 )
132 db.define_table(
133 "membership",
134 # beide zijn eigenlijk: reference:identity.object_id
135 Field("subject", "string", length=36, notnull=True),
136 Field("member_of", "string", length=36, notnull=True),
137 # Field('starts','datetime', default=DEFAULT_STARTS),
138 # Field('ends','datetime', default=DEFAULT_ENDS),
139 migrate=migrate,
140 )
142 db.define_table(
143 "permission",
144 Field("privilege", "string", length=20),
145 # reference:identity.object_id
146 Field("identity_object_id", "string", length=36),
147 Field("target_object_id", "string", length=36),
148 # Field('scope'), lets bail scope for now. every one needs a rule for everything
149 # just to make sure, no 'wildcards' and 'every dossier for org x' etc ...
150 Field("starts", type=my_datetime, default=DEFAULT_STARTS),
151 Field("ends", type=my_datetime, default=DEFAULT_ENDS),
152 migrate=migrate,
153 )
155 db.define_table(
156 "recursive_memberships",
157 Field("root"),
158 Field("object_id"),
159 Field("object_type"),
160 Field("level", "integer"),
161 Field("email"),
162 Field("firstname"),
163 Field("fullname"),
164 migrate=False, # view
165 primarykey=["root", "object_id"], # composed, no primary key
166 )
167 db.define_table(
168 "recursive_members",
169 Field("root"),
170 Field("object_id"),
171 Field("object_type"),
172 Field("level", "integer"),
173 Field("email"),
174 Field("firstname"),
175 Field("fullname"),
176 migrate=False, # view
177 primarykey=["root", "object_id"], # composed, no primary key
178 )
181def add_identity(
182 db: DAL,
183 email: str,
184 member_of: list[IdentityKey],
185 name: str = None,
186 firstname: str = None,
187 fullname: str = None,
188 password: str = None,
189 gid: str | UUID = None,
190 object_type: ObjectTypes = None,
191) -> str:
192 """paramaters name and firstname are equal."""
193 email = email.lower().strip()
194 if object_type is None:
195 raise ValueError("object_type parameter expected")
196 object_id = gid if gid else uuid.uuid4()
197 db.identity.validate_and_insert(
198 object_id=object_id,
199 object_type=object_type,
200 email=email,
201 firstname=name or firstname or None,
202 fullname=fullname,
203 encoded_password=Password.encode(password),
204 )
205 db.commit()
206 for key in member_of:
207 group_id = key_lookup(db, key, "group")
208 if get_group(db, group_id):
209 # check each group if it exists.
210 add_membership(db, identity_key=object_id, group_key=group_id)
211 db.commit()
212 return str(object_id)
215def add_group(db: DAL, email: str, name: str, member_of: list[IdentityKey]):
216 return add_identity(db, email, member_of, name=name, object_type="group")
219def remove_identity(db: DAL, object_id: IdentityKey):
220 removed = db(db.identity.object_id == object_id).delete()
221 # todo: remove permissions and group memberships
222 db.commit()
223 return removed > 0
226def get_identity(db: DAL, key: IdentityKey, object_type: ObjectTypes = None):
227 """
228 :param db: dal db connection
229 :param key: can be the email, id, or object_id
230 :return: user record or None when not found
231 """
232 query = key_lookup_query(db, key, object_type)
233 rows = db(query).select()
234 return rows.first()
237def get_user(db: DAL, key: IdentityKey):
238 """
239 :param db: dal db connection
240 :param key: can be the email, id, or object_id
241 :return: user record or None when not found
242 """
243 return get_identity(db, key, object_type="user")
246def get_group(db: DAL, key: IdentityKey):
247 """
249 :param db: dal db connection
250 :param key: can be the name of the group, the id, object_id or email_address
251 :return: user record or None when not found
252 """
253 return get_identity(db, key, object_type="group")
256def authenticate_user(db: DAL, password: str = None, user: Identity = None, key: IdentityKey = None):
257 if not password:
258 return False
259 if not user:
260 user = get_user(db, key)
261 return Password.validate(password, user.encoded_password)
264def add_membership(db: DAL, identity_key: IdentityKey, group_key: IdentityKey):
265 identity_oid = key_lookup(db, identity_key)
266 if identity_oid is None:
267 raise ValueError("invalid identity_oid key: " + identity_key)
268 group = get_group(db, group_key)
269 if not group:
270 raise ValueError("invalid group key: " + group_key)
271 query = db.membership.subject == identity_oid
272 query &= db.membership.member_of == group.object_id
273 if db(query).count() == 0:
274 db.membership.validate_and_insert(
275 subject=identity_oid,
276 member_of=group.object_id,
277 )
278 db.commit()
281def remove_membership(db: DAL, identity_key: IdentityKey, group_key: IdentityKey):
282 identity = get_identity(db, identity_key)
283 group = get_group(db, group_key)
284 query = db.membership.subject == identity.object_id
285 query &= db.membership.member_of == group.object_id
286 deleted = db(query).delete()
287 db.commit()
288 return deleted
291def get_memberships(db: DAL, object_id: IdentityKey, bare: bool = True):
292 query = db.recursive_memberships.root == object_id
293 fields = [db.recursive_memberships.object_id, db.recursive_memberships.object_type] if bare else []
294 return db(query).select(*fields)
297def get_members(db: DAL, object_id: IdentityKey, bare: bool = True):
298 query = db.recursive_members.root == object_id
299 fields = [db.recursive_members.object_id, db.recursive_members.object_type] if bare else []
300 return db(query).select(*fields)
303def add_permission(
304 db: DAL,
305 identity_key: IdentityKey,
306 target_oid: IdentityKey,
307 privilege: str,
308 starts: dt.datetime | str = DEFAULT_STARTS,
309 ends: dt.datetime | str = DEFAULT_ENDS,
310):
311 identity_oid = key_lookup(db, identity_key)
312 starts = unstr_datetime(starts)
313 ends = unstr_datetime(ends)
314 if has_permission(db, identity_oid, target_oid, privilege, when=starts):
315 # permission already granted. just skip it
316 print(
317 "{privilege} permission already granted to {user_or_group_key} on {target_oid} @ {starts} ".format(
318 **locals()
319 )
320 )
321 # print(db._lastsql)
322 return
323 db.permission.validate_and_insert(
324 privilege=privilege,
325 identity_object_id=identity_oid,
326 target_object_id=target_oid,
327 starts=starts,
328 ends=ends,
329 )
330 db.commit()
333def remove_permission(
334 db: DAL, identity_key: IdentityKey, target_oid: IdentityKey, privilege: str, when: dt.datetime | str = DEFAULT
335):
336 identity_oid = key_lookup(db, identity_key)
337 if when is DEFAULT:
338 when = dt.datetime.now()
339 else:
340 when = unstr_datetime(when)
341 # base object is is the root to check for, user or group
342 permission = db.permission
343 query = permission.identity_object_id == identity_oid
344 query &= permission.target_object_id == target_oid
345 query &= permission.privilege == privilege
346 query &= permission.starts <= when
347 query &= permission.ends >= when
348 result = db(query).delete() > 0
349 db.commit()
350 # print(db._lastsql)
351 return result
354def with_alias(db: DAL, source: Table, alias: str):
355 other = copy.copy(source)
356 other["ALL"] = SQLALL(other)
357 other["_tablename"] = alias
358 for fieldname in other.fields:
359 tmp = source[fieldname].clone()
360 tmp.bind(other)
361 other[fieldname] = tmp
362 if "id" in source and "id" not in other.fields:
363 other["id"] = other[source.id.name]
365 if source_id := getattr(source, "_id", None):
366 other._id = other[source_id.name]
367 db[alias] = other
368 return other
371def has_permission(
372 db: DAL, user_or_group_key: IdentityKey, target_oid: IdentityKey, privilege: str, when: dt.datetime | str = DEFAULT
373):
374 user_or_group_oid = key_lookup(db, user_or_group_key)
375 # the permission system
376 if when is DEFAULT:
377 when = dt.datetime.now()
378 else:
379 when = unstr_datetime(when)
380 # base object is is the root to check for, user or group
381 root_oid = user_or_group_oid
382 permission = db.permission
383 # ugly hack to satisfy pydal aliasing keyed tables /views
384 left = with_alias(db, db.recursive_memberships, "left")
385 right = with_alias(db, db.recursive_memberships, "right")
386 # left = db.recursive_memberships.with_alias('left')
387 # right = db.recursive_memberships.with_alias('right')
389 # end of ugly hack
390 query = left.root == root_oid
391 query &= right.root == target_oid
392 query &= permission.identity_object_id == left.object_id
393 query &= permission.target_object_id == right.object_id
394 query &= permission.privilege == privilege
395 query &= permission.starts <= when
396 query &= permission.ends >= when
397 return db(query).count() > 0