Coverage for tests\unit\test_multifig.py: 98%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-04 01:50 -0700

1from pathlib import Path 

2 

3import matplotlib.pyplot as plt 

4import numpy as np 

5 

6from pattern_lens.figure_util import matplotlib_multifigure_saver 

7 

8TEMP_DIR: Path = Path("tests/_temp") 

9 

10 

11def test_matplotlib_multifigure_saver(): 

12 """Test the matplotlib_multifigure_saver decorator.""" 

13 # Create a temporary directory for saving figures 

14 temp_dir = TEMP_DIR / "test_matplotlib_multifigure_saver" 

15 temp_dir.mkdir(parents=True, exist_ok=True) 

16 

17 # Define a test function with the decorator 

18 @matplotlib_multifigure_saver(names=["hist", "heatmap"]) 

19 def multi_plot(attn_matrix, axes_dict): 

20 # Plot histogram 

21 axes_dict["hist"].hist(attn_matrix.flatten(), bins=30) 

22 axes_dict["hist"].set_title("Attention Values Histogram") 

23 

24 # Plot heatmap 

25 im = axes_dict["heatmap"].matshow(attn_matrix, cmap="viridis") # noqa: F841 

26 axes_dict["heatmap"].set_title("Attention Heatmap") 

27 

28 # Generate a test attention matrix 

29 attn_matrix = np.random.rand(10, 10).astype(np.float32) 

30 

31 # Call the decorated function 

32 multi_plot(attn_matrix, temp_dir) 

33 

34 # Check that both figures were saved 

35 hist_file = temp_dir / "multi_plot.hist.svgz" 

36 heatmap_file = temp_dir / "multi_plot.heatmap.svgz" 

37 

38 assert hist_file.exists(), "Histogram figure file was not saved" 

39 assert heatmap_file.exists(), "Heatmap figure file was not saved" 

40 

41 # Check that figure_save_fmt attribute was set correctly 

42 assert hasattr(multi_plot, "figure_save_fmt") 

43 assert multi_plot.figure_save_fmt == "svgz" 

44 

45 

46def test_matplotlib_multifigure_saver_custom_format(): 

47 """Test the matplotlib_multifigure_saver decorator with a custom format.""" 

48 # Create a temporary directory for saving figures 

49 temp_dir = TEMP_DIR / "test_matplotlib_multifigure_saver_custom_format" 

50 temp_dir.mkdir(parents=True, exist_ok=True) 

51 

52 # Define a test function with the decorator and custom format 

53 @matplotlib_multifigure_saver(names=["plot1", "plot2"], fmt="png") 

54 def multi_plot_png(attn_matrix, axes_dict): 

55 # Simple plots for each axis 

56 axes_dict["plot1"].plot(attn_matrix.mean(axis=0)) 

57 axes_dict["plot1"].set_title("Mean by Column") 

58 

59 axes_dict["plot2"].plot(attn_matrix.mean(axis=1)) 

60 axes_dict["plot2"].set_title("Mean by Row") 

61 

62 # Generate a test attention matrix 

63 attn_matrix = np.random.rand(10, 10).astype(np.float32) 

64 

65 # Call the decorated function 

66 multi_plot_png(attn_matrix, temp_dir) 

67 

68 # Check that both figures were saved with the custom format 

69 plot1_file = temp_dir / "multi_plot_png.plot1.png" 

70 plot2_file = temp_dir / "multi_plot_png.plot2.png" 

71 

72 assert plot1_file.exists(), "Plot1 figure file was not saved" 

73 assert plot2_file.exists(), "Plot2 figure file was not saved" 

74 

75 # Check that figure_save_fmt attribute was set correctly 

76 assert hasattr(multi_plot_png, "figure_save_fmt") 

77 assert multi_plot_png.figure_save_fmt == "png" 

78 

79 

80def test_matplotlib_multifigure_saver_error_handling(): 

81 """Test error handling in the matplotlib_multifigure_saver decorator.""" 

82 # Create a temporary directory for saving figures 

83 temp_dir = TEMP_DIR / "test_matplotlib_multifigure_saver_error_handling" 

84 temp_dir.mkdir(parents=True, exist_ok=True) 

85 

86 # Define a test function that raises an error for one of the plots 

87 @matplotlib_multifigure_saver(names=["good_plot", "error_plot"]) 

88 def plot_with_error(attn_matrix, axes_dict): 

89 # This plot should work fine 

90 axes_dict["good_plot"].matshow(attn_matrix) 

91 

92 # This plot will raise an error 

93 axes_dict["error_plot"].plot(1 / 0) # Division by zero 

94 

95 # Generate a test attention matrix 

96 attn_matrix = np.random.rand(10, 10).astype(np.float32) 

97 

98 # Call the decorated function and expect an error 

99 try: 

100 plot_with_error(attn_matrix, temp_dir) 

101 raise AssertionError("Expected ZeroDivisionError but no exception was raised") 

102 except ZeroDivisionError: 

103 # Check that the first figure was saved before the error 

104 good_plot_file = temp_dir / "plot_with_error.good_plot.svgz" 

105 assert not good_plot_file.exists(), "Good plot file was saved before the error" 

106 

107 

108def test_matplotlib_multifigure_saver_cleanup(): 

109 """Test that matplotlib_multifigure_saver properly closes figures.""" 

110 # Create a temporary directory for saving figures 

111 temp_dir = TEMP_DIR / "test_matplotlib_multifigure_saver_cleanup" 

112 temp_dir.mkdir(parents=True, exist_ok=True) 

113 

114 # Count the number of open figures before the test 

115 initial_figures = len(plt.get_fignums()) 

116 

117 # Define a test function with the decorator 

118 @matplotlib_multifigure_saver(names=["plot1", "plot2"]) 

119 def multi_plot_cleanup(attn_matrix, axes_dict): 

120 axes_dict["plot1"].plot(attn_matrix[0]) 

121 axes_dict["plot2"].plot(attn_matrix[1]) 

122 

123 # Generate a test attention matrix 

124 attn_matrix = np.random.rand(10, 10).astype(np.float32) 

125 

126 # Call the decorated function 

127 multi_plot_cleanup(attn_matrix, temp_dir) 

128 

129 # Check that no new figures remain open 

130 assert len(plt.get_fignums()) == initial_figures, "Figures were not properly closed"