Coverage for tests / integration / test_clis.py: 100%
52 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1import sys
2from unittest import mock
4import pytest
6from pattern_lens.activations import main as activations_main
7from pattern_lens.figures import main as figures_main
8from pattern_lens.server import main as server_main
11def test_activations_cli():
12 """Test the activations command line interface."""
13 test_args = [
14 "pattern_lens.activations",
15 "--model",
16 "gpt2",
17 "--prompts",
18 "test_prompts.jsonl",
19 "--save-path",
20 "test_data",
21 "--min-chars",
22 "100",
23 "--max-chars",
24 "1000",
25 "--n-samples",
26 "5",
27 "--batch-size",
28 "16",
29 "--force",
30 "--raw-prompts",
31 "--shuffle",
32 "--device",
33 "cpu",
34 ]
36 with (
37 mock.patch.object(sys, "argv", test_args),
38 mock.patch(
39 "pattern_lens.activations.activations_main",
40 ) as mock_activations_main,
41 ):
42 # Mock SpinnerContext to prevent actual spinner during tests
43 with mock.patch("pattern_lens.activations.SpinnerContext"):
44 activations_main()
46 # Check that activations_main was called with the right arguments
47 mock_activations_main.assert_called_once()
48 _args, kwargs = mock_activations_main.call_args
50 assert kwargs["model_name"] == "gpt2"
51 assert kwargs["prompts_path"] == "test_prompts.jsonl"
52 assert kwargs["save_path"] == "test_data"
53 assert kwargs["min_chars"] == 100
54 assert kwargs["max_chars"] == 1000
55 assert kwargs["n_samples"] == 5
56 assert kwargs["force"] is True
57 assert kwargs["raw_prompts"] is True
58 assert kwargs["shuffle"] is True
59 assert kwargs["device"] == "cpu"
60 assert kwargs["batch_size"] == 16
63def test_figures_cli():
64 """Test the figures command line interface."""
65 test_args = [
66 "pattern_lens.figures",
67 "--model",
68 "gpt2",
69 "--save-path",
70 "test_data",
71 "--n-samples",
72 "5",
73 "--force",
74 "True",
75 ]
77 with (
78 mock.patch.object(sys, "argv", test_args),
79 mock.patch("pattern_lens.figures.figures_main") as mock_figures_main,
80 ):
81 # Mock SpinnerContext to prevent actual spinner during tests
82 with mock.patch("pattern_lens.figures.SpinnerContext"):
83 figures_main()
85 # Check that figures_main was called with the right arguments
86 mock_figures_main.assert_called_once()
87 _args, kwargs = mock_figures_main.call_args
89 assert kwargs["model_name"] == "gpt2"
90 assert kwargs["save_path"] == "test_data"
91 assert kwargs["n_samples"] == 5
92 assert kwargs["force"] is True
95def test_figures_cli_with_multiple_models():
96 """Test the figures command line interface with multiple models."""
97 test_args = [
98 "pattern_lens.figures",
99 "--model",
100 "gpt2,pythia-70m",
101 "--save-path",
102 "test_data",
103 ]
105 with (
106 mock.patch.object(sys, "argv", test_args),
107 mock.patch("pattern_lens.figures.figures_main") as mock_figures_main,
108 ):
109 # Mock SpinnerContext to prevent actual spinner during tests
110 with mock.patch("pattern_lens.figures.SpinnerContext"):
111 figures_main()
113 # Check that figures_main was called for each model
114 assert mock_figures_main.call_count == 2
116 # First call should be for gpt2
117 _args1, kwargs1 = mock_figures_main.call_args_list[0]
118 assert kwargs1["model_name"] == "gpt2"
120 # Second call should be for pythia-70m
121 _args2, kwargs2 = mock_figures_main.call_args_list[1]
122 assert kwargs2["model_name"] == "pythia-70m"
125def test_server_cli():
126 """Test the server command line interface."""
127 test_args = [
128 "pattern_lens.server",
129 "--port",
130 "8080",
131 # "--path",
132 # "test_path",
133 "--rewrite-index",
134 ]
136 with (
137 mock.patch.object(sys, "argv", test_args),
138 mock.patch("socketserver.TCPServer") as mock_server,
139 # mock.patch("pattern_lens.server.write_html_index") as mock_write_html,
140 # mock.patch("os.chdir") as mock_chdir,
141 ):
142 # Set up the mock server to raise KeyboardInterrupt when serve_forever is called
143 mock_server_instance = mock_server.return_value.__enter__.return_value
144 mock_server_instance.serve_forever.side_effect = KeyboardInterrupt()
146 # Call server_main
147 with pytest.raises(SystemExit):
148 server_main()
150 # TODO: make these mock checks work -- I have no idea how to use mock properly
151 # # Check that write_html_index was called
152 # mock_write_html.assert_called_once()
154 # # Check that chdir was called with the right path
155 # mock_chdir.assert_called_once_with("test_path")
157 # Check that server was started with the right port
158 # mock_server.assert_called_once_with(("", 8080), mock.ANY)
160 # Check that serve_forever was called
161 # mock_server_instance.serve_forever.assert_called_once()