Coverage for tests/test_ssh_agent_client.py: 100.000%

166 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-14 11:39 +0200

1# SPDX-FileCopyrightText: 2024 Marco Ricci <m@the13thletter.info> 

2# 

3# SPDX-License-Identifier: MIT 

4 

5"""Test OpenSSH key loading and signing.""" 

6 

7from __future__ import annotations 

8 

9import base64 

10import io 

11import os 

12import socket 

13import subprocess 

14from typing_extensions import Any 

15 

16import click 

17import click.testing 

18import derivepassphrase 

19import derivepassphrase.cli 

20import pytest 

21import ssh_agent_client 

22import tests 

23 

24class TestStaticFunctionality: 

25 

26 @pytest.mark.parametrize(['public_key', 'public_key_data'], 

27 [(val['public_key'], val['public_key_data']) 

28 for val in tests.SUPPORTED_KEYS.values()]) 

29 def test_100_key_decoding(self, public_key, public_key_data): 

30 keydata = base64.b64decode(public_key.split(None, 2)[1]) 

31 assert ( 

32 keydata == public_key_data 

33 ), "recorded public key data doesn't match" 

34 

35 def test_200_constructor_no_running_agent(self, monkeypatch): 

36 monkeypatch.delenv('SSH_AUTH_SOCK', raising=False) 

37 sock = socket.socket(family=socket.AF_UNIX) 

38 with pytest.raises(KeyError, 

39 match='SSH_AUTH_SOCK environment variable'): 

40 ssh_agent_client.SSHAgentClient(socket=sock) 

41 

42 @pytest.mark.parametrize(['input', 'expected'], [ 

43 (16777216, b'\x01\x00\x00\x00'), 

44 ]) 

45 def test_210_uint32(self, input, expected): 

46 uint32 = ssh_agent_client.SSHAgentClient.uint32 

47 assert uint32(input) == expected 

48 

49 @pytest.mark.parametrize(['input', 'expected'], [ 

50 (b'ssh-rsa', b'\x00\x00\x00\x07ssh-rsa'), 

51 (b'ssh-ed25519', b'\x00\x00\x00\x0bssh-ed25519'), 

52 ( 

53 ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'), 

54 b'\x00\x00\x00\x0f\x00\x00\x00\x0bssh-ed25519', 

55 ), 

56 ]) 

57 def test_211_string(self, input, expected): 

58 string = ssh_agent_client.SSHAgentClient.string 

59 assert bytes(string(input)) == expected 

60 

61 @pytest.mark.parametrize(['input', 'expected'], [ 

62 (b'\x00\x00\x00\x07ssh-rsa', b'ssh-rsa'), 

63 ( 

64 ssh_agent_client.SSHAgentClient.string(b'ssh-ed25519'), 

65 b'ssh-ed25519', 

66 ), 

67 ]) 

68 def test_212_unstring(self, input, expected): 

69 unstring = ssh_agent_client.SSHAgentClient.unstring 

70 unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix 

71 assert bytes(unstring(input)) == expected 

72 assert tuple( 

73 bytes(x) for x in unstring_prefix(input) 

74 ) == (expected, b'') 

75 

76 @pytest.mark.parametrize(['value', 'exc_type', 'exc_pattern'], [ 

77 (10000000000000000, OverflowError, 'int too big to convert'), 

78 (-1, OverflowError, "can't convert negative int to unsigned"), 

79 ]) 

80 def test_310_uint32_exceptions(self, value, exc_type, exc_pattern): 

81 uint32 = ssh_agent_client.SSHAgentClient.uint32 

82 with pytest.raises(exc_type, match=exc_pattern): 

83 uint32(value) 

84 

85 @pytest.mark.parametrize(['input', 'exc_type', 'exc_pattern'], [ 

86 ('some string', TypeError, 'invalid payload type'), 

87 ]) 

88 def test_311_string_exceptions(self, input, exc_type, exc_pattern): 

89 string = ssh_agent_client.SSHAgentClient.string 

90 with pytest.raises(exc_type, match=exc_pattern): 

91 string(input) 

92 

93 @pytest.mark.parametrize( 

94 ['input', 'exc_type', 'exc_pattern', 'has_trailer', 'parts'], [ 

95 (b'ssh', ValueError, 'malformed SSH byte string', False, None), 

96 ( 

97 b'\x00\x00\x00\x08ssh-rsa', 

98 ValueError, 'malformed SSH byte string', 

99 False, None, 

100 ), 

101 ( 

102 b'\x00\x00\x00\x04XXX trailing text', 

103 ValueError, 'malformed SSH byte string', 

104 True, (b'XXX ', b'trailing text'), 

105 ), 

106 ]) 

107 def test_312_unstring_exceptions(self, input, exc_type, exc_pattern, 

108 has_trailer, parts): 

109 unstring = ssh_agent_client.SSHAgentClient.unstring 

