Coverage for tests\integration\test_clis.py: 100%
51 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 01:50 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 01:50 -0700
1import sys
2from unittest import mock
4import pytest
6from pattern_lens.activations import main as activations_main
7from pattern_lens.figures import main as figures_main
8from pattern_lens.server import main as server_main
11def test_activations_cli():
12 """Test the activations command line interface."""
13 test_args = [
14 "pattern_lens.activations",
15 "--model",
16 "gpt2",
17 "--prompts",
18 "test_prompts.jsonl",
19 "--save-path",
20 "test_data",
21 "--min-chars",
22 "100",
23 "--max-chars",
24 "1000",
25 "--n-samples",
26 "5",
27 "--force",
28 "--raw-prompts",
29 "--shuffle",
30 "--device",
31 "cpu",
32 ]
34 with (
35 mock.patch.object(sys, "argv", test_args),
36 mock.patch(
37 "pattern_lens.activations.activations_main",
38 ) as mock_activations_main,
39 ):
40 # Mock SpinnerContext to prevent actual spinner during tests
41 with mock.patch("pattern_lens.activations.SpinnerContext"):
42 activations_main()
44 # Check that activations_main was called with the right arguments
45 mock_activations_main.assert_called_once()
46 args, kwargs = mock_activations_main.call_args
48 assert kwargs["model_name"] == "gpt2"
49 assert kwargs["prompts_path"] == "test_prompts.jsonl"
50 assert kwargs["save_path"] == "test_data"
51 assert kwargs["min_chars"] == 100
52 assert kwargs["max_chars"] == 1000
53 assert kwargs["n_samples"] == 5
54 assert kwargs["force"] is True
55 assert kwargs["raw_prompts"] is True
56 assert kwargs["shuffle"] is True
57 assert kwargs["device"] == "cpu"
60def test_figures_cli():
61 """Test the figures command line interface."""
62 test_args = [
63 "pattern_lens.figures",
64 "--model",
65 "gpt2",
66 "--save-path",
67 "test_data",
68 "--n-samples",
69 "5",
70 "--force",
71 "True",
72 ]
74 with (
75 mock.patch.object(sys, "argv", test_args),
76 mock.patch("pattern_lens.figures.figures_main") as mock_figures_main,
77 ):
78 # Mock SpinnerContext to prevent actual spinner during tests
79 with mock.patch("pattern_lens.figures.SpinnerContext"):
80 figures_main()
82 # Check that figures_main was called with the right arguments
83 mock_figures_main.assert_called_once()
84 args, kwargs = mock_figures_main.call_args
86 assert kwargs["model_name"] == "gpt2"
87 assert kwargs["save_path"] == "test_data"
88 assert kwargs["n_samples"] == 5
89 assert kwargs["force"] is True
92def test_figures_cli_with_multiple_models():
93 """Test the figures command line interface with multiple models."""
94 test_args = [
95 "pattern_lens.figures",
96 "--model",
97 "gpt2,pythia-70m",
98 "--save-path",
99 "test_data",
100 ]
102 with (
103 mock.patch.object(sys, "argv", test_args),
104 mock.patch("pattern_lens.figures.figures_main") as mock_figures_main,
105 ):
106 # Mock SpinnerContext to prevent actual spinner during tests
107 with mock.patch("pattern_lens.figures.SpinnerContext"):
108 figures_main()
110 # Check that figures_main was called for each model
111 assert mock_figures_main.call_count == 2
113 # First call should be for gpt2
114 args1, kwargs1 = mock_figures_main.call_args_list[0]
115 assert kwargs1["model_name"] == "gpt2"
117 # Second call should be for pythia-70m
118 args2, kwargs2 = mock_figures_main.call_args_list[1]
119 assert kwargs2["model_name"] == "pythia-70m"
122def test_server_cli():
123 """Test the server command line interface."""
124 test_args = [
125 "pattern_lens.server",
126 "--port",
127 "8080",
128 # "--path",
129 # "test_path",
130 "--rewrite-index",
131 ]
133 with (
134 mock.patch.object(sys, "argv", test_args),
135 mock.patch("socketserver.TCPServer") as mock_server,
136 # mock.patch("pattern_lens.server.write_html_index") as mock_write_html,
137 # mock.patch("os.chdir") as mock_chdir,
138 ):
139 # Set up the mock server to raise KeyboardInterrupt when serve_forever is called
140 mock_server_instance = mock_server.return_value.__enter__.return_value
141 mock_server_instance.serve_forever.side_effect = KeyboardInterrupt()
143 # Call server_main
144 with pytest.raises(SystemExit):
145 server_main()
147 # TODO: make these mock checks work -- I have no idea how to use mock properly
148 # # Check that write_html_index was called
149 # mock_write_html.assert_called_once()
151 # # Check that chdir was called with the right path
152 # mock_chdir.assert_called_once_with("test_path")
154 # Check that server was started with the right port
155 # mock_server.assert_called_once_with(("", 8080), mock.ANY)
157 # Check that serve_forever was called
158 # mock_server_instance.serve_forever.assert_called_once()