Coverage for src/typedal/caching.py: 100%

123 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-04 18:50 +0100

1""" 

2Helpers to facilitate db-based caching. 

3""" 

4 

5import contextlib 

6import hashlib 

7import json 

8from datetime import datetime, timedelta, timezone 

9from typing import Any, Iterable, Mapping, Optional, TypeVar 

10 

11import dill # nosec 

12from pydal.objects import Field, Rows, Set 

13 

14from .core import TypedField, TypedRows, TypedTable 

15 

16 

17def get_now(tz: timezone = timezone.utc) -> datetime: 

18 """ 

19 Get the default datetime, optionally in a specific timezone. 

20 """ 

21 return datetime.now(tz) 

22 

23 

24class _TypedalCache(TypedTable): 

25 """ 

26 Internal table to store previously loaded models. 

27 """ 

28 

29 key: TypedField[str] 

30 data: TypedField[bytes] 

31 cached_at = TypedField(datetime, default=get_now) 

32 expires_at: TypedField[datetime | None] 

33 

34 

35class _TypedalCacheDependency(TypedTable): 

36 """ 

37 Internal table that stores dependencies to invalidate cache when a related table is updated. 

38 """ 

39 

40 entry: TypedField[_TypedalCache] 

41 table: TypedField[str] 

42 idx: TypedField[int] 

43 

44 

45def prepare(field: Any) -> str: 

46 """ 

47 Prepare data to be used in a cache key. 

48 

49 By sorting and stringifying data, queries can be syntactically different from each other \ 

50 but when semantically exactly the same will still be loaded from cache. 

51 """ 

52 if isinstance(field, str): 

53 return field 

54 elif isinstance(field, (dict, Mapping)): 

55 data = {str(k): prepare(v) for k, v in field.items()} 

56 return json.dumps(data, sort_keys=True) 

57 elif isinstance(field, Iterable): 

58 return ",".join(sorted([prepare(_) for _ in field])) 

59 elif isinstance(field, bool): 

60 return str(int(field)) 

61 else: 

62 return str(field) 

63 

64 

65def create_cache_key(*fields: Any) -> str: 

66 """ 

67 Turn any fields of data into a string. 

68 """ 

69 return "|".join(prepare(field) for field in fields) 

70 

71 

72def hash_cache_key(cache_key: str | bytes) -> str: 

73 """ 

74 Hash the input cache key with SHA 256. 

75 """ 

76 h = hashlib.sha256() 

77 h.update(cache_key.encode() if isinstance(cache_key, str) else cache_key) 

78 return h.hexdigest() 

79 

80 

81def create_and_hash_cache_key(*fields: Any) -> tuple[str, str]: 

82 """ 

83 Combine the input fields into one key and hash it with SHA 256. 

84 """ 

85 key = create_cache_key(*fields) 

86 return key, hash_cache_key(key) 

87 

88 

89DependencyTuple = tuple[str, int] # table + id 

90DependencyTupleSet = set[DependencyTuple] 

91 

92 

93def _get_table_name(field: Field) -> str: 

94 """ 

95 Get the table name from a field or alias. 

96 """ 

97 return str(field._table).split(" AS ")[0].strip() 

98 

99 

100def _get_dependency_ids(rows: Rows, dependency_keys: list[tuple[Field, str]]) -> DependencyTupleSet: 

101 dependencies = set() 

102 for row in rows: 

103 for field, table in dependency_keys: 

104 if idx := row[field]: 

105 dependencies.add((table, idx)) 

106 

107 return dependencies 

108 

109 

110def _determine_dependencies_auto(_: TypedRows[Any], rows: Rows) -> DependencyTupleSet: 

111 dependency_keys = [] 

112 for field in rows.fields: 

113 if str(field).endswith(".id"): 

114 table_name = _get_table_name(field) 

115 

116 dependency_keys.append((field, table_name)) 

117 

118 return _get_dependency_ids(rows, dependency_keys) 

119 

120 

121def _determine_dependencies(instance: TypedRows[Any], rows: Rows, depends_on: list[Any]) -> DependencyTupleSet: 

122 if not depends_on: 

123 return _determine_dependencies_auto(instance, rows) 

124 

125 target_field_names = set() 

126 for field in depends_on: 

127 if "." not in field: 

128 field = f"{instance.model._table}.{field}" 

129 

130 target_field_names.add(str(field)) 

131 

132 dependency_keys = [] 

133 for field in rows.fields: 

134 if str(field) in target_field_names: 

