Coverage for src/typedal/caching.py: 100%

127 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-18 13:49 +0100

1""" 

2Helpers to facilitate db-based caching. 

3""" 

4import contextlib 

5import hashlib 

6import json 

7import typing 

8from datetime import datetime, timedelta, timezone 

9from typing import Any, Iterable, Mapping, Optional, TypeVar 

10 

11import dill # nosec 

12from pydal.objects import Field, Rows, Set 

13 

14from .core import TypedField, TypedRows, TypedTable 

15 

16if typing.TYPE_CHECKING: # pragma: no cover 

17 from .core import TypeDAL 

18 

19 

20def get_now(tz: timezone = timezone.utc) -> datetime: 

21 """ 

22 Get the default datetime, optionally in a specific timezone. 

23 """ 

24 return datetime.now(tz) 

25 

26 

27class _TypedalCache(TypedTable): 

28 """ 

29 Internal table to store previously loaded models. 

30 """ 

31 

32 key: TypedField[str] 

33 data: TypedField[bytes] 

34 cached_at = TypedField(datetime, default=get_now) 

35 expires_at: TypedField[datetime | None] 

36 

37 

38class _TypedalCacheDependency(TypedTable): 

39 """ 

40 Internal table that stores dependencies to invalidate cache when a related table is updated. 

41 """ 

42 

43 entry: TypedField[_TypedalCache] 

44 table: TypedField[str] 

45 idx: TypedField[int] 

46 

47 

48def prepare(field: Any) -> str: 

49 """ 

50 Prepare data to be used in a cache key. 

51 

52 By sorting and stringifying data, queries can be syntactically different from each other \ 

53 but when semantically exactly the same will still be loaded from cache. 

54 """ 

55 if isinstance(field, str): 

56 return field 

57 elif isinstance(field, (dict, Mapping)): 

58 data = {str(k): prepare(v) for k, v in field.items()} 

59 return json.dumps(data, sort_keys=True) 

60 elif isinstance(field, Iterable): 

61 return ",".join(sorted([prepare(_) for _ in field])) 

62 elif isinstance(field, bool): 

63 return str(int(field)) 

64 else: 

65 return str(field) 

66 

67 

68def create_cache_key(*fields: Any) -> str: 

69 """ 

70 Turn any fields of data into a string. 

71 """ 

72 return "|".join(prepare(field) for field in fields) 

73 

74 

75def hash_cache_key(cache_key: str | bytes) -> str: 

76 """ 

77 Hash the input cache key with SHA 256. 

78 """ 

79 h = hashlib.sha256() 

80 h.update(cache_key.encode() if isinstance(cache_key, str) else cache_key) 

81 return h.hexdigest() 

82 

83 

84def create_and_hash_cache_key(*fields: Any) -> tuple[str, str]: 

85 """ 

86 Combine the input fields into one key and hash it with SHA 256. 

87 """ 

88 key = create_cache_key(*fields) 

89 return key, hash_cache_key(key) 

90 

91 

92DependencyTuple = tuple[str, int] # table + id 

93DependencyTupleSet = set[DependencyTuple] 

94 

95 

96def _get_table_name(field: Field) -> str: 

97 """ 

98 Get the table name from a field or alias. 

99 """ 

100 return str(field._table).split(" AS ")[0].strip() 

101 

102 

103def _get_dependency_ids(rows: Rows, dependency_keys: list[tuple[Field, str]]) -> DependencyTupleSet: 

104 dependencies = set() 

105 for row in rows: 

106 for field, table in dependency_keys: 

107 if idx := row[field]: 

108 dependencies.add((table, idx)) 

109 

110 return dependencies 

111 

112 

113def _determine_dependencies_auto(_: TypedRows[Any], rows: Rows) -> DependencyTupleSet: 

114 dependency_keys = [] 

115 for field in rows.fields: 

116 if str(field).endswith(".id"): 

117 table_name = _get_table_name(field) 

118 

119 dependency_keys.append((field, table_name)) 

120 

121 return _get_dependency_ids(rows, dependency_keys) 

122 

123 

124def _determine_dependencies(instance: TypedRows[Any], rows: Rows, depends_on: list[Any]) -> DependencyTupleSet: 

125 if not depends_on: 

126 return _determine_dependencies_auto(instance, rows) 

127 

128 target_field_names = set() 

129 for field in depends_on: 

130 if "." not in field: 

131 field = f"{instance.model._table}.{field}" 

132 

133 target_field_names.add(str(field)) 

134 

135 dependency_keys = [] 

136 for field in rows.fields: 

137 if str(field) in target_field_names: 

138 table_name = _get_table_name(field) 

139 

