Coverage for smartmdao / cache.py: 100%
94 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 20:01 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 20:01 +0200
1import functools
2import hashlib
3import pickle
4import os
5import logging
6from abc import ABC, abstractmethod
7from collections import defaultdict
9# Initialize module-level logger
10logger = logging.getLogger(__name__)
12# --- 1. Abstract Backend Interface ---
13class CacheBackend(ABC):
14 @abstractmethod
15 def get(self, func_name, key):
16 pass
18 @abstractmethod
19 def set(self, func_name, key, value):
20 pass
22 @abstractmethod
23 def has(self, func_name, key):
24 pass
26# --- 2. In-Memory Backend (Dictionary) ---
27class MemoryBackend(CacheBackend):
28 def __init__(self):
29 self.store = {}
31 def _make_key(self, func_name, key):
32 return f"{func_name}::{key}"
34 def has(self, func_name, key):
35 return self._make_key(func_name, key) in self.store
37 def get(self, func_name, key):
38 logger.debug(f"[Memory] Cache hit for {func_name}")
39 return self.store[self._make_key(func_name, key)]
41 def set(self, func_name, key, value):
42 self.store[self._make_key(func_name, key)] = value
44class HistoryBackend(MemoryBackend):
45 """
46 A simple extension of MemoryBackend that keeps a chronological
47 list of all values computed by the cached functions.
48 """
49 def __init__(self):
50 super().__init__()
51 # Dictionary mapping function_name -> list of values
52 self.history = defaultdict(list)
54 def set(self, func_name, key, value):
55 # 1. Store in the standard cache (MemoryBackend logic)
56 super().set(func_name, key, value)
58 # 2. Append to our history list for plotting
59 self.history[func_name].append(value)
61# --- 3. HDF5 Backend ---
62class HDF5Backend(CacheBackend):
63 """
64 Best for Large Numpy Arrays.
65 Limitation: Can only store data types HDF5 supports (scalars, strings, numpy arrays).
66 For generic Python objects (classes, dicts), use Pickle instead.
67 """
68 def __init__(self, filepath):
69 self.filepath = filepath
70 import h5py
71 self.h5py = h5py # lazy import
73 def has(self, func_name, key):
74 if not os.path.exists(self.filepath):
75 return False
76 with self.h5py.File(self.filepath, 'r') as f:
77 return f"{func_name}/{key}" in f
79 def get(self, func_name, key):
80 logger.debug(f"[HDF5] Cache hit for {func_name}")
81 with self.h5py.File(self.filepath, 'r') as f:
82 dataset = f[f"{func_name}/{key}"]
83 # Convert back to numpy or scalar
84 if dataset.shape == ():
85 return dataset[()] # scalar
86 return dataset[:] # array
88 def set(self, func_name, key, value):
89 with self.h5py.File(self.filepath, 'a') as f:
90 group_path = f"{func_name}"
91 if group_path not in f:
92 f.create_group(group_path)
94 # Delete if exists to overwrite
95 if key in f[group_path]:
96 del f[group_path][key]
98 f[group_path].create_dataset(key, data=value)
100class PickleDiskBackend(CacheBackend):
101 def __init__(self, directory="cache_dir"):
102 self.directory = directory
103 os.makedirs(directory, exist_ok=True)
105 def _path(self, func_name, key):
106 return os.path.join(self.directory, f"{func_name}_{key}.pkl")
108 def has(self, func_name, key):
109 return os.path.exists(self._path(func_name, key))
111 def get(self, func_name, key):
112 logger.debug(f"[Pickle] Cache hit for {func_name}")
113 with open(self._path(func_name, key), 'rb') as f:
114 return pickle.load(f)
116 def set(self, func_name, key, value):
117 with open(self._path(func_name, key), 'wb') as f:
118 pickle.dump(value, f)
120# --- 4. The Decorator ---
121def generate_cache_key(kwargs):
122 """
123 Creates a stable hash of the input arguments.
124 We use pickle to serialize args -> hash to handle complex types.
125 """
126 # Sort kwargs to ensure order doesn't matter: f(a=1, b=2) == f(b=2, a=1)
127 sorted_items = sorted(kwargs.items())
128 serialized = pickle.dumps(sorted_items)
129 return hashlib.sha256(serialized).hexdigest()
131def cached(backend: CacheBackend):
132 def decorator(fn):
133 @functools.wraps(fn)
134 def wrapper(**kwargs):
135 # 1. Generate Key based on function input
136 key = generate_cache_key(kwargs)
138 # 2. Check Backend
139 if backend.has(fn.__name__, key):
140 return backend.get(fn.__name__, key)
142 # 3. Run Function
143 logger.debug(f"Cache miss for {fn.__name__}. Executing...")
144 result = fn(**kwargs)
146 # 4. Save Result
147 backend.set(fn.__name__, key, result)
148 return result
149 return wrapper
150 return decorator