110 unstring_prefix = ssh_agent_client.SSHAgentClient.unstring_prefix 

111 with pytest.raises(exc_type, match=exc_pattern): 

112 unstring(input) 

113 if has_trailer: 

114 assert tuple(bytes(x) for x in unstring_prefix(input)) == parts 

115 else: 

116 with pytest.raises(exc_type, match=exc_pattern): 

117 unstring_prefix(input) 

118 

119@tests.skip_if_no_agent 

120class TestAgentInteraction: 

121 

122 @pytest.mark.parametrize(['keytype', 'data_dict'], 

123 list(tests.SUPPORTED_KEYS.items())) 

124 def test_200_sign_data_via_agent(self, keytype, data_dict): 

125 private_key = data_dict['private_key'] 

126 try: 

127 result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'], 

128 input=private_key, check=True, 

129 capture_output=True) 

130 except subprocess.CalledProcessError as e: 

131 pytest.skip( 

132 f"uploading test key: {e!r}, stdout={e.stdout!r}, " 

133 f"stderr={e.stderr!r}" 

134 ) 

135 else: 

136 try: 

137 client = ssh_agent_client.SSHAgentClient() 

138 except OSError: # pragma: no cover 

139 pytest.skip('communication error with the SSH agent') 

140 with client: 

141 key_comment_pairs = {bytes(k): bytes(c) 

142 for k, c in client.list_keys()} 

143 public_key_data = data_dict['public_key_data'] 

144 expected_signature = data_dict['expected_signature'] 

145 derived_passphrase = data_dict['derived_passphrase'] 

146 if public_key_data not in key_comment_pairs: # pragma: no cover 

147 pytest.skip('prerequisite SSH key not loaded') 

148 signature = bytes(client.sign( 

149 payload=derivepassphrase.Vault._UUID, key=public_key_data)) 

150 assert signature == expected_signature, 'SSH signature mismatch' 

151 signature2 = bytes(client.sign( 

152 payload=derivepassphrase.Vault._UUID, key=public_key_data)) 

153 assert signature2 == expected_signature, 'SSH signature mismatch' 

154 assert ( 

155 derivepassphrase.Vault.phrase_from_key(public_key_data) == 

156 derived_passphrase 

157 ), 'SSH signature mismatch' 

158 

159 @pytest.mark.parametrize(['keytype', 'data_dict'], 

160 list(tests.UNSUITABLE_KEYS.items())) 

161 def test_201_sign_data_via_agent_unsupported(self, keytype, data_dict): 

162 private_key = data_dict['private_key'] 

163 try: 

164 result = subprocess.run(['ssh-add', '-t', '30', '-q', '-'], 

165 input=private_key, check=True, 

166 capture_output=True) 

167 except subprocess.CalledProcessError as e: # pragma: no cover 

168 pytest.skip( 

169 f"uploading test key: {e!r}, stdout={e.stdout!r}, " 

170 f"stderr={e.stderr!r}" 

171 ) 

172 else: 

173 try: 

174 client = ssh_agent_client.SSHAgentClient() 

175 except OSError: # pragma: no cover 

176 pytest.skip('communication error with the SSH agent') 

177 with client: 

178 key_comment_pairs = {bytes(k): bytes(c) 

179 for k, c in client.list_keys()} 

180 public_key_data = data_dict['public_key_data'] 

181 expected_signature = data_dict['expected_signature'] 

182 if public_key_data not in key_comment_pairs: # pragma: no cover 

183 pytest.skip('prerequisite SSH key not loaded') 

184 signature = bytes(client.sign( 

185 payload=derivepassphrase.Vault._UUID, key=public_key_data)) 

186 signature2 = bytes(client.sign( 

187 payload=derivepassphrase.Vault._UUID, key=public_key_data)) 

188 assert signature != signature2, 'SSH signature repeatable?!' 

189 with pytest.raises(ValueError, match='unsuitable SSH key'): 

190 derivepassphrase.Vault.phrase_from_key(public_key_data) 

191 

192 @staticmethod 

193 def _params(): 

194 for value in tests.SUPPORTED_KEYS.values(): 

195 key = value['public_key_data'] 

196 yield (key, False) 

197 singleton_key = tests.list_keys_singleton()[0].key 

198 for value in tests.SUPPORTED_KEYS.values(): 

199 key = value['public_key_data'] 

200 if key == singleton_key: 

201 yield (key, True) 

202 

203 @pytest.mark.parametrize(['key', 'single'], list(_params())) 

204 def test_210_ssh_key_selector(self, monkeypatch, key, single): 

205 def key_is_suitable(key: bytes): 

206 return key in {v['public_key_data'] 

207 for v in tests.SUPPORTED_KEYS.values()} 

208 if single: 

209 monkeypatch.setattr(ssh_agent_client.SSHAgentClient, 

210 'list_keys', tests.list_keys_singleton) 

211 keys = [pair.key for pair in tests.list_keys_singleton() 

212 if key_is_suitable(pair.key)] 

