Coverage for tests/integration/test_clis.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-06 15:09 -0600

1import sys 

2from unittest import mock 

3 

4import pytest 

5 

6from pattern_lens.activations import main as activations_main 

7from pattern_lens.figures import main as figures_main 

8from pattern_lens.server import main as server_main 

9 

10 

11def test_activations_cli(): 

12 """Test the activations command line interface.""" 

13 test_args = [ 

14 "pattern_lens.activations", 

15 "--model", 

16 "gpt2", 

17 "--prompts", 

18 "test_prompts.jsonl", 

19 "--save-path", 

20 "test_data", 

21 "--min-chars", 

22 "100", 

23 "--max-chars", 

24 "1000", 

25 "--n-samples", 

26 "5", 

27 "--force", 

28 "--raw-prompts", 

29 "--shuffle", 

30 "--device", 

31 "cpu", 

32 ] 

33 

34 with ( 

35 mock.patch.object(sys, "argv", test_args), 

36 mock.patch( 

37 "pattern_lens.activations.activations_main", 

38 ) as mock_activations_main, 

39 ): 

40 # Mock SpinnerContext to prevent actual spinner during tests 

41 with mock.patch("pattern_lens.activations.SpinnerContext"): 

42 activations_main() 

43 

44 # Check that activations_main was called with the right arguments 

45 mock_activations_main.assert_called_once() 

46 args, kwargs = mock_activations_main.call_args 

47 

48 assert kwargs["model_name"] == "gpt2" 

49 assert kwargs["prompts_path"] == "test_prompts.jsonl" 

50 assert kwargs["save_path"] == "test_data" 

51 assert kwargs["min_chars"] == 100 

52 assert kwargs["max_chars"] == 1000 

53 assert kwargs["n_samples"] == 5 

54 assert kwargs["force"] is True 

55 assert kwargs["raw_prompts"] is True 

56 assert kwargs["shuffle"] is True 

57 assert kwargs["device"] == "cpu" 

58 

59 

60def test_figures_cli(): 

61 """Test the figures command line interface.""" 

62 test_args = [ 

63 "pattern_lens.figures", 

64 "--model", 

65 "gpt2", 

66 "--save-path", 

67 "test_data", 

68 "--n-samples", 

69 "5", 

70 "--force", 

71 "True", 

72 ] 

73 

74 with ( 

75 mock.patch.object(sys, "argv", test_args), 

76 mock.patch("pattern_lens.figures.figures_main") as mock_figures_main, 

77 ): 

78 # Mock SpinnerContext to prevent actual spinner during tests 

79 with mock.patch("pattern_lens.figures.SpinnerContext"): 

80 figures_main() 

81 

82 # Check that figures_main was called with the right arguments 

83 mock_figures_main.assert_called_once() 

84 args, kwargs = mock_figures_main.call_args 

85 

86 assert kwargs["model_name"] == "gpt2" 

87 assert kwargs["save_path"] == "test_data" 

88 assert kwargs["n_samples"] == 5 

89 assert kwargs["force"] is True 

90 

91 

92def test_figures_cli_with_multiple_models(): 

93 """Test the figures command line interface with multiple models.""" 

94 test_args = [ 

95 "pattern_lens.figures", 

96 "--model", 

97 "gpt2,pythia-70m", 

98 "--save-path", 

99 "test_data", 

100 ] 

101 

102 with ( 

103 mock.patch.object(sys, "argv", test_args), 

104 mock.patch("pattern_lens.figures.figures_main") as mock_figures_main, 

105 ): 

106 # Mock SpinnerContext to prevent actual spinner during tests 

107 with mock.patch("pattern_lens.figures.SpinnerContext"): 

108 figures_main() 

109 

110 # Check that figures_main was called for each model 

111 assert mock_figures_main.call_count == 2 

112 

113 # First call should be for gpt2 

114 args1, kwargs1 = mock_figures_main.call_args_list[0] 

115 assert kwargs1["model_name"] == "gpt2" 

116 

117 # Second call should be for pythia-70m 

118 args2, kwargs2 = mock_figures_main.call_args_list[1] 

119 assert kwargs2["model_name"] == "pythia-70m" 

120 

121 

122def test_server_cli(): 

123 """Test the server command line interface.""" 

124 test_args = [ 

125 "pattern_lens.server", 

126 "--port", 

127 "8080", 

128 # "--path", 

129 # "test_path", 

130 "--rewrite-index", 

131 ] 

132 

133 with ( 

134 mock.patch.object(sys, "argv", test_args), 

135 mock.patch("socketserver.TCPServer") as mock_server, 

136 # mock.patch("pattern_lens.server.write_html_index") as mock_write_html, 

137 # mock.patch("os.chdir") as mock_chdir, 

138 ): 

139 # Set up the mock server to raise KeyboardInterrupt when serve_forever is called 

140 mock_server_instance = mock_server.return_value.__enter__.return_value 

141 mock_server_instance.serve_forever.side_effect = KeyboardInterrupt() 

142 

143 # Call server_main 

144 with pytest.raises(SystemExit): 

145 server_main() 

146 

147 # TODO: make these mock checks work -- I have no idea how to use mock properly 

148 # # Check that write_html_index was called 

149 # mock_write_html.assert_called_once() 

150 

151 # # Check that chdir was called with the right path 

152 # mock_chdir.assert_called_once_with("test_path") 

153 

154 # Check that server was started with the right port 

155 # mock_server.assert_called_once_with(("", 8080), mock.ANY) 

156 

157 # Check that serve_forever was called 

158 # mock_server_instance.serve_forever.assert_called_once()