Coverage for smart_pipeline / cache.py: 100%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 13:46 +0200

1import functools 

2import hashlib 

3import pickle 

4import os 

5import logging 

6from abc import ABC, abstractmethod 

7from collections import defaultdict 

8 

9# Initialize module-level logger 

10logger = logging.getLogger(__name__) 

11 

12# --- 1. Abstract Backend Interface --- 

13class CacheBackend(ABC): 

14 @abstractmethod 

15 def get(self, func_name, key): 

16 pass 

17 

18 @abstractmethod 

19 def set(self, func_name, key, value): 

20 pass 

21 

22 @abstractmethod 

23 def has(self, func_name, key): 

24 pass 

25 

26# --- 2. In-Memory Backend (Dictionary) --- 

27class MemoryBackend(CacheBackend): 

28 def __init__(self): 

29 self.store = {} 

30 

31 def _make_key(self, func_name, key): 

32 return f"{func_name}::{key}" 

33 

34 def has(self, func_name, key): 

35 return self._make_key(func_name, key) in self.store 

36 

37 def get(self, func_name, key): 

38 logger.debug(f"[Memory] Cache hit for {func_name}") 

39 return self.store[self._make_key(func_name, key)] 

40 

41 def set(self, func_name, key, value): 

42 self.store[self._make_key(func_name, key)] = value 

43 

44class HistoryBackend(MemoryBackend): 

45 """ 

46 A simple extension of MemoryBackend that keeps a chronological  

47 list of all values computed by the cached functions. 

48 """ 

49 def __init__(self): 

50 super().__init__() 

51 # Dictionary mapping function_name -> list of values 

52 self.history = defaultdict(list) 

53 

54 def set(self, func_name, key, value): 

55 # 1. Store in the standard cache (MemoryBackend logic) 

56 super().set(func_name, key, value) 

57 

58 # 2. Append to our history list for plotting 

59 self.history[func_name].append(value) 

60 

61# --- 3. HDF5 Backend --- 

62class HDF5Backend(CacheBackend): 

63 """ 

64 Best for Large Numpy Arrays.  

65 Limitation: Can only store data types HDF5 supports (scalars, strings, numpy arrays). 

66 For generic Python objects (classes, dicts), use Pickle instead. 

67 """ 

68 def __init__(self, filepath): 

69 self.filepath = filepath 

70 import h5py 

71 self.h5py = h5py # lazy import 

72 

73 def has(self, func_name, key): 

74 if not os.path.exists(self.filepath): 

75 return False 

76 with self.h5py.File(self.filepath, 'r') as f: 

77 return f"{func_name}/{key}" in f 

78 

79 def get(self, func_name, key): 

80 logger.debug(f"[HDF5] Cache hit for {func_name}") 

81 with self.h5py.File(self.filepath, 'r') as f: 

82 dataset = f[f"{func_name}/{key}"] 

83 # Convert back to numpy or scalar 

84 if dataset.shape == (): 

85 return dataset[()] # scalar 

86 return dataset[:] # array 

87 

88 def set(self, func_name, key, value): 

89 with self.h5py.File(self.filepath, 'a') as f: 

90 group_path = f"{func_name}" 

91 if group_path not in f: 

92 f.create_group(group_path) 

93 

94 # Delete if exists to overwrite 

95 if key in f[group_path]: 

96 del f[group_path][key] 

97 

98 f[group_path].create_dataset(key, data=value) 

99 

100class PickleDiskBackend(CacheBackend): 

101 def __init__(self, directory="cache_dir"): 

102 self.directory = directory 

103 os.makedirs(directory, exist_ok=True) 

104 

105 def _path(self, func_name, key): 

106 return os.path.join(self.directory, f"{func_name}_{key}.pkl") 

107 

108 def has(self, func_name, key): 

109 return os.path.exists(self._path(func_name, key)) 

110 

111 def get(self, func_name, key): 

112 logger.debug(f"[Pickle] Cache hit for {func_name}") 

113 with open(self._path(func_name, key), 'rb') as f: 

114 return pickle.load(f) 

115 

116 def set(self, func_name, key, value): 

117 with open(self._path(func_name, key), 'wb') as f: 

118 pickle.dump(value, f) 

119 

120# --- 4. The Decorator --- 

121def generate_cache_key(kwargs): 

122 """ 

123 Creates a stable hash of the input arguments. 

124 We use pickle to serialize args -> hash to handle complex types. 

125 """ 

126 # Sort kwargs to ensure order doesn't matter: f(a=1, b=2) == f(b=2, a=1) 

127 sorted_items = sorted(kwargs.items()) 

128 serialized = pickle.dumps(sorted_items) 

129 return hashlib.sha256(serialized).hexdigest() 

130 

131def cached(backend: CacheBackend): 

132 def decorator(fn): 

133 @functools.wraps(fn) 

134 def wrapper(**kwargs): 

135 # 1. Generate Key based on function input 

136 key = generate_cache_key(kwargs) 

137 

138 # 2. Check Backend 

139 if backend.has(fn.__name__, key): 

140 return backend.get(fn.__name__, key) 

141 

142 # 3. Run Function 

143 logger.debug(f"Cache miss for {fn.__name__}. Executing...") 

144 result = fn(**kwargs) 

145 

146 # 4. Save Result 

147 backend.set(fn.__name__, key, result) 

148 return result 

149 return wrapper 

150 return decorator