213 index = '1' 

214 text = f'Use this key? yes\n' 

215 else: 

216 monkeypatch.setattr(ssh_agent_client.SSHAgentClient, 

217 'list_keys', tests.list_keys) 

218 keys = [pair.key for pair in tests.list_keys() 

219 if key_is_suitable(pair.key)] 

220 index = str(1 + keys.index(key)) 

221 n = len(keys) 

222 text = f'Your selection? (1-{n}, leave empty to abort): {index}\n' 

223 b64_key = base64.standard_b64encode(key).decode('ASCII') 

224 

225 @click.command() 

226 def driver(): 

227 key = derivepassphrase.cli._select_ssh_key() 

228 click.echo(base64.standard_b64encode(key).decode('ASCII')) 

229 

230 runner = click.testing.CliRunner(mix_stderr=True) 

231 result = runner.invoke(driver, [], 

232 input=('yes\n' if single else f'{index}\n'), 

233 catch_exceptions=True) 

234 assert result.stdout.startswith('Suitable SSH keys:\n'), ( 

235 'missing expected output' 

236 ) 

237 assert text in result.stdout, 'missing expected output' 

238 assert ( 

239 result.stdout.endswith(f'\n{b64_key}\n') 

240 ), 'missing expected output' 

241 assert result.exit_code == 0, 'driver program failed?!' 

242 

243 del _params 

244 

245 def test_300_constructor_bad_running_agent(self, monkeypatch): 

246 monkeypatch.setenv('SSH_AUTH_SOCK', 

247 os.environ['SSH_AUTH_SOCK'] + '~') 

248 sock = socket.socket(family=socket.AF_UNIX) 

249 with pytest.raises(OSError): 

250 ssh_agent_client.SSHAgentClient(socket=sock) 

251 

252 @pytest.mark.parametrize(['response'], [ 

253 (b'\x00\x00',), 

254 (b'\x00\x00\x00\x1f some bytes missing',), 

255 ]) 

256 def test_310_truncated_server_response(self, monkeypatch, response): 

257 client = ssh_agent_client.SSHAgentClient() 

258 response_stream = io.BytesIO(response) 

259 class PseudoSocket(object): 

260 def sendall(self, *args: Any, **kwargs: Any) -> Any: 

261 return None 

262 def recv(self, *args: Any, **kwargs: Any) -> Any: 

263 return response_stream.read(*args, **kwargs) 

264 pseudo_socket = PseudoSocket() 

265 monkeypatch.setattr(client, '_connection', pseudo_socket) 

266 with pytest.raises(EOFError): 

267 client.request(255, b'') 

268 

269 @tests.skip_if_no_agent 

270 @pytest.mark.parametrize( 

271 ['response_code', 'response', 'exc_type', 'exc_pattern'], 

272 [ 

273 (255, b'', RuntimeError, 'error return from SSH agent:'), 

274 (12, b'\x00\x00\x00\x01', EOFError, 'truncated response'), 

275 ( 

276 12, 

277 b'\x00\x00\x00\x00abc', 

278 ssh_agent_client.TrailingDataError, 

279 'overlong response', 

280 ), 

281 ] 

282 ) 

283 def test_320_list_keys_error_responses(self, monkeypatch, response_code, 

284 response, exc_type, exc_pattern): 

285 client = ssh_agent_client.SSHAgentClient() 

286 monkeypatch.setattr(client, 'request', 

287 lambda *a, **kw: (response_code, response)) 

288 with pytest.raises(exc_type, match=exc_pattern): 

289 client.list_keys() 

290 

291 @tests.skip_if_no_agent 

292 @pytest.mark.parametrize( 

293 ['key', 'check', 'response', 'exc_type', 'exc_pattern'], 

294 [ 

295 ( 

296 b'invalid-key', 

297 True, 

298 (255, b''), 

299 KeyError, 

300 'target SSH key not loaded into agent', 

301 ), 

302 ( 

303 tests.SUPPORTED_KEYS['ed25519']['public_key_data'], 

304 True, 

305 (255, b''), 

306 RuntimeError, 

307 'signing data failed:', 

308 ) 

309 ] 

310 ) 

311 def test_330_sign_error_responses(self, monkeypatch, key, check, 

312 response, exc_type, exc_pattern): 

313 client = ssh_agent_client.SSHAgentClient() 

314 monkeypatch.setattr(client, 'request', lambda a, b: response) 

315 KeyCommentPair = ssh_agent_client.types.KeyCommentPair 

316 loaded_keys = [KeyCommentPair(v['public_key_data'], b'no comment') 

317 for v in tests.SUPPORTED_KEYS.values()] 

318 monkeypatch.setattr(client, 'list_keys', lambda: loaded_keys) 

319 with pytest.raises(exc_type, match=exc_pattern): 

320 client.sign(key, b'abc', check_if_key_loaded=check)