Coverage for src/typedal/caching.py: 100%

169 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-08-05 19:10 +0200

1""" 

2Helpers to facilitate db-based caching. 

3""" 

4 

5import contextlib 

6import hashlib 

7import json 

8import typing 

9from datetime import datetime, timedelta, timezone 

10from typing import Any, Iterable, Mapping, Optional, TypeVar 

11 

12import dill # nosec 

13from pydal.objects import Field, Rows, Set 

14 

15from .core import TypedField, TypedRows, TypedTable 

16from .types import Query 

17 

18if typing.TYPE_CHECKING: 

19 from .core import TypeDAL 

20 

21 

22def get_now(tz: timezone = timezone.utc) -> datetime: 

23 """ 

24 Get the default datetime, optionally in a specific timezone. 

25 """ 

26 return datetime.now(tz) 

27 

28 

29class _TypedalCache(TypedTable): 

30 """ 

31 Internal table to store previously loaded models. 

32 """ 

33 

34 key: TypedField[str] 

35 data: TypedField[bytes] 

36 cached_at = TypedField(datetime, default=get_now) 

37 expires_at: TypedField[datetime | None] 

38 

39 

40class _TypedalCacheDependency(TypedTable): 

41 """ 

42 Internal table that stores dependencies to invalidate cache when a related table is updated. 

43 """ 

44 

45 entry: TypedField[_TypedalCache] 

46 table: TypedField[str] 

47 idx: TypedField[int] 

48 

49 

50def prepare(field: Any) -> str: 

51 """ 

52 Prepare data to be used in a cache key. 

53 

54 By sorting and stringifying data, queries can be syntactically different from each other \ 

55 but when semantically exactly the same will still be loaded from cache. 

56 """ 

57 if isinstance(field, str): 

58 return field 

59 elif isinstance(field, (dict, Mapping)): 

60 data = {str(k): prepare(v) for k, v in field.items()} 

61 return json.dumps(data, sort_keys=True) 

62 elif isinstance(field, Iterable): 

63 return ",".join(sorted([prepare(_) for _ in field])) 

64 elif isinstance(field, bool): 

65 return str(int(field)) 

66 else: 

67 return str(field) 

68 

69 

70def create_cache_key(*fields: Any) -> str: 

71 """ 

72 Turn any fields of data into a string. 

73 """ 

74 return "|".join(prepare(field) for field in fields) 

75 

76 

77def hash_cache_key(cache_key: str | bytes) -> str: 

78 """ 

79 Hash the input cache key with SHA 256. 

80 """ 

81 h = hashlib.sha256() 

82 h.update(cache_key.encode() if isinstance(cache_key, str) else cache_key) 

83 return h.hexdigest() 

84 

85 

86def create_and_hash_cache_key(*fields: Any) -> tuple[str, str]: 

87 """ 

88 Combine the input fields into one key and hash it with SHA 256. 

89 """ 

90 key = create_cache_key(*fields) 

91 return key, hash_cache_key(key) 

92 

93 

94DependencyTuple = tuple[str, int] # table + id 

95DependencyTupleSet = set[DependencyTuple] 

96 

97 

98def _get_table_name(field: Field) -> str: 

99 """ 

100 Get the table name from a field or alias. 

101 """ 

102 return str(field._table).split(" AS ")[0].strip() 

103 

104 

105def _get_dependency_ids(rows: Rows, dependency_keys: list[tuple[Field, str]]) -> DependencyTupleSet: 

106 dependencies = set() 

107 for row in rows: 

108 for field, table in dependency_keys: 

109 if idx := row[field]: 

110 dependencies.add((table, idx)) 

111 

112 return dependencies 

113 

114 

115def _determine_dependencies_auto(_: TypedRows[Any], rows: Rows) -> DependencyTupleSet: 

116 dependency_keys = [] 

117 for field in rows.fields: 

118 if str(field).endswith(".id"): 

119 table_name = _get_table_name(field) 

120 

121 dependency_keys.append((field, table_name)) 

122 

123 return _get_dependency_ids(rows, dependency_keys) 

124 

125 

126def _determine_dependencies(instance: TypedRows[Any], rows: Rows, depends_on: list[Any]) -> DependencyTupleSet: 

127 if not depends_on: 

128 return _determine_dependencies_auto(instance, rows) 

129 

130 target_field_names = set() 

131 for field in depends_on: 

132 if "." not in field: 

133 field = f"{instance.model._table}.{field}" 

134 

135 target_field_names.add(str(field)) 

136 

137 dependency_keys = [] 

138 for field in rows.fields: 

139 if str(field) in target_field_names: 

140 table_name = _get_table_name(field) 

141 

142 dependency_keys.append((field, table_name)) 

143 

144 return _get_dependency_ids(rows, dependency_keys) 

145 

146 

147def remove_cache(idx: int | Iterable[int], table: str) -> None: 

