Coverage for tests/test_key_signing.py: 100.000%

179 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-06-23 19:40 +0200

1# SPDX-FileCopyrightText: 2024 Marco Ricci <m@the13thletter.info> 

2# 

3# SPDX-License-Identifier: MIT 

4 

5"""Test OpenSSH key loading and signing.""" 

6 

7from __future__ import annotations 

8 

9import click 

10import pytest 

11 

12import derivepassphrase 

13import derivepassphrase.cli 

14import ssh_agent_client 

15import tests 

16 

17import base64 

18import errno 

19import io 

20import os 

21import socket 

22import subprocess 

23 

24def test_client_uint32(): 

25 uint32 = ssh_agent_client.SSHAgentClient.uint32 

26 assert uint32(16777216) == b'\x01\x00\x00\x00' 

27 

28@pytest.mark.parametrize(['value', 'exc_type', 'exc_pattern'], [ 

29 (10000000000000000, OverflowError, 'int too big to convert'), 

30 (-1, OverflowError, "can't convert negative int to unsigned"), 

31]) 

32def test_client_uint32_exceptions(value, exc_type, exc_pattern): 

33 uint32 = ssh_agent_client.SSHAgentClient.uint32 

34 with pytest.raises(exc_type, match=exc_pattern): 

35 uint32(value) 

36 

37@pytest.mark.parametrize(['input', 'expected'], [ 

38 (b'ssh-rsa', b'\x00\x00\x00\x07ssh-rsa'), 

39 (b'ssh-ed25519', b'\x00\x00\x00\x0bssh-ed25519'), 

40 ( 

41 ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'), 

42 b'\x00\x00\x00\x0f\x00\x00\x00\x0bssh-ed25519', 

43 ), 

44]) 

45def test_client_string(input, expected): 

46 string = ssh_agent_client.SSHAgentClient.string 

47 assert bytes(string(input)) == expected 

48 

49@pytest.mark.parametrize(['input', 'exc_type', 'exc_pattern'], [ 

50 ('some string', TypeError, 'invalid payload type'), 

51]) 

52def test_client_string_exceptions(input, exc_type, exc_pattern): 

53 string = ssh_agent_client.SSHAgentClient.string 

54 with pytest.raises(exc_type, match=exc_pattern): 

55 string(input) 

56 

57@pytest.mark.parametrize(['input', 'expected'], [ 

58 (b'\x00\x00\x00\x07ssh-rsa', b'ssh-rsa'), 

59 ( 

60 ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'), 

61 b'ssh-ed25519', 

62 ), 

63]) 

64def test_client_unstring(input, expected): 

65 unstring = ssh_agent_client.SSHAgentClient.unstring 

66 unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix 

67 assert bytes(unstring(input)) == expected 

68 assert tuple(bytes(x) for x in unstring_prefix(input)) == (expected, b'') 

69 

70@pytest.mark.parametrize( 

71 ['input', 'exc_type', 'exc_pattern', 'has_trailer', 'parts'], [ 

72 (b'ssh', ValueError, 'malformed SSH byte string', False, None), 

73 ( 

74 b'\x00\x00\x00\x08ssh-rsa', 

75 ValueError, 'malformed SSH byte string', 

76 False, None, 

77 ), 

78 ( 

79 b'\x00\x00\x00\x04XXX trailing text', 

80 ValueError, 'malformed SSH byte string', 

81 True, (b'XXX ', b'trailing text'), 

82 ), 

83]) 

84def test_client_unstring_exceptions(input, exc_type, exc_pattern, 

85 has_trailer, parts): 

86 unstring = ssh_agent_client.SSHAgentClient.unstring 

87 unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix 

88 with pytest.raises(exc_type, match=exc_pattern): 

89 unstring(input) 

90 if has_trailer: 

91 assert tuple(bytes(x) for x in unstring_prefix(input)) == parts 

92 else: 

93 with pytest.raises(exc_type, match=exc_pattern): 

94 unstring_prefix(input) 

95 

96def test_key_decoding(): 

97 public_key = tests.SUPPORTED_KEYS['ed25519']['public_key'] 

98 public_key_data = tests.SUPPORTED_KEYS['ed25519']['public_key_data'] 

99 keydata = base64.b64decode(public_key.split(None, 2)[1]) 

100 assert ( 

101 keydata == public_key_data 

102 ), "recorded public key data doesn't match" 

103 

104@tests.skip_if_no_agent 

105@pytest.mark.parametrize(['keytype', 'data_dict'], 

106 list(tests.SUPPORTED_KEYS.items())) 

107def test_sign_data_via_agent(keytype, data_dict): 

108 private_key = data_dict['private_key'] 

109 try: 

110 result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'], 

111 input=private_key, check=True, 

112 capture_output=True) 

113 except subprocess.CalledProcessError as e: 

114 pytest.skip( 

115 f"uploading test key: {e!r}, stdout={e.stdout!r}, " 

116 f"stderr={e.stderr!r}" 

117 ) 

