Coverage for src/typedal/caching.py: 100%
169 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 16:34 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 16:34 +0200
1"""
2Helpers to facilitate db-based caching.
3"""
5import contextlib
6import hashlib
7import json
8import typing
9from datetime import datetime, timedelta, timezone
10from typing import Any, Iterable, Mapping, Optional, TypeVar
12import dill # nosec
13from pydal.objects import Field, Rows, Set
15from .core import TypedField, TypedRows, TypedTable
16from .types import Query
18if typing.TYPE_CHECKING: # pragma: no cover
19 from .core import TypeDAL
22def get_now(tz: timezone = timezone.utc) -> datetime:
23 """
24 Get the default datetime, optionally in a specific timezone.
25 """
26 return datetime.now(tz)
29class _TypedalCache(TypedTable):
30 """
31 Internal table to store previously loaded models.
32 """
34 key: TypedField[str]
35 data: TypedField[bytes]
36 cached_at = TypedField(datetime, default=get_now)
37 expires_at: TypedField[datetime | None]
40class _TypedalCacheDependency(TypedTable):
41 """
42 Internal table that stores dependencies to invalidate cache when a related table is updated.
43 """
45 entry: TypedField[_TypedalCache]
46 table: TypedField[str]
47 idx: TypedField[int]
50def prepare(field: Any) -> str:
51 """
52 Prepare data to be used in a cache key.
54 By sorting and stringifying data, queries can be syntactically different from each other \
55 but when semantically exactly the same will still be loaded from cache.
56 """
57 if isinstance(field, str):
58 return field
59 elif isinstance(field, (dict, Mapping)):
60 data = {str(k): prepare(v) for k, v in field.items()}
61 return json.dumps(data, sort_keys=True)
62 elif isinstance(field, Iterable):
63 return ",".join(sorted([prepare(_) for _ in field]))
64 elif isinstance(field, bool):
65 return str(int(field))
66 else:
67 return str(field)
70def create_cache_key(*fields: Any) -> str:
71 """
72 Turn any fields of data into a string.
73 """
74 return "|".join(prepare(field) for field in fields)
77def hash_cache_key(cache_key: str | bytes) -> str:
78 """
79 Hash the input cache key with SHA 256.
80 """
81 h = hashlib.sha256()
82 h.update(cache_key.encode() if isinstance(cache_key, str) else cache_key)
83 return h.hexdigest()
86def create_and_hash_cache_key(*fields: Any) -> tuple[str, str]:
87 """
88 Combine the input fields into one key and hash it with SHA 256.
89 """
90 key = create_cache_key(*fields)
91 return key, hash_cache_key(key)
94DependencyTuple = tuple[str, int] # table + id
95DependencyTupleSet = set[DependencyTuple]
98def _get_table_name(field: Field) -> str:
99 """
100 Get the table name from a field or alias.
101 """
102 return str(field._table).split(" AS ")[0].strip()
105def _get_dependency_ids(rows: Rows, dependency_keys: list[tuple[Field, str]]) -> DependencyTupleSet:
106 dependencies = set()
107 for row in rows:
108 for field, table in dependency_keys:
109 if idx := row[field]:
110 dependencies.add((table, idx))
112 return dependencies
115def _determine_dependencies_auto(_: TypedRows[Any], rows: Rows) -> DependencyTupleSet:
116 dependency_keys = []
117 for field in rows.fields:
118 if str(field).endswith(".id"):
119 table_name = _get_table_name(field)
121 dependency_keys.append((field, table_name))
123 return _get_dependency_ids(rows, dependency_keys)
126def _determine_dependencies(instance: TypedRows[Any], rows: Rows, depends_on: list[Any]) -> DependencyTupleSet:
127 if not depends_on:
128 return _determine_dependencies_auto(instance, rows)
130 target_field_names = set()
131 for field in depends_on:
132 if "." not in field:
133 field = f"{instance.model._table}.{field}"
135 target_field_names.add(str(field))
137 dependency_keys = []
138 for field in rows.fields:
139 if str(field) in target_field_names:
140 table_name = _get_table_name(field)
142 dependency_keys.append((field, table_name))
144 return _get_dependency_ids(rows, dependency_keys)
147def remove_cache(idx: int | Iterable[int], table: str) -> None:
148 """
149 Remove any cache entries that are dependant on one or multiple indices of a table.
150 """
151 if not isinstance(idx, Iterable):
152 idx = [idx]
154 related = (
155 _TypedalCacheDependency.where(table=table).where(lambda row: row.idx.belongs(idx)).select("entry").to_sql()
156 )
158 _TypedalCache.where(_TypedalCache.id.belongs(related)).delete()
161def clear_cache() -> None:
162 """
163 Remove everything from the cache.
164 """
165 _TypedalCacheDependency.truncate()
166 _TypedalCache.truncate()
169def clear_expired() -> int:
170 """
171 Remove all expired items from the cache.
173 By default, expired items are only removed when trying to access them.
174 """
175 now = get_now()
176 return len(_TypedalCache.where(_TypedalCache.expires_at != None).where(_TypedalCache.expires_at < now).delete())
179def _remove_cache(s: Set, tablename: str) -> None:
180 """
181 Used as the table._before_update and table._after_update for every TypeDAL table (on by default).
182 """
183 indeces = s.select("id").column("id")
184 remove_cache(indeces, tablename)
187T_TypedTable = TypeVar("T_TypedTable", bound=TypedTable)
190def get_expire(
191 expires_at: Optional[datetime] = None, ttl: Optional[int | timedelta] = None, now: Optional[datetime] = None
192) -> datetime | None:
193 """
194 Based on an expires_at date or a ttl (in seconds or a time delta), determine the expire date.
195 """
196 now = now or get_now()
198 if expires_at and ttl:
199 raise ValueError("Please only supply an `expired at` date or a `ttl` in seconds!")
200 elif isinstance(ttl, timedelta):
201 return now + ttl
202 elif ttl:
203 return now + timedelta(seconds=ttl)
204 elif expires_at:
205 return expires_at
207 return None
210def save_to_cache(
211 instance: TypedRows[T_TypedTable],
212 rows: Rows,
213 expires_at: Optional[datetime] = None,
214 ttl: Optional[int | timedelta] = None,
215) -> TypedRows[T_TypedTable]:
216 """
217 Save a typedrows result to the database, and save dependencies from rows.
219 You can call .cache(...) with dependent fields (e.g. User.id) or this function will determine them automatically.
220 """
221 db = rows.db
222 if (c := instance.metadata.get("cache", {})) and c.get("enabled") and (key := c.get("key")):
223 expires_at = get_expire(expires_at=expires_at, ttl=ttl) or c.get("expires_at")
225 deps = _determine_dependencies(instance, rows, c["depends_on"])
227 entry = _TypedalCache.insert(
228 key=key,
229 data=dill.dumps(instance),
230 expires_at=expires_at,
231 )
233 _TypedalCacheDependency.bulk_insert([{"entry": entry, "table": table, "idx": idx} for table, idx in deps])
235 db.commit()
236 instance.metadata["cache"]["status"] = "fresh"
237 return instance
240def _load_from_cache(key: str, db: "TypeDAL") -> Any | None:
241 if not (row := _TypedalCache.where(key=key).first()):
242 return None
244 now = get_now()
246 expires = row.expires_at.replace(tzinfo=timezone.utc) if row.expires_at else None
248 if expires and now >= expires:
249 row.delete_record()
250 return None
252 inst = dill.loads(row.data) # nosec
254 inst.metadata["cache"]["status"] = "cached"
255 inst.metadata["cache"]["cached_at"] = row.cached_at
256 inst.metadata["cache"]["expires_at"] = row.expires_at
258 inst.db = db
259 inst.model = db._class_map[inst.model]
260 inst.model._setup_instance_methods(inst.model) # type: ignore
261 return inst
264def load_from_cache(key: str, db: "TypeDAL") -> Any | None:
265 """
266 If 'key' matches a non-expired row in the database, try to load the dill.
268 If anything fails, return None.
269 """
270 with contextlib.suppress(Exception):
271 return _load_from_cache(key, db)
273 return None # pragma: no cover
276def humanize_bytes(size: int | float) -> str:
277 """
278 Turn a number of bytes into a human-readable version (e.g. 124 GB).
279 """
280 if not size:
281 return "0"
283 suffixes = ["B", "KB", "MB", "GB", "TB", "PB"] # List of suffixes for different magnitudes
284 suffix_index = 0
286 while size > 1024 and suffix_index < len(suffixes) - 1:
287 suffix_index += 1
288 size /= 1024.0
290 return f"{size:.2f} {suffixes[suffix_index]}"
293def _expired_and_valid_query() -> tuple[str, str]:
294 expired_items = (
295 _TypedalCache.where(lambda row: (row.expires_at < get_now()) & (row.expires_at != None))
296 .select(_TypedalCache.id)
297 .to_sql()
298 )
300 valid_items = _TypedalCache.where(~_TypedalCache.id.belongs(expired_items)).select(_TypedalCache.id).to_sql()
302 return expired_items, valid_items
305T = typing.TypeVar("T")
306Stats = typing.TypedDict("Stats", {"total": T, "valid": T, "expired": T})
308RowStats = typing.TypedDict(
309 "RowStats",
310 {
311 "Dependent Cache Entries": int,
312 },
313)
316def _row_stats(db: "TypeDAL", table: str, query: Query) -> RowStats:
317 count_field = _TypedalCacheDependency.entry.count()
318 stats: TypedRows[_TypedalCacheDependency] = db(query & (_TypedalCacheDependency.table == table)).select(
319 _TypedalCacheDependency.entry, count_field, groupby=_TypedalCacheDependency.entry
320 )
321 return {
322 "Dependent Cache Entries": len(stats),
323 }
326def row_stats(db: "TypeDAL", table: str, row_id: str) -> Stats[RowStats]:
327 """
328 Collect caching stats for a specific table row (by ID).
329 """
330 expired_items, valid_items = _expired_and_valid_query()
332 query = _TypedalCacheDependency.idx == row_id
334 return {
335 "total": _row_stats(db, table, query),
336 "valid": _row_stats(db, table, _TypedalCacheDependency.entry.belongs(valid_items) & query),
337 "expired": _row_stats(db, table, _TypedalCacheDependency.entry.belongs(expired_items) & query),
338 }
341TableStats = typing.TypedDict(
342 "TableStats",
343 {
344 "Dependent Cache Entries": int,
345 "Associated Table IDs": int,
346 },
347)
350def _table_stats(db: "TypeDAL", table: str, query: Query) -> TableStats:
351 count_field = _TypedalCacheDependency.entry.count()
352 stats: TypedRows[_TypedalCacheDependency] = db(query & (_TypedalCacheDependency.table == table)).select(
353 _TypedalCacheDependency.entry, count_field, groupby=_TypedalCacheDependency.entry
354 )
355 return {
356 "Dependent Cache Entries": len(stats),
357 "Associated Table IDs": sum(stats.column(count_field)),
358 }
361def table_stats(db: "TypeDAL", table: str) -> Stats[TableStats]:
362 """
363 Collect caching stats for a table.
364 """
365 expired_items, valid_items = _expired_and_valid_query()
367 return {
368 "total": _table_stats(db, table, _TypedalCacheDependency.id > 0),
369 "valid": _table_stats(db, table, _TypedalCacheDependency.entry.belongs(valid_items)),
370 "expired": _table_stats(db, table, _TypedalCacheDependency.entry.belongs(expired_items)),
371 }
374GenericStats = typing.TypedDict(
375 "GenericStats",
376 {
377 "entries": int,
378 "dependencies": int,
379 "size": str,
380 },
381)
384def _calculate_stats(db: "TypeDAL", query: Query) -> GenericStats:
385 sum_len_field = _TypedalCache.data.len().sum()
386 size_row = db(query).select(sum_len_field).first()
388 size = size_row[sum_len_field] if size_row else 0 # type: ignore
390 return {
391 "entries": _TypedalCache.where(query).count(),
392 "dependencies": db(_TypedalCacheDependency.entry.belongs(query)).count(),
393 "size": humanize_bytes(size),
394 }
397def calculate_stats(db: "TypeDAL") -> Stats[GenericStats]:
398 """
399 Collect generic caching stats.
400 """
401 expired_items, valid_items = _expired_and_valid_query()
403 return {
404 "total": _calculate_stats(db, _TypedalCache.id > 0),
405 "valid": _calculate_stats(db, _TypedalCache.id.belongs(valid_items)),
406 "expired": _calculate_stats(db, _TypedalCache.id.belongs(expired_items)),
407 }