Coverage for tests/test_ssh_agent_client.py: 100.000%
166 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-14 11:39 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-14 11:39 +0200
1# SPDX-FileCopyrightText: 2024 Marco Ricci <m@the13thletter.info>
2#
3# SPDX-License-Identifier: MIT
5"""Test OpenSSH key loading and signing."""
7from __future__ import annotations
9import base64
10import io
11import os
12import socket
13import subprocess
14from typing_extensions import Any
16import click
17import click.testing
18import derivepassphrase
19import derivepassphrase.cli
20import pytest
21import ssh_agent_client
22import tests
24class TestStaticFunctionality:
26 @pytest.mark.parametrize(['public_key', 'public_key_data'],
27 [(val['public_key'], val['public_key_data'])
28 for val in tests.SUPPORTED_KEYS.values()])
29 def test_100_key_decoding(self, public_key, public_key_data):
30 keydata = base64.b64decode(public_key.split(None, 2)[1])
31 assert (
32 keydata == public_key_data
33 ), "recorded public key data doesn't match"
35 def test_200_constructor_no_running_agent(self, monkeypatch):
36 monkeypatch.delenv('SSH_AUTH_SOCK', raising=False)
37 sock = socket.socket(family=socket.AF_UNIX)
38 with pytest.raises(KeyError,
39 match='SSH_AUTH_SOCK environment variable'):
40 ssh_agent_client.SSHAgentClient(socket=sock)
42 @pytest.mark.parametrize(['input', 'expected'], [
43 (16777216, b'\x01\x00\x00\x00'),
44 ])
45 def test_210_uint32(self, input, expected):
46 uint32 = ssh_agent_client.SSHAgentClient.uint32
47 assert uint32(input) == expected
49 @pytest.mark.parametrize(['input', 'expected'], [
50 (b'ssh-rsa', b'\x00\x00\x00\x07ssh-rsa'),
51 (b'ssh-ed25519', b'\x00\x00\x00\x0bssh-ed25519'),
52 (
53 ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'),
54 b'\x00\x00\x00\x0f\x00\x00\x00\x0bssh-ed25519',
55 ),
56 ])
57 def test_211_string(self, input, expected):
58 string = ssh_agent_client.SSHAgentClient.string
59 assert bytes(string(input)) == expected
61 @pytest.mark.parametrize(['input', 'expected'], [
62 (b'\x00\x00\x00\x07ssh-rsa', b'ssh-rsa'),
63 (
64 ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'),
65 b'ssh-ed25519',
66 ),
67 ])
68 def test_212_unstring(self, input, expected):
69 unstring = ssh_agent_client.SSHAgentClient.unstring
70 unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix
71 assert bytes(unstring(input)) == expected
72 assert tuple(
73 bytes(x) for x in unstring_prefix(input)
74 ) == (expected, b'')
76 @pytest.mark.parametrize(['value', 'exc_type', 'exc_pattern'], [
77 (10000000000000000, OverflowError, 'int too big to convert'),
78 (-1, OverflowError, "can't convert negative int to unsigned"),
79 ])
80 def test_310_uint32_exceptions(self, value, exc_type, exc_pattern):
81 uint32 = ssh_agent_client.SSHAgentClient.uint32
82 with pytest.raises(exc_type, match=exc_pattern):
83 uint32(value)
85 @pytest.mark.parametrize(['input', 'exc_type', 'exc_pattern'], [
86 ('some string', TypeError, 'invalid payload type'),
87 ])
88 def test_311_string_exceptions(self, input, exc_type, exc_pattern):
89 string = ssh_agent_client.SSHAgentClient.string
90 with pytest.raises(exc_type, match=exc_pattern):
91 string(input)
93 @pytest.mark.parametrize(
94 ['input', 'exc_type', 'exc_pattern', 'has_trailer', 'parts'], [
95 (b'ssh', ValueError, 'malformed SSH byte string', False, None),
96 (
97 b'\x00\x00\x00\x08ssh-rsa',
98 ValueError, 'malformed SSH byte string',
99 False, None,
100 ),
101 (
102 b'\x00\x00\x00\x04XXX trailing text',
103 ValueError, 'malformed SSH byte string',
104 True, (b'XXX ', b'trailing text'),
105 ),
106 ])
107 def test_312_unstring_exceptions(self, input, exc_type, exc_pattern,
108 has_trailer, parts):
109 unstring = ssh_agent_client.SSHAgentClient.unstring
110 unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix
111 with pytest.raises(exc_type, match=exc_pattern):
112 unstring(input)
113 if has_trailer:
114 assert tuple(bytes(x) for x in unstring_prefix(input)) == parts
115 else:
116 with pytest.raises(exc_type, match=exc_pattern):
117 unstring_prefix(input)
119@tests.skip_if_no_agent
120class TestAgentInteraction:
122 @pytest.mark.parametrize(['keytype', 'data_dict'],
123 list(tests.SUPPORTED_KEYS.items()))
124 def test_200_sign_data_via_agent(self, keytype, data_dict):
125 private_key = data_dict['private_key']
126 try:
127 result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'],
128 input=private_key, check=True,
129 capture_output=True)
130 except subprocess.CalledProcessError as e:
131 pytest.skip(
132 f"uploading test key: {e!r}, stdout={e.stdout!r}, "
133 f"stderr={e.stderr!r}"
134 )
135 else:
136 try:
137 client = ssh_agent_client.SSHAgentClient()
138 except OSError: # pragma: no cover
139 pytest.skip('communication error with the SSH agent')
140 with client:
141 key_comment_pairs = {bytes(k): bytes(c)
142 for k, c in client.list_keys()}
143 public_key_data = data_dict['public_key_data']
144 expected_signature = data_dict['expected_signature']
145 derived_passphrase = data_dict['derived_passphrase']
146 if public_key_data not in key_comment_pairs: # pragma: no cover
147 pytest.skip('prerequisite SSH key not loaded')
148 signature = bytes(client.sign(
149 payload=derivepassphrase.Vault._UUID, key=public_key_data))
150 assert signature == expected_signature, 'SSH signature mismatch'
151 signature2 = bytes(client.sign(
152 payload=derivepassphrase.Vault._UUID, key=public_key_data))
153 assert signature2 == expected_signature, 'SSH signature mismatch'
154 assert (
155 derivepassphrase.Vault.phrase_from_key(public_key_data) ==
156 derived_passphrase
157 ), 'SSH signature mismatch'
159 @pytest.mark.parametrize(['keytype', 'data_dict'],
160 list(tests.UNSUITABLE_KEYS.items()))
161 def test_201_sign_data_via_agent_unsupported(self, keytype, data_dict):
162 private_key = data_dict['private_key']
163 try:
164 result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'],
165 input=private_key, check=True,
166 capture_output=True)
167 except subprocess.CalledProcessError as e: # pragma: no cover
168 pytest.skip(
169 f"uploading test key: {e!r}, stdout={e.stdout!r}, "
170 f"stderr={e.stderr!r}"
171 )
172 else:
173 try:
174 client = ssh_agent_client.SSHAgentClient()
175 except OSError: # pragma: no cover
176 pytest.skip('communication error with the SSH agent')
177 with client:
178 key_comment_pairs = {bytes(k): bytes(c)
179 for k, c in client.list_keys()}
180 public_key_data = data_dict['public_key_data']
181 expected_signature = data_dict['expected_signature']
182 if public_key_data not in key_comment_pairs: # pragma: no cover
183 pytest.skip('prerequisite SSH key not loaded')
184 signature = bytes(client.sign(
185 payload=derivepassphrase.Vault._UUID, key=public_key_data))
186 signature2 = bytes(client.sign(
187 payload=derivepassphrase.Vault._UUID, key=public_key_data))
188 assert signature != signature2, 'SSH signature repeatable?!'
189 with pytest.raises(ValueError, match='unsuitable SSH key'):
190 derivepassphrase.Vault.phrase_from_key(public_key_data)
192 @staticmethod
193 def _params():
194 for value in tests.SUPPORTED_KEYS.values():
195 key = value['public_key_data']
196 yield (key, False)
197 singleton_key = tests.list_keys_singleton()[0].key
198 for value in tests.SUPPORTED_KEYS.values():
199 key = value['public_key_data']
200 if key == singleton_key:
201 yield (key, True)
203 @pytest.mark.parametrize(['key', 'single'], list(_params()))
204 def test_210_ssh_key_selector(self, monkeypatch, key, single):
205 def key_is_suitable(key: bytes):
206 return key in {v['public_key_data']
207 for v in tests.SUPPORTED_KEYS.values()}
208 if single:
209 monkeypatch.setattr(ssh_agent_client.SSHAgentClient,
210 'list_keys', tests.list_keys_singleton)
211 keys = [pair.key for pair in tests.list_keys_singleton()
212 if key_is_suitable(pair.key)]
213 index = '1'
214 text = f'Use this key? yes\n'
215 else:
216 monkeypatch.setattr(ssh_agent_client.SSHAgentClient,
217 'list_keys', tests.list_keys)
218 keys = [pair.key for pair in tests.list_keys()
219 if key_is_suitable(pair.key)]
220 index = str(1 + keys.index(key))
221 n = len(keys)
222 text = f'Your selection? (1-{n}, leave empty to abort): {index}\n'
223 b64_key = base64.standard_b64encode(key).decode('ASCII')
225 @click.command()
226 def driver():
227 key = derivepassphrase.cli._select_ssh_key()
228 click.echo(base64.standard_b64encode(key).decode('ASCII'))
230 runner = click.testing.CliRunner(mix_stderr=True)
231 result = runner.invoke(driver, [],
232 input=('yes\n' if single else f'{index}\n'),
233 catch_exceptions=True)
234 assert result.stdout.startswith('Suitable SSH keys:\n'), (
235 'missing expected output'
236 )
237 assert text in result.stdout, 'missing expected output'
238 assert (
239 result.stdout.endswith(f'\n{b64_key}\n')
240 ), 'missing expected output'
241 assert result.exit_code == 0, 'driver program failed?!'
243 del _params
245 def test_300_constructor_bad_running_agent(self, monkeypatch):
246 monkeypatch.setenv('SSH_AUTH_SOCK',
247 os.environ['SSH_AUTH_SOCK'] + '~')
248 sock = socket.socket(family=socket.AF_UNIX)
249 with pytest.raises(OSError):
250 ssh_agent_client.SSHAgentClient(socket=sock)
252 @pytest.mark.parametrize(['response'], [
253 (b'\x00\x00',),
254 (b'\x00\x00\x00\x1f some bytes missing',),
255 ])
256 def test_310_truncated_server_response(self, monkeypatch, response):
257 client = ssh_agent_client.SSHAgentClient()
258 response_stream = io.BytesIO(response)
259 class PseudoSocket(object):
260 def sendall(self, *args: Any, **kwargs: Any) -> Any:
261 return None
262 def recv(self, *args: Any, **kwargs: Any) -> Any:
263 return response_stream.read(*args, **kwargs)
264 pseudo_socket = PseudoSocket()
265 monkeypatch.setattr(client, '_connection', pseudo_socket)
266 with pytest.raises(EOFError):
267 client.request(255, b'')
269 @tests.skip_if_no_agent
270 @pytest.mark.parametrize(
271 ['response_code', 'response', 'exc_type', 'exc_pattern'],
272 [
273 (255, b'', RuntimeError, 'error return from SSH agent:'),
274 (12, b'\x00\x00\x00\x01', EOFError, 'truncated response'),
275 (
276 12,
277 b'\x00\x00\x00\x00abc',
278 ssh_agent_client.TrailingDataError,
279 'overlong response',
280 ),
281 ]
282 )
283 def test_320_list_keys_error_responses(self, monkeypatch, response_code,
284 response, exc_type, exc_pattern):
285 client = ssh_agent_client.SSHAgentClient()
286 monkeypatch.setattr(client, 'request',
287 lambda *a, **kw: (response_code, response))
288 with pytest.raises(exc_type, match=exc_pattern):
289 client.list_keys()
291 @tests.skip_if_no_agent
292 @pytest.mark.parametrize(
293 ['key', 'check', 'response', 'exc_type', 'exc_pattern'],
294 [
295 (
296 b'invalid-key',
297 True,
298 (255, b''),
299 KeyError,
300 'target SSH key not loaded into agent',
301 ),
302 (
303 tests.SUPPORTED_KEYS['ed25519']['public_key_data'],
304 True,
305 (255, b''),
306 RuntimeError,
307 'signing data failed:',
308 )
309 ]
310 )
311 def test_330_sign_error_responses(self, monkeypatch, key, check,
312 response, exc_type, exc_pattern):
313 client = ssh_agent_client.SSHAgentClient()
314 monkeypatch.setattr(client, 'request', lambda a, b: response)
315 KeyCommentPair = ssh_agent_client.types.KeyCommentPair
316 loaded_keys = [KeyCommentPair(v['public_key_data'], b'no comment')
317 for v in tests.SUPPORTED_KEYS.values()]
318 monkeypatch.setattr(client, 'list_keys', lambda: loaded_keys)
319 with pytest.raises(exc_type, match=exc_pattern):
320 client.sign(key, b'abc', check_if_key_loaded=check)