Coverage for tests / integration / test_clis.py: 100%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1import sys 

2from unittest import mock 

3 

4import pytest 

5 

6from pattern_lens.activations import main as activations_main 

7from pattern_lens.figures import main as figures_main 

8from pattern_lens.server import main as server_main 

9 

10 

11def test_activations_cli(): 

12 """Test the activations command line interface.""" 

13 test_args = [ 

14 "pattern_lens.activations", 

15 "--model", 

16 "gpt2", 

17 "--prompts", 

18 "test_prompts.jsonl", 

19 "--save-path", 

20 "test_data", 

21 "--min-chars", 

22 "100", 

23 "--max-chars", 

24 "1000", 

25 "--n-samples", 

26 "5", 

27 "--batch-size", 

28 "16", 

29 "--force", 

30 "--raw-prompts", 

31 "--shuffle", 

32 "--device", 

33 "cpu", 

34 ] 

35 

36 with ( 

37 mock.patch.object(sys, "argv", test_args), 

38 mock.patch( 

39 "pattern_lens.activations.activations_main", 

40 ) as mock_activations_main, 

41 ): 

42 # Mock SpinnerContext to prevent actual spinner during tests 

43 with mock.patch("pattern_lens.activations.SpinnerContext"): 

44 activations_main() 

45 

46 # Check that activations_main was called with the right arguments 

47 mock_activations_main.assert_called_once() 

48 _args, kwargs = mock_activations_main.call_args 

49 

50 assert kwargs["model_name"] == "gpt2" 

51 assert kwargs["prompts_path"] == "test_prompts.jsonl" 

52 assert kwargs["save_path"] == "test_data" 

53 assert kwargs["min_chars"] == 100 

54 assert kwargs["max_chars"] == 1000 

55 assert kwargs["n_samples"] == 5 

56 assert kwargs["force"] is True 

57 assert kwargs["raw_prompts"] is True 

58 assert kwargs["shuffle"] is True 

59 assert kwargs["device"] == "cpu" 

60 assert kwargs["batch_size"] == 16 

61 

62 

63def test_figures_cli(): 

64 """Test the figures command line interface.""" 

65 test_args = [ 

66 "pattern_lens.figures", 

67 "--model", 

68 "gpt2", 

69 "--save-path", 

70 "test_data", 

71 "--n-samples", 

72 "5", 

73 "--force", 

74 "True", 

75 ] 

76 

77 with ( 

78 mock.patch.object(sys, "argv", test_args), 

79 mock.patch("pattern_lens.figures.figures_main") as mock_figures_main, 

80 ): 

81 # Mock SpinnerContext to prevent actual spinner during tests 

82 with mock.patch("pattern_lens.figures.SpinnerContext"): 

83 figures_main() 

84 

85 # Check that figures_main was called with the right arguments 

86 mock_figures_main.assert_called_once() 

87 _args, kwargs = mock_figures_main.call_args 

88 

89 assert kwargs["model_name"] == "gpt2" 

90 assert kwargs["save_path"] == "test_data" 

91 assert kwargs["n_samples"] == 5 

92 assert kwargs["force"] is True 

93 

94 

95def test_figures_cli_with_multiple_models(): 

96 """Test the figures command line interface with multiple models.""" 

97 test_args = [ 

98 "pattern_lens.figures", 

99 "--model", 

100 "gpt2,pythia-70m", 

101 "--save-path", 

102 "test_data", 

103 ] 

104 

105 with ( 

106 mock.patch.object(sys, "argv", test_args), 

107 mock.patch("pattern_lens.figures.figures_main") as mock_figures_main, 

108 ): 

109 # Mock SpinnerContext to prevent actual spinner during tests 

110 with mock.patch("pattern_lens.figures.SpinnerContext"): 

111 figures_main() 

112 

113 # Check that figures_main was called for each model 

114 assert mock_figures_main.call_count == 2 

115 

116 # First call should be for gpt2 

117 _args1, kwargs1 = mock_figures_main.call_args_list[0] 

118 assert kwargs1["model_name"] == "gpt2" 

119 

120 # Second call should be for pythia-70m 

121 _args2, kwargs2 = mock_figures_main.call_args_list[1] 

122 assert kwargs2["model_name"] == "pythia-70m" 

123 

124 

125def test_server_cli(): 

126 """Test the server command line interface.""" 

127 test_args = [ 

128 "pattern_lens.server", 

129 "--port", 

130 "8080", 

131 # "--path", 

132 # "test_path", 

133 "--rewrite-index", 

134 ] 

135 

136 with ( 

137 mock.patch.object(sys, "argv", test_args), 

138 mock.patch("socketserver.TCPServer") as mock_server, 

139 # mock.patch("pattern_lens.server.write_html_index") as mock_write_html, 

140 # mock.patch("os.chdir") as mock_chdir, 

141 ): 

142 # Set up the mock server to raise KeyboardInterrupt when serve_forever is called 

143 mock_server_instance = mock_server.return_value.__enter__.return_value 

144 mock_server_instance.serve_forever.side_effect = KeyboardInterrupt() 

145 

146 # Call server_main 

147 with pytest.raises(SystemExit): 

148 server_main() 

149 

150 # TODO: make these mock checks work -- I have no idea how to use mock properly 

151 # # Check that write_html_index was called 

152 # mock_write_html.assert_called_once() 

153 

154 # # Check that chdir was called with the right path 

155 # mock_chdir.assert_called_once_with("test_path") 

156 

157 # Check that server was started with the right port 

158 # mock_server.assert_called_once_with(("", 8080), mock.ANY) 

159 

160 # Check that serve_forever was called 

161 # mock_server_instance.serve_forever.assert_called_once()