Coverage for tests / unit / test_figure_util.py: 100%

175 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1import base64 

2import gzip 

3import io 

4import re 

5from pathlib import Path 

6 

7import jaxtyping 

8import matplotlib.pyplot as plt 

9import numpy as np 

10import pytest 

11from PIL import Image 

12 

13from pattern_lens.figure_util import ( 

14 MATPLOTLIB_FIGURE_FMT, 

15 matplotlib_figure_saver, 

16 matrix_as_svg, 

17 save_matrix_wrapper, 

18) 

19 

20TEMP_DIR: Path = Path("tests/.temp") 

21 

22 

23def test_matplotlib_figure_saver(): 

24 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

25 

26 @matplotlib_figure_saver 

27 def plot_matrix(attn_matrix, ax): 

28 ax.matshow(attn_matrix, cmap="viridis") 

29 ax.axis("off") 

30 

31 attn_matrix = np.random.rand(10, 10).astype(np.float32) 

32 plot_matrix(attn_matrix, TEMP_DIR) 

33 

34 saved_file = TEMP_DIR / f"plot_matrix.{MATPLOTLIB_FIGURE_FMT}" 

35 assert saved_file.exists(), "Matplotlib figure file was not saved" 

36 

37 

38def test_matplotlib_figure_saver_exception(): 

39 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

40 

41 @matplotlib_figure_saver 

42 def faulty_plot(attn_matrix, ax): # noqa: ARG001 

43 raise ValueError("Intentional failure for testing") 

44 

45 attn_matrix = np.random.rand(10, 10).astype(np.float32) 

46 with pytest.raises(ValueError, match="Intentional failure for testing"): 

47 faulty_plot(attn_matrix, TEMP_DIR) 

48 

49 

50def test_matrix_as_svg_normalization(): 

51 matrix = np.array([[2, 4], [6, 8]], dtype=np.float32) 

52 svg_content = matrix_as_svg(matrix, normalize=True) 

53 assert "image href=" in svg_content, "SVG content is malformed" 

54 assert "data:image/png;base64," in svg_content, "Base64 encoding is missing" 

55 

56 

57def test_matrix_as_svg_no_normalization(): 

58 matrix = np.array([[0.1, 0.4], [0.6, 0.9]], dtype=np.float32) 

59 svg_content = matrix_as_svg(matrix, normalize=False) 

60 assert "image href=" in svg_content, "SVG content is malformed" 

61 assert "data:image/png;base64," in svg_content, "Base64 encoding is missing" 

62 

63 

64def test_matrix_as_svg_invalid_range(): 

65 matrix = np.array([[-1, 2], [3, 4]], dtype=np.float32) 

66 with pytest.raises( 

67 AssertionError, 

68 match="Matrix values must be in range \\[0, 1\\], or normalize must be True", 

69 ): 

70 matrix_as_svg(matrix, normalize=False) 

71 

72 

73def test_matrix_as_svg_invalid_dims(): 

74 matrix = np.random.rand(5, 5, 5).astype(np.float32) 

75 with pytest.raises((AssertionError, jaxtyping.TypeCheckError)): 

76 matrix_as_svg(matrix, normalize=True) 

77 

78 

79def test_matrix_as_svg_invalid_cmap_fixed(): 

80 matrix = np.array([[0.1, 0.4], [0.6, 0.9]], dtype=np.float32) 

81 with pytest.raises(KeyError, match="'invalid_cmap' is not a known colormap name"): 

82 matrix_as_svg(matrix, cmap="invalid_cmap") 

83 

84 

85# Test with no arguments 

86def test_save_matrix_as_svgz_wrapper_no_args(): 

87 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

88 

89 @save_matrix_wrapper(fmt="svgz") 

90 def no_op(matrix): 

91 return matrix 

92 

93 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 

94 no_op(test_matrix, TEMP_DIR) 

95 

96 saved_file = TEMP_DIR / "no_op.svgz" 

97 assert saved_file.exists(), "SVGZ file was not saved in no-args case" 

98 

99 

100# Test with keyword-only arguments 

101def test_save_matrix_as_svgz_wrapper_with_args(): 

102 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

103 

104 @save_matrix_wrapper(normalize=True, cmap="plasma") 

105 def scale_matrix(matrix): 

106 return matrix * 2 

107 

108 test_matrix = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float32) 

109 scale_matrix(test_matrix, TEMP_DIR) 

110 

111 saved_file = TEMP_DIR / "scale_matrix.svgz" 

112 assert saved_file.exists(), "SVGZ file was not saved with keyword-only arguments" 

113 

114 

115# Test exception handling 

116def test_save_matrix_as_svgz_wrapper_exceptions(): 

117 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

118 

119 @save_matrix_wrapper(normalize=False) 

120 def invalid_range(matrix): 

121 return matrix * 2 

122 

123 test_matrix = np.array([[2, 3], [4, 5]], dtype=np.float32) 

124 with pytest.raises( 

125 AssertionError, 

126 match=r"Matrix values must be in range \[0, 1\], or normalize must be True.*", 

127 ): 

128 invalid_range(test_matrix, TEMP_DIR) 

129 

130 

131# Test keyword-only arguments enforced 

132def test_save_matrix_as_svgz_wrapper_keyword_only(): 

133 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

134 

135 @save_matrix_wrapper(normalize=True, cmap="plasma") 

136 def scale_matrix(matrix): 

137 return matrix * 2 

138 

139 test_matrix = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float32) 

140 scale_matrix(test_matrix, TEMP_DIR) 

141 

142 saved_file = TEMP_DIR / "scale_matrix.svgz" 

143 assert saved_file.exists(), "SVGZ file was not saved with keyword-only arguments" 

144 

145 

146# Test multiple calls to the decorator 

