Coverage for src/typedal/caching.py: 100%
123 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-04 18:50 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-04 18:50 +0100
1"""
2Helpers to facilitate db-based caching.
3"""
5import contextlib
6import hashlib
7import json
8from datetime import datetime, timedelta, timezone
9from typing import Any, Iterable, Mapping, Optional, TypeVar
11import dill # nosec
12from pydal.objects import Field, Rows, Set
14from .core import TypedField, TypedRows, TypedTable
17def get_now(tz: timezone = timezone.utc) -> datetime:
18 """
19 Get the default datetime, optionally in a specific timezone.
20 """
21 return datetime.now(tz)
24class _TypedalCache(TypedTable):
25 """
26 Internal table to store previously loaded models.
27 """
29 key: TypedField[str]
30 data: TypedField[bytes]
31 cached_at = TypedField(datetime, default=get_now)
32 expires_at: TypedField[datetime | None]
35class _TypedalCacheDependency(TypedTable):
36 """
37 Internal table that stores dependencies to invalidate cache when a related table is updated.
38 """
40 entry: TypedField[_TypedalCache]
41 table: TypedField[str]
42 idx: TypedField[int]
45def prepare(field: Any) -> str:
46 """
47 Prepare data to be used in a cache key.
49 By sorting and stringifying data, queries can be syntactically different from each other \
50 but when semantically exactly the same will still be loaded from cache.
51 """
52 if isinstance(field, str):
53 return field
54 elif isinstance(field, (dict, Mapping)):
55 data = {str(k): prepare(v) for k, v in field.items()}
56 return json.dumps(data, sort_keys=True)
57 elif isinstance(field, Iterable):
58 return ",".join(sorted([prepare(_) for _ in field]))
59 elif isinstance(field, bool):
60 return str(int(field))
61 else:
62 return str(field)
65def create_cache_key(*fields: Any) -> str:
66 """
67 Turn any fields of data into a string.
68 """
69 return "|".join(prepare(field) for field in fields)
72def hash_cache_key(cache_key: str | bytes) -> str:
73 """
74 Hash the input cache key with SHA 256.
75 """
76 h = hashlib.sha256()
77 h.update(cache_key.encode() if isinstance(cache_key, str) else cache_key)
78 return h.hexdigest()
81def create_and_hash_cache_key(*fields: Any) -> tuple[str, str]:
82 """
83 Combine the input fields into one key and hash it with SHA 256.
84 """
85 key = create_cache_key(*fields)
86 return key, hash_cache_key(key)
89DependencyTuple = tuple[str, int] # table + id
90DependencyTupleSet = set[DependencyTuple]
93def _get_table_name(field: Field) -> str:
94 """
95 Get the table name from a field or alias.
96 """
97 return str(field._table).split(" AS ")[0].strip()
100def _get_dependency_ids(rows: Rows, dependency_keys: list[tuple[Field, str]]) -> DependencyTupleSet:
101 dependencies = set()
102 for row in rows:
103 for field, table in dependency_keys:
104 if idx := row[field]:
105 dependencies.add((table, idx))
107 return dependencies
110def _determine_dependencies_auto(_: TypedRows[Any], rows: Rows) -> DependencyTupleSet:
111 dependency_keys = []
112 for field in rows.fields:
113 if str(field).endswith(".id"):
114 table_name = _get_table_name(field)
116 dependency_keys.append((field, table_name))
118 return _get_dependency_ids(rows, dependency_keys)
121def _determine_dependencies(instance: TypedRows[Any], rows: Rows, depends_on: list[Any]) -> DependencyTupleSet:
122 if not depends_on:
123 return _determine_dependencies_auto(instance, rows)
125 target_field_names = set()
126 for field in depends_on:
127 if "." not in field:
128 field = f"{instance.model._table}.{field}"
130 target_field_names.add(str(field))
132 dependency_keys = []
133 for field in rows.fields:
134 if str(field) in target_field_names:
135 table_name = _get_table_name(field)
137 dependency_keys.append((field, table_name))
139 return _get_dependency_ids(rows, dependency_keys)
142def remove_cache(idx: int | Iterable[int], table: str) -> None:
143 """
144 Remove any cache entries that are dependant on one or multiple indices of a table.
145 """
146 if not isinstance(idx, Iterable):
147 idx = [idx]
149 related = (
150 _TypedalCacheDependency.where(table=table).where(lambda row: row.idx.belongs(idx)).select("entry").to_sql()
151 )
153 _TypedalCache.where(_TypedalCache.id.belongs(related)).delete()
156def clear_cache() -> None:
157 """
158 Remove everything from the cache.
159 """
160 _TypedalCacheDependency.truncate()
161 _TypedalCache.truncate()
164def clear_expired() -> int:
165 """
166 Remove all expired items from the cache.
168 By default, expired items are only removed when trying to access them.
169 """
170 now = get_now()
171 return len(_TypedalCache.where(_TypedalCache.expires_at > now).delete())
174def _remove_cache(s: Set, tablename: str) -> None:
175 """
176 Used as the table._before_update and table._after_update for every TypeDAL table (on by default).
177 """
178 indeces = s.select("id").column("id")
179 remove_cache(indeces, tablename)
182T_TypedTable = TypeVar("T_TypedTable", bound=TypedTable)
185def get_expire(
186 expires_at: Optional[datetime] = None, ttl: Optional[int | timedelta] = None, now: Optional[datetime] = None
187) -> datetime | None:
188 """
189 Based on an expires_at date or a ttl (in seconds or a time delta), determine the expire date.
190 """
191 now = now or get_now()
193 if expires_at and ttl:
194 raise ValueError("Please only supply an `expired at` date or a `ttl` in seconds!")
195 elif isinstance(ttl, timedelta):
196 return now + ttl
197 elif ttl:
198 return now + timedelta(seconds=ttl)
199 elif expires_at:
200 return expires_at
202 return None
205def save_to_cache(
206 instance: TypedRows[T_TypedTable],
207 rows: Rows,
208 expires_at: Optional[datetime] = None,
209 ttl: Optional[int | timedelta] = None,
210) -> TypedRows[T_TypedTable]:
211 """
212 Save a typedrows result to the database, and save dependencies from rows.
214 You can call .cache(...) with dependent fields (e.g. User.id) or this function will determine them automatically.
215 """
216 db = rows.db
217 if (c := instance.metadata.get("cache", {})) and c.get("enabled") and (key := c.get("key")):
218 expires_at = get_expire(expires_at=expires_at, ttl=ttl) or c.get("expires_at")
220 deps = _determine_dependencies(instance, rows, c["depends_on"])
222 entry = _TypedalCache.insert(
223 key=key,
224 data=dill.dumps(instance),
225 expires_at=expires_at,
226 )
228 _TypedalCacheDependency.bulk_insert([{"entry": entry, "table": table, "idx": idx} for table, idx in deps])
230 db.commit()
231 instance.metadata["cache"]["status"] = "fresh"
232 return instance
235def _load_from_cache(key: str) -> Any | None:
236 if not (row := _TypedalCache.where(key=key).first()):
237 return None
239 now = get_now()
241 expires = row.expires_at.replace(tzinfo=timezone.utc) if row.expires_at else None
243 if expires and now >= expires:
244 row.delete_record()
245 return None
247 inst = dill.loads(row.data) # nosec
248 inst.metadata["cache"]["status"] = "cached"
249 inst.metadata["cache"]["cached_at"] = row.cached_at
250 inst.metadata["cache"]["expires_at"] = row.expires_at
251 return inst
254def load_from_cache(key: str) -> Any | None:
255 """
256 If 'key' matches a non-expired row in the database, try to load the dill.
258 If anything fails, return None.
259 """
260 with contextlib.suppress(Exception):
261 return _load_from_cache(key)
263 return None # pragma: no cover