Coverage for tests/unit/test_figure_util.py: 100%
175 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1import base64
2import gzip
3import io
4import re
5from pathlib import Path
7import jaxtyping
8import matplotlib.pyplot as plt
9import numpy as np
10import pytest
11from PIL import Image
13from pattern_lens.figure_util import (
14 MATPLOTLIB_FIGURE_FMT,
15 matplotlib_figure_saver,
16 matrix_as_svg,
17 save_matrix_wrapper,
18)
20TEMP_DIR: Path = Path("tests/_temp")
23def test_matplotlib_figure_saver():
24 TEMP_DIR.mkdir(parents=True, exist_ok=True)
26 @matplotlib_figure_saver
27 def plot_matrix(attn_matrix, ax):
28 ax.matshow(attn_matrix, cmap="viridis")
29 ax.axis("off")
31 attn_matrix = np.random.rand(10, 10).astype(np.float32)
32 plot_matrix(attn_matrix, TEMP_DIR)
34 saved_file = TEMP_DIR / f"plot_matrix.{MATPLOTLIB_FIGURE_FMT}"
35 assert saved_file.exists(), "Matplotlib figure file was not saved"
38def test_matplotlib_figure_saver_exception():
39 TEMP_DIR.mkdir(parents=True, exist_ok=True)
41 @matplotlib_figure_saver
42 def faulty_plot(attn_matrix, ax): # noqa: ARG001
43 raise ValueError("Intentional failure for testing")
45 attn_matrix = np.random.rand(10, 10).astype(np.float32)
46 with pytest.raises(ValueError, match="Intentional failure for testing"):
47 faulty_plot(attn_matrix, TEMP_DIR)
50def test_matrix_as_svg_normalization():
51 matrix = np.array([[2, 4], [6, 8]], dtype=np.float32)
52 svg_content = matrix_as_svg(matrix, normalize=True)
53 assert "image href=" in svg_content, "SVG content is malformed"
54 assert "data:image/png;base64," in svg_content, "Base64 encoding is missing"
57def test_matrix_as_svg_no_normalization():
58 matrix = np.array([[0.1, 0.4], [0.6, 0.9]], dtype=np.float32)
59 svg_content = matrix_as_svg(matrix, normalize=False)
60 assert "image href=" in svg_content, "SVG content is malformed"
61 assert "data:image/png;base64," in svg_content, "Base64 encoding is missing"
64def test_matrix_as_svg_invalid_range():
65 matrix = np.array([[-1, 2], [3, 4]], dtype=np.float32)
66 with pytest.raises(
67 AssertionError,
68 match="Matrix values must be in range \\[0, 1\\], or normalize must be True",
69 ):
70 matrix_as_svg(matrix, normalize=False)
73def test_matrix_as_svg_invalid_dims():
74 matrix = np.random.rand(5, 5, 5).astype(np.float32)
75 with pytest.raises((AssertionError, jaxtyping.TypeCheckError)):
76 matrix_as_svg(matrix, normalize=True)
79def test_matrix_as_svg_invalid_cmap_fixed():
80 matrix = np.array([[0.1, 0.4], [0.6, 0.9]], dtype=np.float32)
81 with pytest.raises(KeyError, match="'invalid_cmap' is not a known colormap name"):
82 matrix_as_svg(matrix, cmap="invalid_cmap")
85# Test with no arguments
86def test_save_matrix_as_svgz_wrapper_no_args():
87 TEMP_DIR.mkdir(parents=True, exist_ok=True)
89 @save_matrix_wrapper(fmt="svgz")
90 def no_op(matrix):
91 return matrix
93 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
94 no_op(test_matrix, TEMP_DIR)
96 saved_file = TEMP_DIR / "no_op.svgz"
97 assert saved_file.exists(), "SVGZ file was not saved in no-args case"
100# Test with keyword-only arguments
101def test_save_matrix_as_svgz_wrapper_with_args():
102 TEMP_DIR.mkdir(parents=True, exist_ok=True)
104 @save_matrix_wrapper(normalize=True, cmap="plasma")
105 def scale_matrix(matrix):
106 return matrix * 2
108 test_matrix = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float32)
109 scale_matrix(test_matrix, TEMP_DIR)
111 saved_file = TEMP_DIR / "scale_matrix.svgz"
112 assert saved_file.exists(), "SVGZ file was not saved with keyword-only arguments"
115# Test exception handling
116def test_save_matrix_as_svgz_wrapper_exceptions():
117 TEMP_DIR.mkdir(parents=True, exist_ok=True)
119 @save_matrix_wrapper(normalize=False)
120 def invalid_range(matrix):
121 return matrix * 2
123 test_matrix = np.array([[2, 3], [4, 5]], dtype=np.float32)
124 with pytest.raises(
125 AssertionError,
126 match=r"Matrix values must be in range \[0, 1\], or normalize must be True.*",
127 ):
128 invalid_range(test_matrix, TEMP_DIR)
131# Test keyword-only arguments enforced
132def test_save_matrix_as_svgz_wrapper_keyword_only():
133 TEMP_DIR.mkdir(parents=True, exist_ok=True)
135 @save_matrix_wrapper(normalize=True, cmap="plasma")
136 def scale_matrix(matrix):
137 return matrix * 2
139 test_matrix = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float32)
140 scale_matrix(test_matrix, TEMP_DIR)
142 saved_file = TEMP_DIR / "scale_matrix.svgz"
143 assert saved_file.exists(), "SVGZ file was not saved with keyword-only arguments"
146# Test multiple calls to the decorator
147def test_save_matrix_as_svgz_wrapper_multiple():
148 TEMP_DIR.mkdir(parents=True, exist_ok=True)
150 @save_matrix_wrapper(normalize=True)
151 def scale_by_factor(matrix):
152 return matrix * 3
154 matrix_1 = np.array([[0.1, 0.5], [0.7, 0.9]], dtype=np.float32)
155 matrix_2 = np.array([[0.2, 0.6], [0.8, 1.0]], dtype=np.float32)
157 scale_by_factor(matrix_1, TEMP_DIR)
158 scale_by_factor(matrix_2, TEMP_DIR)
160 # Check the saved files
161 saved_file = TEMP_DIR / "scale_by_factor.svgz"
162 assert saved_file.exists(), "SVGZ file was not saved for multiple calls"
165# Validate behavior when normalize is False and values are in range
166def test_save_matrix_as_svgz_wrapper_no_normalization():
167 TEMP_DIR.mkdir(parents=True, exist_ok=True)
169 @save_matrix_wrapper(normalize=False)
170 def pass_through(matrix):
171 return matrix
173 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
174 pass_through(test_matrix, TEMP_DIR)
176 saved_file = TEMP_DIR / "pass_through.svgz"
177 assert saved_file.exists(), (
178 "SVGZ file was not saved when normalization was not applied"
179 )
182# Test with a complex matrix
183def test_save_matrix_as_svgz_wrapper_complex_matrix():
184 TEMP_DIR.mkdir(parents=True, exist_ok=True)
186 @save_matrix_wrapper(normalize=True, cmap="viridis")
187 def complex_processing(matrix):
188 return np.sin(matrix)
190 test_matrix = np.linspace(0, np.pi, 16).reshape(4, 4).astype(np.float32)
191 complex_processing(test_matrix, TEMP_DIR)
193 saved_file = TEMP_DIR / "complex_processing.svgz"
194 assert saved_file.exists(), "SVGZ file was not saved for complex matrix processing"
197def test_matrix_as_svg_dimensions():
198 # Test different matrix shapes
199 matrices = [
200 np.random.rand(5, 10), # Non-square
201 np.random.rand(3, 3), # Small square
202 np.random.rand(100, 50), # Large non-square
203 ]
205 for matrix in matrices:
206 m, n = matrix.shape
207 svg_content = matrix_as_svg(matrix, normalize=True)
208 assert f'width="{m}"' in svg_content
209 assert f'height="{n}"' in svg_content
210 assert f'viewBox="0 0 {m} {n}"' in svg_content
213def test_save_matrix_as_svgz_wrapper_content():
214 TEMP_DIR.mkdir(parents=True, exist_ok=True)
216 @save_matrix_wrapper(normalize=True)
217 def identity(matrix):
218 return matrix
220 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
221 identity(test_matrix, TEMP_DIR)
223 saved_file = TEMP_DIR / "identity.svgz"
224 with gzip.open(saved_file, "rt") as f:
225 content = f.read()
226 assert "svg" in content
227 assert "image href=" in content
228 assert "base64" in content
231@pytest.mark.parametrize("fmt", ["png", "pdf", "svg"])
232def test_matplotlib_figure_saver_formats(fmt):
233 TEMP_DIR.mkdir(parents=True, exist_ok=True)
235 # TYPING: error: Too few arguments [call-arg]
236 @matplotlib_figure_saver(None, fmt=fmt) # type: ignore[call-arg]
237 def plot_matrix(attn_matrix, ax):
238 ax.matshow(attn_matrix)
239 ax.axis("off")
241 matrix = np.random.rand(5, 5)
242 plot_matrix(matrix, TEMP_DIR)
243 saved_file = TEMP_DIR / f"plot_matrix.{fmt}"
244 assert saved_file.exists(), f"File not saved for format {fmt}"
247def test_matrix_as_svg_empty():
248 empty_matrix = np.array([[]], dtype=np.float32).reshape(0, 0)
249 with pytest.raises(AssertionError, match="Matrix cannot be empty"):
250 matrix_as_svg(empty_matrix)
253def test_matplotlib_figure_saver_cleanup():
254 TEMP_DIR.mkdir(parents=True, exist_ok=True)
255 initial_figures = len(plt.get_fignums())
257 @matplotlib_figure_saver
258 def plot_matrix(attn_matrix, ax):
259 ax.matshow(attn_matrix)
261 matrix = np.random.rand(5, 5)
262 plot_matrix(matrix, TEMP_DIR)
264 # Check that no figure objects remain
265 assert len(plt.get_fignums()) == initial_figures, "Figure not properly cleaned up"
268def test_matrix_as_svg_non_numeric():
269 matrix = np.array([["a", "b"], ["c", "d"]])
270 with pytest.raises(TypeError):
271 matrix_as_svg(matrix)
274def test_matrix_as_svg_format():
275 # create a small 2x2 matrix
276 matrix = np.array([[0.0, 0.5], [1.0, 0.75]], dtype=float)
278 svg_str = matrix_as_svg(matrix)
280 # ensure it's got the correct SVG wrapper
281 assert svg_str.startswith("<svg"), "SVG should start with <svg>"
282 assert svg_str.endswith("</svg>"), "SVG should end with </svg>"
284 # find the embedded base64 image data
285 match = re.search(r'data:image/png;base64,([^"]+)', svg_str)
286 assert match, "Expected an embedded PNG in data URI format"
288 embedded_data = match.group(1)
289 png_data = base64.b64decode(embedded_data)
291 Image.open(io.BytesIO(png_data))