147def test_save_matrix_as_svgz_wrapper_multiple(): 

148 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

149 

150 @save_matrix_wrapper(normalize=True) 

151 def scale_by_factor(matrix): 

152 return matrix * 3 

153 

154 matrix_1 = np.array([[0.1, 0.5], [0.7, 0.9]], dtype=np.float32) 

155 matrix_2 = np.array([[0.2, 0.6], [0.8, 1.0]], dtype=np.float32) 

156 

157 scale_by_factor(matrix_1, TEMP_DIR) 

158 scale_by_factor(matrix_2, TEMP_DIR) 

159 

160 # Check the saved files 

161 saved_file = TEMP_DIR / "scale_by_factor.svgz" 

162 assert saved_file.exists(), "SVGZ file was not saved for multiple calls" 

163 

164 

165# Validate behavior when normalize is False and values are in range 

166def test_save_matrix_as_svgz_wrapper_no_normalization(): 

167 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

168 

169 @save_matrix_wrapper(normalize=False) 

170 def pass_through(matrix): 

171 return matrix 

172 

173 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 

174 pass_through(test_matrix, TEMP_DIR) 

175 

176 saved_file = TEMP_DIR / "pass_through.svgz" 

177 assert saved_file.exists(), ( 

178 "SVGZ file was not saved when normalization was not applied" 

179 ) 

180 

181 

182# Test with a complex matrix 

183def test_save_matrix_as_svgz_wrapper_complex_matrix(): 

184 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

185 

186 @save_matrix_wrapper(normalize=True, cmap="viridis") 

187 def complex_processing(matrix): 

188 return np.sin(matrix) 

189 

190 test_matrix = np.linspace(0, np.pi, 16).reshape(4, 4).astype(np.float32) 

191 complex_processing(test_matrix, TEMP_DIR) 

192 

193 saved_file = TEMP_DIR / "complex_processing.svgz" 

194 assert saved_file.exists(), "SVGZ file was not saved for complex matrix processing" 

195 

196 

197def test_matrix_as_svg_dimensions(): 

198 # Test different matrix shapes 

199 matrices = [ 

200 np.random.rand(5, 10), # Non-square 

201 np.random.rand(3, 3), # Small square 

202 np.random.rand(100, 50), # Large non-square 

203 ] 

204 

205 for matrix in matrices: 

206 m, n = matrix.shape 

207 svg_content = matrix_as_svg(matrix, normalize=True) 

208 assert f'width="{m}"' in svg_content 

209 assert f'height="{n}"' in svg_content 

210 assert f'viewBox="0 0 {m} {n}"' in svg_content 

211 

212 

213def test_save_matrix_as_svgz_wrapper_content(): 

214 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

215 

216 @save_matrix_wrapper(normalize=True) 

217 def identity(matrix): 

218 return matrix 

219 

220 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 

221 identity(test_matrix, TEMP_DIR) 

222 

223 saved_file = TEMP_DIR / "identity.svgz" 

224 with gzip.open(saved_file, "rt") as f: 

225 content = f.read() 

226 assert "svg" in content 

227 assert "image href=" in content 

228 assert "base64" in content 

229 

230 

231@pytest.mark.parametrize("fmt", ["png", "pdf", "svg"]) 

232def test_matplotlib_figure_saver_formats(fmt): 

233 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

234 

235 # TYPING: error: Too few arguments [call-arg] 

236 @matplotlib_figure_saver(None, fmt=fmt) # type: ignore[call-arg] 

237 def plot_matrix(attn_matrix, ax): 

238 ax.matshow(attn_matrix) 

239 ax.axis("off") 

240 

241 matrix = np.random.rand(5, 5) 

242 plot_matrix(matrix, TEMP_DIR) 

243 saved_file = TEMP_DIR / f"plot_matrix.{fmt}" 

244 assert saved_file.exists(), f"File not saved for format {fmt}" 

245 

246 

247def test_matrix_as_svg_empty(): 

248 empty_matrix = np.array([[]], dtype=np.float32).reshape(0, 0) 

249 with pytest.raises(AssertionError, match="Matrix cannot be empty"): 

250 matrix_as_svg(empty_matrix) 

251 

252 

253def test_matplotlib_figure_saver_cleanup(): 

254 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

255 initial_figures = len(plt.get_fignums()) 

256 

257 @matplotlib_figure_saver 

258 def plot_matrix(attn_matrix, ax): 

259 ax.matshow(attn_matrix) 

260 

261 matrix = np.random.rand(5, 5) 

262 plot_matrix(matrix, TEMP_DIR) 

263 

264 # Check that no figure objects remain 

265 assert len(plt.get_fignums()) == initial_figures, "Figure not properly cleaned up" 

266 

267 

268def test_matrix_as_svg_non_numeric(): 

269 matrix = np.array([["a", "b"], ["c", "d"]]) 

270 with pytest.raises(TypeError): 

271 matrix_as_svg(matrix) 

272 

273 

274def test_matrix_as_svg_format(): 

275 # create a small 2x2 matrix 

276 matrix = np.array([[0.0, 0.5], [1.0, 0.75]], dtype=float) 

277 

278 svg_str = matrix_as_svg(matrix) 

279 

280 # ensure it's got the correct SVG wrapper 

281 assert svg_str.startswith("<svg"), "SVG should start with <svg>" 

282 assert svg_str.endswith("</svg>"), "SVG should end with </svg>" 

283 

284 # find the embedded base64 image data 

285 match = re.search(r'data:image/png;base64,([^"]+)', svg_str) 

286 assert match, "Expected an embedded PNG in data URI format" 

287 

288 embedded_data = match.group(1) 

289 png_data = base64.b64decode(embedded_data) 

290 

291 Image.open(io.BytesIO(png_data))