1
2
3
4
5
6
7
8
9
10 """Classes representing TLS messages."""
11
12 from .utils.compat import *
13 from .utils.cryptomath import *
14 from .errors import *
15 from .utils.codec import *
16 from .constants import *
17 from .x509 import X509
18 from .x509certchain import X509CertChain
19 from .utils.tackwrapper import *
20
23 self.type = 0
24 self.version = (0,0)
25 self.length = 0
26 self.ssl2 = False
27
29 self.type = type
30 self.version = version
31 self.length = length
32 return self
33
35 w = Writer()
36 w.add(self.type, 1)
37 w.add(self.version[0], 1)
38 w.add(self.version[1], 1)
39 w.add(self.length, 2)
40 return w.bytes
41
43 self.type = p.get(1)
44 self.version = (p.get(1), p.get(1))
45 self.length = p.get(2)
46 self.ssl2 = False
47 return self
48
51 self.type = 0
52 self.version = (0,0)
53 self.length = 0
54 self.ssl2 = True
55
64
65
68 self.contentType = ContentType.alert
69 self.level = 0
70 self.description = 0
71
73 self.level = level
74 self.description = description
75 return self
76
83
85 w = Writer()
86 w.add(self.level, 1)
87 w.add(self.description, 1)
88 return w.bytes
89
90
95
96 - def postWrite(self, w):
97 headerWriter = Writer()
98 headerWriter.add(self.handshakeType, 1)
99 headerWriter.add(len(w.bytes), 3)
100 return headerWriter.bytes + w.bytes
101
116
117 - def create(self, version, random, session_id, cipher_suites,
118 certificate_types=None, srpUsername=None,
119 tack=False, supports_npn=False, serverName=None):
120 self.client_version = version
121 self.random = random
122 self.session_id = session_id
123 self.cipher_suites = cipher_suites
124 self.certificate_types = certificate_types
125 self.compression_methods = [0]
126 if srpUsername:
127 self.srp_username = bytearray(srpUsername, "utf-8")
128 self.tack = tack
129 self.supports_npn = supports_npn
130 if serverName:
131 self.server_name = bytearray(serverName, "utf-8")
132 return self
133
191
193 w = Writer()
194 w.add(self.client_version[0], 1)
195 w.add(self.client_version[1], 1)
196 w.addFixSeq(self.random, 1)
197 w.addVarSeq(self.session_id, 1, 1)
198 w.addVarSeq(self.cipher_suites, 2, 2)
199 w.addVarSeq(self.compression_methods, 1, 1)
200
201 w2 = Writer()
202 if self.certificate_types and self.certificate_types != \
203 [CertificateType.x509]:
204 w2.add(ExtensionType.cert_type, 2)
205 w2.add(len(self.certificate_types)+1, 2)
206 w2.addVarSeq(self.certificate_types, 1, 1)
207 if self.srp_username:
208 w2.add(ExtensionType.srp, 2)
209 w2.add(len(self.srp_username)+1, 2)
210 w2.addVarSeq(self.srp_username, 1, 1)
211 if self.supports_npn:
212 w2.add(ExtensionType.supports_npn, 2)
213 w2.add(0, 2)
214 if self.server_name:
215 w2.add(ExtensionType.server_name, 2)
216 w2.add(len(self.server_name)+5, 2)
217 w2.add(len(self.server_name)+3, 2)
218 w2.add(NameType.host_name, 1)
219 w2.addVarSeq(self.server_name, 1, 2)
220 if self.tack:
221 w2.add(ExtensionType.tack, 2)
222 w2.add(0, 2)
223 if len(w2.bytes):
224 w.add(len(w2.bytes), 2)
225 w.bytes += w2.bytes
226 return self.postWrite(w)
227
231
233 return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
234
237 HandshakeMsg.__init__(self, HandshakeType.server_hello)
238 self.server_version = (0,0)
239 self.random = bytearray(32)
240 self.session_id = bytearray(0)
241 self.cipher_suite = 0
242 self.certificate_type = CertificateType.x509
243 self.compression_method = 0
244 self.tackExt = None
245 self.next_protos_advertised = None
246 self.next_protos = None
247
248 - def create(self, version, random, session_id, cipher_suite,
249 certificate_type, tackExt, next_protos_advertised):
250 self.server_version = version
251 self.random = random
252 self.session_id = session_id
253 self.cipher_suite = cipher_suite
254 self.certificate_type = certificate_type
255 self.compression_method = 0
256 self.tackExt = tackExt
257 self.next_protos_advertised = next_protos_advertised
258 return self
259
286
288 protos = []
289 while True:
290 if len(b) == 0:
291 break
292 l = b[0]
293 b = b[1:]
294 if len(b) < l:
295 raise BadNextProtos(len(b))
296 protos.append(b[:l])
297 b = b[l:]
298 return protos
299
301 b = bytearray()
302 for e in self.next_protos_advertised:
303 if len(e) > 255 or len(e) == 0:
304 raise BadNextProtos(len(e))
305 b += bytearray( [len(e)] ) + bytearray(e)
306 return b
307
309 w = Writer()
310 w.add(self.server_version[0], 1)
311 w.add(self.server_version[1], 1)
312 w.addFixSeq(self.random, 1)
313 w.addVarSeq(self.session_id, 1, 1)
314 w.add(self.cipher_suite, 2)
315 w.add(self.compression_method, 1)
316
317 w2 = Writer()
318 if self.certificate_type and self.certificate_type != \
319 CertificateType.x509:
320 w2.add(ExtensionType.cert_type, 2)
321 w2.add(1, 2)
322 w2.add(self.certificate_type, 1)
323 if self.tackExt:
324 b = self.tackExt.serialize()
325 w2.add(ExtensionType.tack, 2)
326 w2.add(len(b), 2)
327 w2.bytes += b
328 if self.next_protos_advertised is not None:
329 encoded_next_protos_advertised = self.__next_protos_encoded()
330 w2.add(ExtensionType.supports_npn, 2)
331 w2.add(len(encoded_next_protos_advertised), 2)
332 w2.addFixSeq(encoded_next_protos_advertised, 1)
333 if len(w2.bytes):
334 w.add(len(w2.bytes), 2)
335 w.bytes += w2.bytes
336 return self.postWrite(w)
337
338
344
346 self.certChain = certChain
347 return self
348
368
370 w = Writer()
371 if self.certificateType == CertificateType.x509:
372 chainLength = 0
373 if self.certChain:
374 certificate_list = self.certChain.x509List
375 else:
376 certificate_list = []
377
378 for cert in certificate_list:
379 bytes = cert.writeBytes()
380 chainLength += len(bytes)+3
381
382 w.add(chainLength, 3)
383 for cert in certificate_list:
384 bytes = cert.writeBytes()
385 w.addVarSeq(bytes, 1, 3)
386 else:
387 raise AssertionError()
388 return self.postWrite(w)
389
399
400 - def create(self, certificate_types, certificate_authorities, sig_algs=(), version=(3,0)):
401 self.certificate_types = certificate_types
402 self.certificate_authorities = certificate_authorities
403 self.version = version
404 self.supported_signature_algs = sig_algs
405 return self
406
408 p.startLengthCheck(3)
409 self.certificate_types = p.getVarList(1, 1)
410 if self.version >= (3,3):
411 self.supported_signature_algs = p.getVarList(2, 2)
412 ca_list_length = p.get(2)
413 index = 0
414 self.certificate_authorities = []
415 while index != ca_list_length:
416 ca_bytes = p.getVarBytes(2)
417 self.certificate_authorities.append(ca_bytes)
418 index += len(ca_bytes)+2
419 p.stopLengthCheck()
420 return self
421
423 w = Writer()
424 w.addVarSeq(self.certificate_types, 1, 1)
425 if self.version >= (3,3):
426 w.addVarSeq(self.supported_signature_algs, 2, 2)
427 caLength = 0
428
429 for ca_dn in self.certificate_authorities:
430 caLength += len(ca_dn)+2
431 w.add(caLength, 2)
432
433 for ca_dn in self.certificate_authorities:
434 w.addVarSeq(ca_dn, 1, 2)
435 return self.postWrite(w)
436
439 HandshakeMsg.__init__(self, HandshakeType.server_key_exchange)
440 self.cipherSuite = cipherSuite
441 self.srp_N = 0
442 self.srp_g = 0
443 self.srp_s = bytearray(0)
444 self.srp_B = 0
445
446 self.dh_p = 0
447 self.dh_g = 0
448 self.dh_Ys = 0
449 self.signature = bytearray(0)
450
451 - def createSRP(self, srp_N, srp_g, srp_s, srp_B):
452 self.srp_N = srp_N
453 self.srp_g = srp_g
454 self.srp_s = srp_s
455 self.srp_B = srp_B
456 return self
457
459 self.dh_p = dh_p
460 self.dh_g = dh_g
461 self.dh_Ys = dh_Ys
462 return self
463
479
481 w = Writer()
482 if self.cipherSuite in CipherSuite.srpAllSuites:
483 w.addVarSeq(numberToByteArray(self.srp_N), 1, 2)
484 w.addVarSeq(numberToByteArray(self.srp_g), 1, 2)
485 w.addVarSeq(self.srp_s, 1, 1)
486 w.addVarSeq(numberToByteArray(self.srp_B), 1, 2)
487 if self.cipherSuite in CipherSuite.srpCertSuites:
488 w.addVarSeq(self.signature, 1, 2)
489 elif self.cipherSuite in CipherSuite.anonSuites:
490 w.addVarSeq(numberToByteArray(self.dh_p), 1, 2)
491 w.addVarSeq(numberToByteArray(self.dh_g), 1, 2)
492 w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2)
493 if self.cipherSuite in []:
494 w.addVarSeq(self.signature, 1, 2)
495 return self.postWrite(w)
496
497 - def hash(self, clientRandom, serverRandom):
498 oldCipherSuite = self.cipherSuite
499 self.cipherSuite = None
500 try:
501 bytes = clientRandom + serverRandom + self.write()[4:]
502 return MD5(bytes) + SHA1(bytes)
503 finally:
504 self.cipherSuite = oldCipherSuite
505
521
523 - def __init__(self, cipherSuite, version=None):
529
531 self.srp_A = srp_A
532 return self
533
534 - def createRSA(self, encryptedPreMasterSecret):
535 self.encryptedPreMasterSecret = encryptedPreMasterSecret
536 return self
537
539 self.dh_Yc = dh_Yc
540 return self
541
560
562 w = Writer()
563 if self.cipherSuite in CipherSuite.srpAllSuites:
564 w.addVarSeq(numberToByteArray(self.srp_A), 1, 2)
565 elif self.cipherSuite in CipherSuite.certSuites:
566 if self.version in ((3,1), (3,2), (3,3)):
567 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
568 elif self.version == (3,0):
569 w.addFixSeq(self.encryptedPreMasterSecret, 1)
570 else:
571 raise AssertionError()
572 elif self.cipherSuite in CipherSuite.anonSuites:
573 w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2)
574 else:
575 raise AssertionError()
576 return self.postWrite(w)
577
582
584 self.signature = signature
585 return self
586
592
597
602
604 self.type = 1
605 return self
606
612
614 w = Writer()
615 w.add(self.type,1)
616 return w.bytes
617
618
623
624 - def create(self, next_proto):
625 self.next_proto = next_proto
626 return self
627
634
635 - def write(self, trial=False):
636 w = Writer()
637 w.addVarSeq(self.next_proto, 1, 1)
638 paddingLen = 32 - ((len(self.next_proto) + 2) % 32)
639 w.addVarSeq(bytearray(paddingLen), 1, 1)
640 return self.postWrite(w)
641
647
648 - def create(self, verify_data):
649 self.verify_data = verify_data
650 return self
651
662
667
672
674 self.bytes = bytes
675 return self
676
681
683 self.bytes = p.bytes
684 return self
685
688