148 """ 

149 Remove any cache entries that are dependant on one or multiple indices of a table. 

150 """ 

151 if not isinstance(idx, Iterable): 

152 idx = [idx] 

153 

154 related = ( 

155 _TypedalCacheDependency.where(table=table).where(lambda row: row.idx.belongs(idx)).select("entry").to_sql() 

156 ) 

157 

158 _TypedalCache.where(_TypedalCache.id.belongs(related)).delete() 

159 

160 

161def clear_cache() -> None: 

162 """ 

163 Remove everything from the cache. 

164 """ 

165 _TypedalCacheDependency.truncate() 

166 _TypedalCache.truncate() 

167 

168 

169def clear_expired() -> int: 

170 """ 

171 Remove all expired items from the cache. 

172 

173 By default, expired items are only removed when trying to access them. 

174 """ 

175 now = get_now() 

176 return len(_TypedalCache.where(_TypedalCache.expires_at != None).where(_TypedalCache.expires_at < now).delete()) 

177 

178 

179def _remove_cache(s: Set, tablename: str) -> None: 

180 """ 

181 Used as the table._before_update and table._after_update for every TypeDAL table (on by default). 

182 """ 

183 indeces = s.select("id").column("id") 

184 remove_cache(indeces, tablename) 

185 

186 

187T_TypedTable = TypeVar("T_TypedTable", bound=TypedTable) 

188 

189 

190def get_expire( 

191 expires_at: Optional[datetime] = None, ttl: Optional[int | timedelta] = None, now: Optional[datetime] = None 

192) -> datetime | None: 

193 """ 

194 Based on an expires_at date or a ttl (in seconds or a time delta), determine the expire date. 

195 """ 

196 now = now or get_now() 

197 

198 if expires_at and ttl: 

199 raise ValueError("Please only supply an `expired at` date or a `ttl` in seconds!") 

200 elif isinstance(ttl, timedelta): 

201 return now + ttl 

202 elif ttl: 

203 return now + timedelta(seconds=ttl) 

204 elif expires_at: 

205 return expires_at 

206 

207 return None 

208 

209 

210def save_to_cache( 

211 instance: TypedRows[T_TypedTable], 

212 rows: Rows, 

213 expires_at: Optional[datetime] = None, 

214 ttl: Optional[int | timedelta] = None, 

215) -> TypedRows[T_TypedTable]: 

216 """ 

217 Save a typedrows result to the database, and save dependencies from rows. 

218 

219 You can call .cache(...) with dependent fields (e.g. User.id) or this function will determine them automatically. 

220 """ 

221 db = rows.db 

222 if (c := instance.metadata.get("cache", {})) and c.get("enabled") and (key := c.get("key")): 

223 expires_at = get_expire(expires_at=expires_at, ttl=ttl) or c.get("expires_at") 

224 

225 deps = _determine_dependencies(instance, rows, c["depends_on"]) 

226 

227 entry = _TypedalCache.insert( 

228 key=key, 

229 data=dill.dumps(instance), 

230 expires_at=expires_at, 

231 ) 

232 

233 _TypedalCacheDependency.bulk_insert([{"entry": entry, "table": table, "idx": idx} for table, idx in deps]) 

234 

235 db.commit() 

236 instance.metadata["cache"]["status"] = "fresh" 

237 return instance 

238 

239 

240def _load_from_cache(key: str, db: "TypeDAL") -> Any | None: 

241 if not (row := _TypedalCache.where(key=key).first()): 

242 return None 

243 

244 now = get_now() 

245 

246 expires = row.expires_at.replace(tzinfo=timezone.utc) if row.expires_at else None 

247 

248 if expires and now >= expires: 

249 row.delete_record() 

250 return None 

251 

252 inst = dill.loads(row.data) # nosec 

253 

254 inst.metadata["cache"]["status"] = "cached" 

255 inst.metadata["cache"]["cached_at"] = row.cached_at 

256 inst.metadata["cache"]["expires_at"] = row.expires_at 

257 

258 inst.db = db 

259 inst.model = db._class_map[inst.model] 

260 inst.model._setup_instance_methods(inst.model) # type: ignore 

261 return inst 

262 

263 

264def load_from_cache(key: str, db: "TypeDAL") -> Any | None: 

265 """ 

266 If 'key' matches a non-expired row in the database, try to load the dill. 

267 

268 If anything fails, return None. 

269 """ 

270 with contextlib.suppress(Exception): 

271 return _load_from_cache(key, db) 

272 

273 return None # pragma: no cover 

274 

275 

276def humanize_bytes(size: int | float) -> str: 

277 """ 

278 Turn a number of bytes into a human-readable version (e.g. 124 GB). 

279 """ 

280 if not size: 

281 return "0" 

282 

283 suffixes = ["B", "KB", "MB", "GB", "TB", "PB"] # List of suffixes for different magnitudes 