118 else: 

119 try: 

120 client = ssh_agent_client.SSHAgentClient() 

121 except OSError: # pragma: no cover 

122 pytest.skip('communication error with the SSH agent') 

123 with client: 

124 key_comment_pairs = {bytes(k): bytes(c) 

125 for k, c in client.list_keys()} 

126 public_key_data = data_dict['public_key_data'] 

127 expected_signature = data_dict['expected_signature'] 

128 derived_passphrase = data_dict['derived_passphrase'] 

129 if public_key_data not in key_comment_pairs: # pragma: no cover 

130 pytest.skip('prerequisite SSH key not loaded') 

131 signature = bytes(client.sign( 

132 payload=derivepassphrase.Vault._UUID, key=public_key_data)) 

133 assert signature == expected_signature, 'SSH signature mismatch' 

134 signature2 = bytes(client.sign( 

135 payload=derivepassphrase.Vault._UUID, key=public_key_data)) 

136 assert signature2 == expected_signature, 'SSH signature mismatch' 

137 assert ( 

138 derivepassphrase.Vault.phrase_from_key(public_key_data) == 

139 derived_passphrase 

140 ), 'SSH signature mismatch' 

141 

142@tests.skip_if_no_agent 

143@pytest.mark.parametrize(['keytype', 'data_dict'], 

144 list(tests.UNSUITABLE_KEYS.items())) 

145def test_sign_data_via_agent_unsupported(keytype, data_dict): 

146 private_key = data_dict['private_key'] 

147 try: 

148 result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'], 

149 input=private_key, check=True, 

150 capture_output=True) 

151 except subprocess.CalledProcessError as e: # pragma: no cover 

152 pytest.skip( 

153 f"uploading test key: {e!r}, stdout={e.stdout!r}, " 

154 f"stderr={e.stderr!r}" 

155 ) 

156 else: 

157 try: 

158 client = ssh_agent_client.SSHAgentClient() 

159 except OSError: # pragma: no cover 

160 pytest.skip('communication error with the SSH agent') 

161 with client: 

162 key_comment_pairs = {bytes(k): bytes(c) 

163 for k, c in client.list_keys()} 

164 public_key_data = data_dict['public_key_data'] 

165 expected_signature = data_dict['expected_signature'] 

166 if public_key_data not in key_comment_pairs: # pragma: no cover 

167 pytest.skip('prerequisite SSH key not loaded') 

168 signature = bytes(client.sign( 

169 payload=derivepassphrase.Vault._UUID, key=public_key_data)) 

170 signature2 = bytes(client.sign( 

171 payload=derivepassphrase.Vault._UUID, key=public_key_data)) 

172 assert signature != signature2, 'SSH signature repeatable?!' 

173 with pytest.raises(ValueError, match='unsuitable SSH key'): 

174 derivepassphrase.Vault.phrase_from_key(public_key_data) 

175 

176def _params(): 

177 for value in tests.SUPPORTED_KEYS.values(): 

178 key = value['public_key_data'] 

179 yield (key, False) 

180 singleton_key = tests.list_keys_singleton()[0].key 

181 for value in tests.SUPPORTED_KEYS.values(): 

182 key = value['public_key_data'] 

183 if key == singleton_key: 

184 yield (key, True) 

185@pytest.mark.parametrize(['key', 'single'], list(_params())) 

186def test_ssh_key_selector(monkeypatch, key, single): 

187 def key_is_suitable(key: bytes): 

188 return key in {v['public_key_data'] 

189 for v in tests.SUPPORTED_KEYS.values()} 

190 if single: 

191 monkeypatch.setattr(ssh_agent_client.SSHAgentClient, 

192 'list_keys', tests.list_keys_singleton) 

193 keys = [pair.key for pair in tests.list_keys_singleton() 

194 if key_is_suitable(pair.key)] 

195 index = '1' 

196 text = f'Use this key? yes\n' 

197 else: 

198 monkeypatch.setattr(ssh_agent_client.SSHAgentClient, 

199 'list_keys', tests.list_keys) 

200 keys = [pair.key for pair in tests.list_keys() 

201 if key_is_suitable(pair.key)] 

202 index = str(1 + keys.index(key)) 

203 n = len(keys) 

204 text = f'Your selection? (1-{n}, leave empty to abort): {index}\n' 

205 b64_key = base64.standard_b64encode(key).decode('ASCII') 

206 

207 @click.command() 

208 def driver(): 

209 key = derivepassphrase.cli._select_ssh_key() 

210 click.echo(base64.standard_b64encode(key).decode('ASCII')) 

211 

212 runner = click.testing.CliRunner(mix_stderr=True) 

213 result = runner.invoke(driver, [], 

214 input=('yes\n' if single else f'{index}\n'), 

215 catch_exceptions=True) 

216 assert result.stdout.startswith('Suitable SSH keys:\n'), ( 

217 'missing expected output' 

218 ) 