140 dependency_keys.append((field, table_name)) 

141 

142 return _get_dependency_ids(rows, dependency_keys) 

143 

144 

145def remove_cache(idx: int | Iterable[int], table: str) -> None: 

146 """ 

147 Remove any cache entries that are dependant on one or multiple indices of a table. 

148 """ 

149 if not isinstance(idx, Iterable): 

150 idx = [idx] 

151 

152 related = ( 

153 _TypedalCacheDependency.where(table=table).where(lambda row: row.idx.belongs(idx)).select("entry").to_sql() 

154 ) 

155 

156 _TypedalCache.where(_TypedalCache.id.belongs(related)).delete() 

157 

158 

159def clear_cache() -> None: 

160 """ 

161 Remove everything from the cache. 

162 """ 

163 _TypedalCacheDependency.truncate() 

164 _TypedalCache.truncate() 

165 

166 

167def clear_expired() -> int: 

168 """ 

169 Remove all expired items from the cache. 

170 

171 By default, expired items are only removed when trying to access them. 

172 """ 

173 now = get_now() 

174 return len(_TypedalCache.where(_TypedalCache.expires_at > now).delete()) 

175 

176 

177def _remove_cache(s: Set, tablename: str) -> None: 

178 """ 

179 Used as the table._before_update and table._after_update for every TypeDAL table (on by default). 

180 """ 

181 indeces = s.select("id").column("id") 

182 remove_cache(indeces, tablename) 

183 

184 

185T_TypedTable = TypeVar("T_TypedTable", bound=TypedTable) 

186 

187 

188def get_expire( 

189 expires_at: Optional[datetime] = None, ttl: Optional[int | timedelta] = None, now: Optional[datetime] = None 

190) -> datetime | None: 

191 """ 

192 Based on an expires_at date or a ttl (in seconds or a time delta), determine the expire date. 

193 """ 

194 now = now or get_now() 

195 

196 if expires_at and ttl: 

197 raise ValueError("Please only supply an `expired at` date or a `ttl` in seconds!") 

198 elif isinstance(ttl, timedelta): 

199 return now + ttl 

200 elif ttl: 

201 return now + timedelta(seconds=ttl) 

202 elif expires_at: 

203 return expires_at 

204 

205 return None 

206 

207 

208def save_to_cache( 

209 instance: TypedRows[T_TypedTable], 

210 rows: Rows, 

211 expires_at: Optional[datetime] = None, 

212 ttl: Optional[int | timedelta] = None, 

213) -> TypedRows[T_TypedTable]: 

214 """ 

215 Save a typedrows result to the database, and save dependencies from rows. 

216 

217 You can call .cache(...) with dependent fields (e.g. User.id) or this function will determine them automatically. 

218 """ 

219 db = rows.db 

220 if (c := instance.metadata.get("cache", {})) and c.get("enabled") and (key := c.get("key")): 

221 expires_at = get_expire(expires_at=expires_at, ttl=ttl) or c.get("expires_at") 

222 

223 deps = _determine_dependencies(instance, rows, c["depends_on"]) 

224 

225 entry = _TypedalCache.insert( 

226 key=key, 

227 data=dill.dumps(instance), 

228 expires_at=expires_at, 

229 ) 

230 

231 _TypedalCacheDependency.bulk_insert([{"entry": entry, "table": table, "idx": idx} for table, idx in deps]) 

232 

233 db.commit() 

234 instance.metadata["cache"]["status"] = "fresh" 

235 return instance 

236 

237 

238def _load_from_cache(key: str, db: "TypeDAL") -> Any | None: 

239 if not (row := _TypedalCache.where(key=key).first()): 

240 return None 

241 

242 now = get_now() 

243 

244 expires = row.expires_at.replace(tzinfo=timezone.utc) if row.expires_at else None 

245 

246 if expires and now >= expires: 

247 row.delete_record() 

248 return None 

249 

250 inst = dill.loads(row.data) # nosec 

251 

252 inst.metadata["cache"]["status"] = "cached" 

253 inst.metadata["cache"]["cached_at"] = row.cached_at 

254 inst.metadata["cache"]["expires_at"] = row.expires_at 

255 

256 inst.db = db 

257 inst.model = db._class_map[inst.model] 

258 inst.model._setup_instance_methods(inst.model) # type: ignore 

259 return inst 

260 

261 

262def load_from_cache(key: str, db: "TypeDAL") -> Any | None: 

263 """ 

264 If 'key' matches a non-expired row in the database, try to load the dill. 

265 

266 If anything fails, return None. 

267 """ 

268 with contextlib.suppress(Exception): 

269 return _load_from_cache(key, db) 

270 

271 return None # pragma: no cover