Coverage for src/typedal/caching.py: 100%
127 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-18 13:49 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-18 13:49 +0100
1"""
2Helpers to facilitate db-based caching.
3"""
4import contextlib
5import hashlib
6import json
7import typing
8from datetime import datetime, timedelta, timezone
9from typing import Any, Iterable, Mapping, Optional, TypeVar
11import dill # nosec
12from pydal.objects import Field, Rows, Set
14from .core import TypedField, TypedRows, TypedTable
16if typing.TYPE_CHECKING: # pragma: no cover
17 from .core import TypeDAL
20def get_now(tz: timezone = timezone.utc) -> datetime:
21 """
22 Get the default datetime, optionally in a specific timezone.
23 """
24 return datetime.now(tz)
27class _TypedalCache(TypedTable):
28 """
29 Internal table to store previously loaded models.
30 """
32 key: TypedField[str]
33 data: TypedField[bytes]
34 cached_at = TypedField(datetime, default=get_now)
35 expires_at: TypedField[datetime | None]
38class _TypedalCacheDependency(TypedTable):
39 """
40 Internal table that stores dependencies to invalidate cache when a related table is updated.
41 """
43 entry: TypedField[_TypedalCache]
44 table: TypedField[str]
45 idx: TypedField[int]
48def prepare(field: Any) -> str:
49 """
50 Prepare data to be used in a cache key.
52 By sorting and stringifying data, queries can be syntactically different from each other \
53 but when semantically exactly the same will still be loaded from cache.
54 """
55 if isinstance(field, str):
56 return field
57 elif isinstance(field, (dict, Mapping)):
58 data = {str(k): prepare(v) for k, v in field.items()}
59 return json.dumps(data, sort_keys=True)
60 elif isinstance(field, Iterable):
61 return ",".join(sorted([prepare(_) for _ in field]))
62 elif isinstance(field, bool):
63 return str(int(field))
64 else:
65 return str(field)
68def create_cache_key(*fields: Any) -> str:
69 """
70 Turn any fields of data into a string.
71 """
72 return "|".join(prepare(field) for field in fields)
75def hash_cache_key(cache_key: str | bytes) -> str:
76 """
77 Hash the input cache key with SHA 256.
78 """
79 h = hashlib.sha256()
80 h.update(cache_key.encode() if isinstance(cache_key, str) else cache_key)
81 return h.hexdigest()
84def create_and_hash_cache_key(*fields: Any) -> tuple[str, str]:
85 """
86 Combine the input fields into one key and hash it with SHA 256.
87 """
88 key = create_cache_key(*fields)
89 return key, hash_cache_key(key)
92DependencyTuple = tuple[str, int] # table + id
93DependencyTupleSet = set[DependencyTuple]
96def _get_table_name(field: Field) -> str:
97 """
98 Get the table name from a field or alias.
99 """
100 return str(field._table).split(" AS ")[0].strip()
103def _get_dependency_ids(rows: Rows, dependency_keys: list[tuple[Field, str]]) -> DependencyTupleSet:
104 dependencies = set()
105 for row in rows:
106 for field, table in dependency_keys:
107 if idx := row[field]:
108 dependencies.add((table, idx))
110 return dependencies
113def _determine_dependencies_auto(_: TypedRows[Any], rows: Rows) -> DependencyTupleSet:
114 dependency_keys = []
115 for field in rows.fields:
116 if str(field).endswith(".id"):
117 table_name = _get_table_name(field)
119 dependency_keys.append((field, table_name))
121 return _get_dependency_ids(rows, dependency_keys)
124def _determine_dependencies(instance: TypedRows[Any], rows: Rows, depends_on: list[Any]) -> DependencyTupleSet:
125 if not depends_on:
126 return _determine_dependencies_auto(instance, rows)
128 target_field_names = set()
129 for field in depends_on:
130 if "." not in field:
131 field = f"{instance.model._table}.{field}"
133 target_field_names.add(str(field))
135 dependency_keys = []
136 for field in rows.fields:
137 if str(field) in target_field_names:
138 table_name = _get_table_name(field)
140 dependency_keys.append((field, table_name))
142 return _get_dependency_ids(rows, dependency_keys)
145def remove_cache(idx: int | Iterable[int], table: str) -> None:
146 """
147 Remove any cache entries that are dependant on one or multiple indices of a table.
148 """
149 if not isinstance(idx, Iterable):
150 idx = [idx]
152 related = (
153 _TypedalCacheDependency.where(table=table).where(lambda row: row.idx.belongs(idx)).select("entry").to_sql()
154 )
156 _TypedalCache.where(_TypedalCache.id.belongs(related)).delete()
159def clear_cache() -> None:
160 """
161 Remove everything from the cache.
162 """
163 _TypedalCacheDependency.truncate()
164 _TypedalCache.truncate()
167def clear_expired() -> int:
168 """
169 Remove all expired items from the cache.
171 By default, expired items are only removed when trying to access them.
172 """
173 now = get_now()
174 return len(_TypedalCache.where(_TypedalCache.expires_at > now).delete())
177def _remove_cache(s: Set, tablename: str) -> None:
178 """
179 Used as the table._before_update and table._after_update for every TypeDAL table (on by default).
180 """
181 indeces = s.select("id").column("id")
182 remove_cache(indeces, tablename)
185T_TypedTable = TypeVar("T_TypedTable", bound=TypedTable)
188def get_expire(
189 expires_at: Optional[datetime] = None, ttl: Optional[int | timedelta] = None, now: Optional[datetime] = None
190) -> datetime | None:
191 """
192 Based on an expires_at date or a ttl (in seconds or a time delta), determine the expire date.
193 """
194 now = now or get_now()
196 if expires_at and ttl:
197 raise ValueError("Please only supply an `expired at` date or a `ttl` in seconds!")
198 elif isinstance(ttl, timedelta):
199 return now + ttl
200 elif ttl:
201 return now + timedelta(seconds=ttl)
202 elif expires_at:
203 return expires_at
205 return None
208def save_to_cache(
209 instance: TypedRows[T_TypedTable],
210 rows: Rows,
211 expires_at: Optional[datetime] = None,
212 ttl: Optional[int | timedelta] = None,
213) -> TypedRows[T_TypedTable]:
214 """
215 Save a typedrows result to the database, and save dependencies from rows.
217 You can call .cache(...) with dependent fields (e.g. User.id) or this function will determine them automatically.
218 """
219 db = rows.db
220 if (c := instance.metadata.get("cache", {})) and c.get("enabled") and (key := c.get("key")):
221 expires_at = get_expire(expires_at=expires_at, ttl=ttl) or c.get("expires_at")
223 deps = _determine_dependencies(instance, rows, c["depends_on"])
225 entry = _TypedalCache.insert(
226 key=key,
227 data=dill.dumps(instance),
228 expires_at=expires_at,
229 )
231 _TypedalCacheDependency.bulk_insert([{"entry": entry, "table": table, "idx": idx} for table, idx in deps])
233 db.commit()
234 instance.metadata["cache"]["status"] = "fresh"
235 return instance
238def _load_from_cache(key: str, db: "TypeDAL") -> Any | None:
239 if not (row := _TypedalCache.where(key=key).first()):
240 return None
242 now = get_now()
244 expires = row.expires_at.replace(tzinfo=timezone.utc) if row.expires_at else None
246 if expires and now >= expires:
247 row.delete_record()
248 return None
250 inst = dill.loads(row.data) # nosec
252 inst.metadata["cache"]["status"] = "cached"
253 inst.metadata["cache"]["cached_at"] = row.cached_at
254 inst.metadata["cache"]["expires_at"] = row.expires_at
256 inst.db = db
257 inst.model = db._class_map[inst.model]
258 inst.model._setup_instance_methods(inst.model) # type: ignore
259 return inst
262def load_from_cache(key: str, db: "TypeDAL") -> Any | None:
263 """
264 If 'key' matches a non-expired row in the database, try to load the dill.
266 If anything fails, return None.
267 """
268 with contextlib.suppress(Exception):
269 return _load_from_cache(key, db)
271 return None # pragma: no cover