Coverage for tests / unit / test_multifig.py: 98%
64 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1from pathlib import Path
3import matplotlib.pyplot as plt
4import numpy as np
6from pattern_lens.figure_util import matplotlib_multifigure_saver
8TEMP_DIR: Path = Path("tests/.temp")
11def test_matplotlib_multifigure_saver():
12 """Test the matplotlib_multifigure_saver decorator."""
13 # Create a temporary directory for saving figures
14 temp_dir = TEMP_DIR / "test_matplotlib_multifigure_saver"
15 temp_dir.mkdir(parents=True, exist_ok=True)
17 # Define a test function with the decorator
18 @matplotlib_multifigure_saver(names=["hist", "heatmap"])
19 def multi_plot(attn_matrix, axes_dict):
20 # Plot histogram
21 axes_dict["hist"].hist(attn_matrix.flatten(), bins=30)
22 axes_dict["hist"].set_title("Attention Values Histogram")
24 # Plot heatmap
25 im = axes_dict["heatmap"].matshow(attn_matrix, cmap="viridis") # noqa: F841
26 axes_dict["heatmap"].set_title("Attention Heatmap")
28 # Generate a test attention matrix
29 attn_matrix = np.random.rand(10, 10).astype(np.float32)
31 # Call the decorated function
32 multi_plot(attn_matrix, temp_dir)
34 # Check that both figures were saved
35 hist_file = temp_dir / "multi_plot.hist.svgz"
36 heatmap_file = temp_dir / "multi_plot.heatmap.svgz"
38 assert hist_file.exists(), "Histogram figure file was not saved"
39 assert heatmap_file.exists(), "Heatmap figure file was not saved"
41 # Check that figure_save_fmt attribute was set correctly
42 assert hasattr(multi_plot, "figure_save_fmt")
43 assert multi_plot.figure_save_fmt == "svgz"
46def test_matplotlib_multifigure_saver_custom_format():
47 """Test the matplotlib_multifigure_saver decorator with a custom format."""
48 # Create a temporary directory for saving figures
49 temp_dir = TEMP_DIR / "test_matplotlib_multifigure_saver_custom_format"
50 temp_dir.mkdir(parents=True, exist_ok=True)
52 # Define a test function with the decorator and custom format
53 @matplotlib_multifigure_saver(names=["plot1", "plot2"], fmt="png")
54 def multi_plot_png(attn_matrix, axes_dict):
55 # Simple plots for each axis
56 axes_dict["plot1"].plot(attn_matrix.mean(axis=0))
57 axes_dict["plot1"].set_title("Mean by Column")
59 axes_dict["plot2"].plot(attn_matrix.mean(axis=1))
60 axes_dict["plot2"].set_title("Mean by Row")
62 # Generate a test attention matrix
63 attn_matrix = np.random.rand(10, 10).astype(np.float32)
65 # Call the decorated function
66 multi_plot_png(attn_matrix, temp_dir)
68 # Check that both figures were saved with the custom format
69 plot1_file = temp_dir / "multi_plot_png.plot1.png"
70 plot2_file = temp_dir / "multi_plot_png.plot2.png"
72 assert plot1_file.exists(), "Plot1 figure file was not saved"
73 assert plot2_file.exists(), "Plot2 figure file was not saved"
75 # Check that figure_save_fmt attribute was set correctly
76 assert hasattr(multi_plot_png, "figure_save_fmt")
77 assert multi_plot_png.figure_save_fmt == "png"
80def test_matplotlib_multifigure_saver_error_handling():
81 """Test error handling in the matplotlib_multifigure_saver decorator."""
82 # Create a temporary directory for saving figures
83 temp_dir = TEMP_DIR / "test_matplotlib_multifigure_saver_error_handling"
84 temp_dir.mkdir(parents=True, exist_ok=True)
86 # Define a test function that raises an error for one of the plots
87 @matplotlib_multifigure_saver(names=["good_plot", "error_plot"])
88 def plot_with_error(attn_matrix, axes_dict):
89 # This plot should work fine
90 axes_dict["good_plot"].matshow(attn_matrix)
92 # This plot will raise an error
93 axes_dict["error_plot"].plot(1 / 0) # Division by zero
95 # Generate a test attention matrix
96 attn_matrix = np.random.rand(10, 10).astype(np.float32)
98 # Call the decorated function and expect an error
99 try:
100 plot_with_error(attn_matrix, temp_dir)
101 raise AssertionError("Expected ZeroDivisionError but no exception was raised")
102 except ZeroDivisionError:
103 # Check that the first figure was saved before the error
104 good_plot_file = temp_dir / "plot_with_error.good_plot.svgz"
105 assert not good_plot_file.exists(), "Good plot file was saved before the error"
108def test_matplotlib_multifigure_saver_cleanup():
109 """Test that matplotlib_multifigure_saver properly closes figures."""
110 # Create a temporary directory for saving figures
111 temp_dir = TEMP_DIR / "test_matplotlib_multifigure_saver_cleanup"
112 temp_dir.mkdir(parents=True, exist_ok=True)
114 # Count the number of open figures before the test
115 initial_figures = len(plt.get_fignums())
117 # Define a test function with the decorator
118 @matplotlib_multifigure_saver(names=["plot1", "plot2"])
119 def multi_plot_cleanup(attn_matrix, axes_dict):
120 axes_dict["plot1"].plot(attn_matrix[0])
121 axes_dict["plot2"].plot(attn_matrix[1])
123 # Generate a test attention matrix
124 attn_matrix = np.random.rand(10, 10).astype(np.float32)
126 # Call the decorated function
127 multi_plot_cleanup(attn_matrix, temp_dir)
129 # Check that no new figures remain open
130 assert len(plt.get_fignums()) == initial_figures, "Figures were not properly closed"