219 assert text in result.stdout, 'missing expected output' 

220 assert result.stdout.endswith(f'\n{b64_key}\n'), 'missing expected output' 

221 assert result.exit_code == 0, 'driver program failed?!' 

222 

223@tests.skip_if_no_agent 

224@pytest.mark.parametrize(['conn_hint'], [('none',), ('socket',), ('client',)]) 

225def test_get_suitable_ssh_keys(conn_hint): 

226 hint: ssh_agent_client.SSHAgentClient | socket.socket | None 

227 match conn_hint: 

228 case 'client': 

229 hint = ssh_agent_client.SSHAgentClient() 

230 case 'socket': 

231 hint = socket.socket(family=socket.AF_UNIX) 

232 hint.connect(os.environ['SSH_AUTH_SOCK']) 

233 case _: 

234 assert conn_hint == 'none' 

235 hint = None 

236 exception: type[Exception] | None = None 

237 try: 

238 list(derivepassphrase.cli._get_suitable_ssh_keys(hint)) 

239 except RuntimeError: # pragma: no cover 

240 pass 

241 except Exception as e: # pragma: no cover 

242 exception = e 

243 finally: 

244 assert exception == None, 'exception querying suitable SSH keys' 

245 

246def test_constructor_no_running_agent(monkeypatch): 

247 monkeypatch.delenv('SSH_AUTH_SOCK', raising=False) 

248 sock = socket.socket(family=socket.AF_UNIX) 

249 with pytest.raises(KeyError, match='SSH_AUTH_SOCK environment variable'): 

250 ssh_agent_client.SSHAgentClient(socket=sock) 

251 

252@tests.skip_if_no_agent 

253def test_constructor_bad_running_agent(monkeypatch): 

254 monkeypatch.setenv('SSH_AUTH_SOCK', os.environ['SSH_AUTH_SOCK'] + '~') 

255 sock = socket.socket(family=socket.AF_UNIX) 

256 with pytest.raises(OSError): 

257 ssh_agent_client.SSHAgentClient(socket=sock) 

258 

259@pytest.mark.parametrize(['response'], [ 

260 (b'\x00\x00',), 

261 (b'\x00\x00\x00\x1f some bytes missing',), 

262]) 

263def test_truncated_server_response(monkeypatch, response): 

264 client = ssh_agent_client.SSHAgentClient() 

265 response_stream = io.BytesIO(response) 

266 class PseudoSocket(object): 

267 pass 

268 pseudo_socket = PseudoSocket() 

269 pseudo_socket.sendall = lambda *a, **kw: None 

270 pseudo_socket.recv = response_stream.read 

271 monkeypatch.setattr(client, '_connection', pseudo_socket) 

272 with pytest.raises(EOFError): 

273 client.request(255, b'') 

274 

275@tests.skip_if_no_agent 

276@pytest.mark.parametrize( 

277 ['response_code', 'response', 'exc_type', 'exc_pattern'], 

278 [ 

279 (255, b'', RuntimeError, 'error return from SSH agent:'), 

280 (12, b'\x00\x00\x00\x01', EOFError, 'truncated response'), 

281 (12, b'\x00\x00\x00\x00abc', RuntimeError, 'overlong response'), 

282 ] 

283) 

284def test_list_keys_error_responses(monkeypatch, response_code, response, 

285 exc_type, exc_pattern): 

286 client = ssh_agent_client.SSHAgentClient() 

287 monkeypatch.setattr(client, 'request', 

288 lambda *a, **kw: (response_code, response)) 

289 with pytest.raises(exc_type, match=exc_pattern): 

290 client.list_keys() 

291 

292@tests.skip_if_no_agent 

293@pytest.mark.parametrize( 

294 ['key', 'check', 'response', 'exc_type', 'exc_pattern'], 

295 [ 

296 ( 

297 b'invalid-key', 

298 True, 

299 (255, b''), 

300 KeyError, 

301 'target SSH key not loaded into agent', 

302 ), 

303 ( 

304 tests.SUPPORTED_KEYS['ed25519']['public_key_data'], 

305 True, 

306 (255, b''), 

307 RuntimeError, 

308 'signing data failed:', 

309 ) 

310 ] 

311) 

312def test_sign_error_responses(monkeypatch, key, check, response, exc_type, 

313 exc_pattern): 

314 client = ssh_agent_client.SSHAgentClient() 

315 monkeypatch.setattr(client, 'request', lambda a, b: response) 

316 KeyCommentPair = ssh_agent_client.types.KeyCommentPair 

317 loaded_keys = [KeyCommentPair(v['public_key_data'], b'no comment') 

318 for v in tests.SUPPORTED_KEYS.values()] 

319 monkeypatch.setattr(client, 'list_keys', lambda: loaded_keys) 

320 with pytest.raises(exc_type, match=exc_pattern): 

321 client.sign(key, b'abc', check_if_key_loaded=check)