135 table_name = _get_table_name(field) 

136 

137 dependency_keys.append((field, table_name)) 

138 

139 return _get_dependency_ids(rows, dependency_keys) 

140 

141 

142def remove_cache(idx: int | Iterable[int], table: str) -> None: 

143 """ 

144 Remove any cache entries that are dependant on one or multiple indices of a table. 

145 """ 

146 if not isinstance(idx, Iterable): 

147 idx = [idx] 

148 

149 related = ( 

150 _TypedalCacheDependency.where(table=table).where(lambda row: row.idx.belongs(idx)).select("entry").to_sql() 

151 ) 

152 

153 _TypedalCache.where(_TypedalCache.id.belongs(related)).delete() 

154 

155 

156def clear_cache() -> None: 

157 """ 

158 Remove everything from the cache. 

159 """ 

160 _TypedalCacheDependency.truncate() 

161 _TypedalCache.truncate() 

162 

163 

164def clear_expired() -> int: 

165 """ 

166 Remove all expired items from the cache. 

167 

168 By default, expired items are only removed when trying to access them. 

169 """ 

170 now = get_now() 

171 return len(_TypedalCache.where(_TypedalCache.expires_at > now).delete()) 

172 

173 

174def _remove_cache(s: Set, tablename: str) -> None: 

175 """ 

176 Used as the table._before_update and table._after_update for every TypeDAL table (on by default). 

177 """ 

178 indeces = s.select("id").column("id") 

179 remove_cache(indeces, tablename) 

180 

181 

182T_TypedTable = TypeVar("T_TypedTable", bound=TypedTable) 

183 

184 

185def get_expire( 

186 expires_at: Optional[datetime] = None, ttl: Optional[int | timedelta] = None, now: Optional[datetime] = None 

187) -> datetime | None: 

188 """ 

189 Based on an expires_at date or a ttl (in seconds or a time delta), determine the expire date. 

190 """ 

191 now = now or get_now() 

192 

193 if expires_at and ttl: 

194 raise ValueError("Please only supply an `expired at` date or a `ttl` in seconds!") 

195 elif isinstance(ttl, timedelta): 

196 return now + ttl 

197 elif ttl: 

198 return now + timedelta(seconds=ttl) 

199 elif expires_at: 

200 return expires_at 

201 

202 return None 

203 

204 

205def save_to_cache( 

206 instance: TypedRows[T_TypedTable], 

207 rows: Rows, 

208 expires_at: Optional[datetime] = None, 

209 ttl: Optional[int | timedelta] = None, 

210) -> TypedRows[T_TypedTable]: 

211 """ 

212 Save a typedrows result to the database, and save dependencies from rows. 

213 

214 You can call .cache(...) with dependent fields (e.g. User.id) or this function will determine them automatically. 

215 """ 

216 db = rows.db 

217 if (c := instance.metadata.get("cache", {})) and c.get("enabled") and (key := c.get("key")): 

218 expires_at = get_expire(expires_at=expires_at, ttl=ttl) or c.get("expires_at") 

219 

220 deps = _determine_dependencies(instance, rows, c["depends_on"]) 

221 

222 entry = _TypedalCache.insert( 

223 key=key, 

224 data=dill.dumps(instance), 

225 expires_at=expires_at, 

226 ) 

227 

228 _TypedalCacheDependency.bulk_insert([{"entry": entry, "table": table, "idx": idx} for table, idx in deps]) 

229 

230 db.commit() 

231 instance.metadata["cache"]["status"] = "fresh" 

232 return instance 

233 

234 

235def _load_from_cache(key: str) -> Any | None: 

236 if not (row := _TypedalCache.where(key=key).first()): 

237 return None 

238 

239 now = get_now() 

240 

241 expires = row.expires_at.replace(tzinfo=timezone.utc) if row.expires_at else None 

242 

243 if expires and now >= expires: 

244 row.delete_record() 

245 return None 

246 

247 inst = dill.loads(row.data) # nosec 

248 inst.metadata["cache"]["status"] = "cached" 

249 inst.metadata["cache"]["cached_at"] = row.cached_at 

250 inst.metadata["cache"]["expires_at"] = row.expires_at 

251 return inst 

252 

253 

254def load_from_cache(key: str) -> Any | None: 

255 """ 

256 If 'key' matches a non-expired row in the database, try to load the dill. 

257 

258 If anything fails, return None. 

259 """ 

260 with contextlib.suppress(Exception): 

261 return _load_from_cache(key) 

262 

263 return None # pragma: no cover