284 suffix_index = 0 

285 

286 while size > 1024 and suffix_index < len(suffixes) - 1: 

287 suffix_index += 1 

288 size /= 1024.0 

289 

290 return f"{size:.2f} {suffixes[suffix_index]}" 

291 

292 

293def _expired_and_valid_query() -> tuple[str, str]: 

294 expired_items = ( 

295 _TypedalCache.where(lambda row: (row.expires_at < get_now()) & (row.expires_at != None)) 

296 .select(_TypedalCache.id) 

297 .to_sql() 

298 ) 

299 

300 valid_items = _TypedalCache.where(~_TypedalCache.id.belongs(expired_items)).select(_TypedalCache.id).to_sql() 

301 

302 return expired_items, valid_items 

303 

304 

305T = typing.TypeVar("T") 

306Stats = typing.TypedDict("Stats", {"total": T, "valid": T, "expired": T}) 

307 

308RowStats = typing.TypedDict( 

309 "RowStats", 

310 { 

311 "Dependent Cache Entries": int, 

312 }, 

313) 

314 

315 

316def _row_stats(db: "TypeDAL", table: str, query: Query) -> RowStats: 

317 count_field = _TypedalCacheDependency.entry.count() 

318 stats: TypedRows[_TypedalCacheDependency] = db(query & (_TypedalCacheDependency.table == table)).select( 

319 _TypedalCacheDependency.entry, count_field, groupby=_TypedalCacheDependency.entry 

320 ) 

321 return { 

322 "Dependent Cache Entries": len(stats), 

323 } 

324 

325 

326def row_stats(db: "TypeDAL", table: str, row_id: str) -> Stats[RowStats]: 

327 """ 

328 Collect caching stats for a specific table row (by ID). 

329 """ 

330 expired_items, valid_items = _expired_and_valid_query() 

331 

332 query = _TypedalCacheDependency.idx == row_id 

333 

334 return { 

335 "total": _row_stats(db, table, query), 

336 "valid": _row_stats(db, table, _TypedalCacheDependency.entry.belongs(valid_items) & query), 

337 "expired": _row_stats(db, table, _TypedalCacheDependency.entry.belongs(expired_items) & query), 

338 } 

339 

340 

341TableStats = typing.TypedDict( 

342 "TableStats", 

343 { 

344 "Dependent Cache Entries": int, 

345 "Associated Table IDs": int, 

346 }, 

347) 

348 

349 

350def _table_stats(db: "TypeDAL", table: str, query: Query) -> TableStats: 

351 count_field = _TypedalCacheDependency.entry.count() 

352 stats: TypedRows[_TypedalCacheDependency] = db(query & (_TypedalCacheDependency.table == table)).select( 

353 _TypedalCacheDependency.entry, count_field, groupby=_TypedalCacheDependency.entry 

354 ) 

355 return { 

356 "Dependent Cache Entries": len(stats), 

357 "Associated Table IDs": sum(stats.column(count_field)), 

358 } 

359 

360 

361def table_stats(db: "TypeDAL", table: str) -> Stats[TableStats]: 

362 """ 

363 Collect caching stats for a table. 

364 """ 

365 expired_items, valid_items = _expired_and_valid_query() 

366 

367 return { 

368 "total": _table_stats(db, table, _TypedalCacheDependency.id > 0), 

369 "valid": _table_stats(db, table, _TypedalCacheDependency.entry.belongs(valid_items)), 

370 "expired": _table_stats(db, table, _TypedalCacheDependency.entry.belongs(expired_items)), 

371 } 

372 

373 

374GenericStats = typing.TypedDict( 

375 "GenericStats", 

376 { 

377 "entries": int, 

378 "dependencies": int, 

379 "size": str, 

380 }, 

381) 

382 

383 

384def _calculate_stats(db: "TypeDAL", query: Query) -> GenericStats: 

385 sum_len_field = _TypedalCache.data.len().sum() 

386 size_row = db(query).select(sum_len_field).first() 

387 

388 size = size_row[sum_len_field] if size_row else 0 # type: ignore 

389 

390 return { 

391 "entries": _TypedalCache.where(query).count(), 

392 "dependencies": db(_TypedalCacheDependency.entry.belongs(query)).count(), 

393 "size": humanize_bytes(size), 

394 } 

395 

396 

397def calculate_stats(db: "TypeDAL") -> Stats[GenericStats]: 

398 """ 

399 Collect generic caching stats. 

400 """ 

401 expired_items, valid_items = _expired_and_valid_query() 

402 

403 return { 

404 "total": _calculate_stats(db, _TypedalCache.id > 0), 

405 "valid": _calculate_stats(db, _TypedalCache.id.belongs(valid_items)), 

406 "expired": _calculate_stats(db, _TypedalCache.id.belongs(expired_items)), 

407 }