Package tlslite :: Module messages
[hide private]
[frames] | no frames]

Source Code for Module tlslite.messages

  1  # Authors:  
  2  #   Trevor Perrin 
  3  #   Google - handling CertificateRequest.certificate_types 
  4  #   Google (adapted by Sam Rushing) - NPN support 
  5  #   Dimitris Moraitis - Anon ciphersuites 
  6  # 
  7  # See the LICENSE file for legal information regarding use of this file. 
  8   
  9  """Classes representing TLS messages.""" 
 10   
 11  from .utils.compat import * 
 12  from .utils.cryptomath import * 
 13  from .errors import * 
 14  from .utils.codec import * 
 15  from .constants import * 
 16  from .x509 import X509 
 17  from .x509certchain import X509CertChain 
 18  from .utils.tackwrapper import * 
 19   
20 -class RecordHeader3:
21 - def __init__(self):
22 self.type = 0 23 self.version = (0,0) 24 self.length = 0 25 self.ssl2 = False
26
27 - def create(self, version, type, length):
28 self.type = type 29 self.version = version 30 self.length = length 31 return self
32
33 - def write(self):
34 w = Writer() 35 w.add(self.type, 1) 36 w.add(self.version[0], 1) 37 w.add(self.version[1], 1) 38 w.add(self.length, 2) 39 return w.bytes
40
41 - def parse(self, p):
42 self.type = p.get(1) 43 self.version = (p.get(1), p.get(1)) 44 self.length = p.get(2) 45 self.ssl2 = False 46 return self
47
48 -class RecordHeader2:
49 - def __init__(self):
50 self.type = 0 51 self.version = (0,0) 52 self.length = 0 53 self.ssl2 = True
54
55 - def parse(self, p):
56 if p.get(1)!=128: 57 raise SyntaxError() 58 self.type = ContentType.handshake 59 self.version = (2,0) 60 #We don't support 2-byte-length-headers; could be a problem 61 self.length = p.get(1) 62 return self
63 64
65 -class Alert:
66 - def __init__(self):
67 self.contentType = ContentType.alert 68 self.level = 0 69 self.description = 0
70
71 - def create(self, description, level=AlertLevel.fatal):
72 self.level = level 73 self.description = description 74 return self
75
76 - def parse(self, p):
77 p.setLengthCheck(2) 78 self.level = p.get(1) 79 self.description = p.get(1) 80 p.stopLengthCheck() 81 return self
82
83 - def write(self):
84 w = Writer() 85 w.add(self.level, 1) 86 w.add(self.description, 1) 87 return w.bytes
88 89
90 -class HandshakeMsg:
91 - def __init__(self, handshakeType):
92 self.contentType = ContentType.handshake 93 self.handshakeType = handshakeType
94
95 - def postWrite(self, w):
96 headerWriter = Writer() 97 headerWriter.add(self.handshakeType, 1) 98 headerWriter.add(len(w.bytes), 3) 99 return headerWriter.bytes + w.bytes
100
101 -class ClientHello(HandshakeMsg):
102 - def __init__(self, ssl2=False):
103 HandshakeMsg.__init__(self, HandshakeType.client_hello) 104 self.ssl2 = ssl2 105 self.client_version = (0,0) 106 self.random = createByteArrayZeros(32) 107 self.session_id = createByteArraySequence([]) 108 self.cipher_suites = [] # a list of 16-bit values 109 self.certificate_types = [CertificateType.x509] 110 self.compression_methods = [] # a list of 8-bit values 111 self.srp_username = None # a string 112 self.tack = False 113 self.supports_npn = False 114 self.server_name = ""
115
116 - def create(self, version, random, session_id, cipher_suites, 117 certificate_types=None, srp_username=None, 118 tack=False, supports_npn=False, server_name=""):
119 self.client_version = version 120 self.random = random 121 self.session_id = session_id 122 self.cipher_suites = cipher_suites 123 self.certificate_types = certificate_types 124 self.compression_methods = [0] 125 self.srp_username = srp_username 126 self.tack = tack 127 self.supports_npn = supports_npn 128 self.server_name = server_name 129 return self
130
131 - def parse(self, p):
132 if self.ssl2: 133 self.client_version = (p.get(1), p.get(1)) 134 cipherSpecsLength = p.get(2) 135 sessionIDLength = p.get(2) 136 randomLength = p.get(2) 137 self.cipher_suites = p.getFixList(3, cipherSpecsLength//3) 138 self.session_id = p.getFixBytes(sessionIDLength) 139 self.random = p.getFixBytes(randomLength) 140 if len(self.random) < 32: 141 zeroBytes = 32-len(self.random) 142 self.random = createByteArrayZeros(zeroBytes) + self.random 143 self.compression_methods = [0]#Fake this value 144 145 #We're not doing a stopLengthCheck() for SSLv2, oh well.. 146 else: 147 p.startLengthCheck(3) 148 self.client_version = (p.get(1), p.get(1)) 149 self.random = p.getFixBytes(32) 150 self.session_id = p.getVarBytes(1) 151 self.cipher_suites = p.getVarList(2, 2) 152 self.compression_methods = p.getVarList(1, 1) 153 if not p.atLengthCheck(): 154 totalExtLength = p.get(2) 155 soFar = 0 156 while soFar != totalExtLength: 157 extType = p.get(2) 158 extLength = p.get(2) 159 index1 = p.index 160 if extType == ExtensionType.srp: 161 self.srp_username = bytesToString(p.getVarBytes(1)) 162 elif extType == ExtensionType.cert_type: 163 self.certificate_types = p.getVarList(1, 1) 164 elif extType == ExtensionType.tack: 165 self.tack = True 166 elif extType == ExtensionType.supports_npn: 167 self.supports_npn = True 168 elif extType == ExtensionType.server_name: 169 serverNameListBytes = p.getFixBytes(extLength) 170 p2 = Parser(serverNameListBytes) 171 p2.startLengthCheck(2) 172 while 1: 173 if p2.atLengthCheck(): 174 break # no host_name, oh well 175 name_type = p2.get(1) 176 hostNameBytes = p2.getVarBytes(2) 177 if name_type == NameType.host_name: 178 self.server_name = bytesToString(hostNameBytes) 179 break 180 else: 181 _ = p.getFixBytes(extLength) 182 index2 = p.index 183 if index2 - index1 != extLength: 184 raise SyntaxError("Bad length for extension_data") 185 soFar += 4 + extLength 186 p.stopLengthCheck() 187 return self
188
189 - def write(self):
190 w = Writer() 191 w.add(self.client_version[0], 1) 192 w.add(self.client_version[1], 1) 193 w.addFixSeq(self.random, 1) 194 w.addVarSeq(self.session_id, 1, 1) 195 w.addVarSeq(self.cipher_suites, 2, 2) 196 w.addVarSeq(self.compression_methods, 1, 1) 197 198 w2 = Writer() # For Extensions 199 if self.certificate_types and self.certificate_types != \ 200 [CertificateType.x509]: 201 w2.add(ExtensionType.cert_type, 2) 202 w2.add(len(self.certificate_types)+1, 2) 203 w2.addVarSeq(self.certificate_types, 1, 1) 204 if self.srp_username: 205 w2.add(ExtensionType.srp, 2) 206 w2.add(len(self.srp_username)+1, 2) 207 w2.addVarSeq(stringToBytes(self.srp_username), 1, 1) 208 if self.server_name: 209 w2.add(ExtensionType.server_name, 2) 210 w2.add(len(self.server_name)+5, 2) 211 w2.add(len(self.server_name)+3, 2) 212 w2.add(NameType.host_name, 1) 213 w2.addVarSeq(stringToBytes(self.server_name), 1, 2) 214 if self.tack: 215 w2.add(ExtensionType.tack, 2) 216 w2.add(0, 2) 217 if len(w2.bytes): 218 w.add(len(w2.bytes), 2) 219 w.bytes += w2.bytes 220 return self.postWrite(w)
221
222 -class BadNextProtos(Exception):
223 - def __init__(self, l):
224 self.length = l
225
226 - def __str__(self):
227 return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
228
229 -class ServerHello(HandshakeMsg):
230 - def __init__(self):
231 HandshakeMsg.__init__(self, HandshakeType.server_hello) 232 self.server_version = (0,0) 233 self.random = createByteArrayZeros(32) 234 self.session_id = createByteArraySequence([]) 235 self.cipher_suite = 0 236 self.certificate_type = CertificateType.x509 237 self.compression_method = 0 238 self.tackExt = None 239 self.next_protos_advertised = None
240
241 - def create(self, version, random, session_id, cipher_suite, 242 certificate_type, tackExt, next_protos_advertised):
243 self.server_version = version 244 self.random = random 245 self.session_id = session_id 246 self.cipher_suite = cipher_suite 247 self.certificate_type = certificate_type 248 self.compression_method = 0 249 self.tackExt = tackExt 250 self.next_protos_advertised = next_protos_advertised 251 return self
252
253 - def parse(self, p):
254 p.startLengthCheck(3) 255 self.server_version = (p.get(1), p.get(1)) 256 self.random = p.getFixBytes(32) 257 self.session_id = p.getVarBytes(1) 258 self.cipher_suite = p.get(2) 259 self.compression_method = p.get(1) 260 if not p.atLengthCheck(): 261 totalExtLength = p.get(2) 262 soFar = 0 263 while soFar != totalExtLength: 264 extType = p.get(2) 265 extLength = p.get(2) 266 if extType == ExtensionType.cert_type: 267 if extLength != 1: 268 raise SyntaxError() 269 self.certificate_type = p.get(1) 270 elif extType == ExtensionType.tack and tackpyLoaded: 271 self.tackExt = TackExtension(p.getFixBytes(extLength)) 272 else: 273 p.getFixBytes(extLength) 274 soFar += 4 + extLength 275 p.stopLengthCheck() 276 return self
277
278 - def __next_protos_encoded(self):
279 a = [] 280 for e in self.next_protos_advertised: 281 if len(e) > 255 or len(e) == 0: 282 raise BadNextProtos(len(e)) 283 a.append(chr(len(e))) 284 a.append(e) 285 286 return [ord(x) for x in ''.join(a)]
287
288 - def write(self):
289 w = Writer() 290 w.add(self.server_version[0], 1) 291 w.add(self.server_version[1], 1) 292 w.addFixSeq(self.random, 1) 293 w.addVarSeq(self.session_id, 1, 1) 294 w.add(self.cipher_suite, 2) 295 w.add(self.compression_method, 1) 296 297 w2 = Writer() # For Extensions 298 if self.certificate_type and self.certificate_type != \ 299 CertificateType.x509: 300 w2.add(ExtensionType.cert_type, 2) 301 w2.add(1, 2) 302 w2.add(self.certificate_type, 1) 303 if self.tackExt: 304 b = self.tackExt.serialize() 305 w2.add(ExtensionType.tack, 2) 306 w2.add(len(b), 2) 307 w2.bytes += b 308 if self.next_protos_advertised is not None: 309 encoded_next_protos_advertised = self.__next_protos_encoded() 310 w2.add(ExtensionType.supports_npn, 2) 311 w2.add(len(encoded_next_protos_advertised), 2) 312 w2.addFixSeq(encoded_next_protos_advertised, 1) 313 if len(w2.bytes): 314 w.add(len(w2.bytes), 2) 315 w.bytes += w2.bytes 316 return self.postWrite(w)
317 318
319 -class Certificate(HandshakeMsg):
320 - def __init__(self, certificateType):
321 HandshakeMsg.__init__(self, HandshakeType.certificate) 322 self.certificateType = certificateType 323 self.certChain = None
324
325 - def create(self, certChain):
326 self.certChain = certChain 327 return self
328
329 - def parse(self, p):
330 p.startLengthCheck(3) 331 if self.certificateType == CertificateType.x509: 332 chainLength = p.get(3) 333 index = 0 334 certificate_list = [] 335 while index != chainLength: 336 certBytes = p.getVarBytes(3) 337 x509 = X509() 338 x509.parseBinary(certBytes) 339 certificate_list.append(x509) 340 index += len(certBytes)+3 341 if certificate_list: 342 self.certChain = X509CertChain(certificate_list) 343 else: 344 raise AssertionError() 345 346 p.stopLengthCheck() 347 return self
348
349 - def write(self):
350 w = Writer() 351 if self.certificateType == CertificateType.x509: 352 chainLength = 0 353 if self.certChain: 354 certificate_list = self.certChain.x509List 355 else: 356 certificate_list = [] 357 #determine length 358 for cert in certificate_list: 359 bytes = cert.writeBytes() 360 chainLength += len(bytes)+3 361 #add bytes 362 w.add(chainLength, 3) 363 for cert in certificate_list: 364 bytes = cert.writeBytes() 365 w.addVarSeq(bytes, 1, 3) 366 else: 367 raise AssertionError() 368 return self.postWrite(w)
369
370 -class CertificateRequest(HandshakeMsg):
371 - def __init__(self):
372 HandshakeMsg.__init__(self, HandshakeType.certificate_request) 373 #Apple's Secure Transport library rejects empty certificate_types, so 374 #default to rsa_sign. 375 self.certificate_types = [ClientCertificateType.rsa_sign] 376 self.certificate_authorities = []
377
378 - def create(self, certificate_types, certificate_authorities):
379 self.certificate_types = certificate_types 380 self.certificate_authorities = certificate_authorities 381 return self
382
383 - def parse(self, p):
384 p.startLengthCheck(3) 385 self.certificate_types = p.getVarList(1, 1) 386 ca_list_length = p.get(2) 387 index = 0 388 self.certificate_authorities = [] 389 while index != ca_list_length: 390 ca_bytes = p.getVarBytes(2) 391 self.certificate_authorities.append(ca_bytes) 392 index += len(ca_bytes)+2 393 p.stopLengthCheck() 394 return self
395
396 - def write(self):
397 w = Writer() 398 w.addVarSeq(self.certificate_types, 1, 1) 399 caLength = 0 400 #determine length 401 for ca_dn in self.certificate_authorities: 402 caLength += len(ca_dn)+2 403 w.add(caLength, 2) 404 #add bytes 405 for ca_dn in self.certificate_authorities: 406 w.addVarSeq(ca_dn, 1, 2) 407 return self.postWrite(w)
408
409 -class ServerKeyExchange(HandshakeMsg):
410 - def __init__(self, cipherSuite):
411 HandshakeMsg.__init__(self, HandshakeType.server_key_exchange) 412 self.cipherSuite = cipherSuite 413 self.srp_N = 0L 414 self.srp_g = 0L 415 self.srp_s = createByteArraySequence([]) 416 self.srp_B = 0L 417 # Anon DH params: 418 self.dh_p = 0L 419 self.dh_g = 0L 420 self.dh_Ys = 0L 421 self.signature = createByteArraySequence([])
422
423 - def createSRP(self, srp_N, srp_g, srp_s, srp_B):
424 self.srp_N = srp_N 425 self.srp_g = srp_g 426 self.srp_s = srp_s 427 self.srp_B = srp_B 428 return self
429
430 - def createDH(self, dh_p, dh_g, dh_Ys):
431 self.dh_p = dh_p 432 self.dh_g = dh_g 433 self.dh_Ys = dh_Ys 434 return self
435
436 - def parse(self, p):
437 p.startLengthCheck(3) 438 if self.cipherSuite in CipherSuite.srpAllSuites: 439 self.srp_N = bytesToNumber(p.getVarBytes(2)) 440 self.srp_g = bytesToNumber(p.getVarBytes(2)) 441 self.srp_s = p.getVarBytes(1) 442 self.srp_B = bytesToNumber(p.getVarBytes(2)) 443 if self.cipherSuite in CipherSuite.srpCertSuites: 444 self.signature = p.getVarBytes(2) 445 elif self.cipherSuite in CipherSuite.anonSuites: 446 self.dh_p = bytesToNumber(p.getVarBytes(2)) 447 self.dh_g = bytesToNumber(p.getVarBytes(2)) 448 self.dh_Ys = bytesToNumber(p.getVarBytes(2)) 449 p.stopLengthCheck() 450 return self
451
452 - def write(self):
453 w = Writer() 454 if self.cipherSuite in CipherSuite.srpAllSuites: 455 w.addVarSeq(numberToBytes(self.srp_N), 1, 2) 456 w.addVarSeq(numberToBytes(self.srp_g), 1, 2) 457 w.addVarSeq(self.srp_s, 1, 1) 458 w.addVarSeq(numberToBytes(self.srp_B), 1, 2) 459 if self.cipherSuite in CipherSuite.srpCertSuites: 460 w.addVarSeq(self.signature, 1, 2) 461 elif self.cipherSuite in CipherSuite.anonSuites: 462 w.addVarSeq(numberToBytes(self.dh_p), 1, 2) 463 w.addVarSeq(numberToBytes(self.dh_g), 1, 2) 464 w.addVarSeq(numberToBytes(self.dh_Ys), 1, 2) 465 if self.cipherSuite in []: # TODO support for signed_params 466 w.addVarSeq(self.signature, 1, 2) 467 return self.postWrite(w)
468
469 - def hash(self, clientRandom, serverRandom):
470 oldCipherSuite = self.cipherSuite 471 self.cipherSuite = None 472 try: 473 bytes = clientRandom + serverRandom + self.write()[4:] 474 s = bytesToString(bytes) 475 return stringToBytes(md5(s).digest() + sha1(s).digest()) 476 finally: 477 self.cipherSuite = oldCipherSuite
478
479 -class ServerHelloDone(HandshakeMsg):
480 - def __init__(self):
482
483 - def create(self):
484 return self
485
486 - def parse(self, p):
487 p.startLengthCheck(3) 488 p.stopLengthCheck() 489 return self
490
491 - def write(self):
492 w = Writer() 493 return self.postWrite(w)
494
495 -class ClientKeyExchange(HandshakeMsg):
496 - def __init__(self, cipherSuite, version=None):
497 HandshakeMsg.__init__(self, HandshakeType.client_key_exchange) 498 self.cipherSuite = cipherSuite 499 self.version = version 500 self.srp_A = 0 501 self.encryptedPreMasterSecret = createByteArraySequence([])
502
503 - def createSRP(self, srp_A):
504 self.srp_A = srp_A 505 return self
506
507 - def createRSA(self, encryptedPreMasterSecret):
508 self.encryptedPreMasterSecret = encryptedPreMasterSecret 509 return self
510
511 - def createDH(self, dh_Yc):
512 self.dh_Yc = dh_Yc 513 return self
514
515 - def parse(self, p):
516 p.startLengthCheck(3) 517 if self.cipherSuite in CipherSuite.srpAllSuites: 518 self.srp_A = bytesToNumber(p.getVarBytes(2)) 519 elif self.cipherSuite in CipherSuite.certSuites: 520 if self.version in ((3,1), (3,2)): 521 self.encryptedPreMasterSecret = p.getVarBytes(2) 522 elif self.version == (3,0): 523 self.encryptedPreMasterSecret = \ 524 p.getFixBytes(len(p.bytes)-p.index) 525 else: 526 raise AssertionError() 527 elif self.cipherSuite in CipherSuite.anonSuites: 528 self.dh_Yc = bytesToNumber(p.getVarBytes(2)) 529 else: 530 raise AssertionError() 531 p.stopLengthCheck() 532 return self
533
534 - def write(self):
535 w = Writer() 536 if self.cipherSuite in CipherSuite.srpAllSuites: 537 w.addVarSeq(numberToBytes(self.srp_A), 1, 2) 538 elif self.cipherSuite in CipherSuite.certSuites: 539 if self.version in ((3,1), (3,2)): 540 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2) 541 elif self.version == (3,0): 542 w.addFixSeq(self.encryptedPreMasterSecret, 1) 543 else: 544 raise AssertionError() 545 elif self.cipherSuite in CipherSuite.anonSuites: 546 w.addVarSeq(numberToBytes(self.dh_Yc), 1, 2) 547 else: 548 raise AssertionError() 549 return self.postWrite(w)
550
551 -class CertificateVerify(HandshakeMsg):
552 - def __init__(self):
555
556 - def create(self, signature):
557 self.signature = signature 558 return self
559
560 - def parse(self, p):
561 p.startLengthCheck(3) 562 self.signature = p.getVarBytes(2) 563 p.stopLengthCheck() 564 return self
565
566 - def write(self):
567 w = Writer() 568 w.addVarSeq(self.signature, 1, 2) 569 return self.postWrite(w)
570
571 -class ChangeCipherSpec:
572 - def __init__(self):
573 self.contentType = ContentType.change_cipher_spec 574 self.type = 1
575
576 - def create(self):
577 self.type = 1 578 return self
579
580 - def parse(self, p):
581 p.setLengthCheck(1) 582 self.type = p.get(1) 583 p.stopLengthCheck() 584 return self
585
586 - def write(self):
587 w = Writer() 588 w.add(self.type,1) 589 return w.bytes
590 591
592 -class NextProtocol(HandshakeMsg):
593 - def __init__(self):
594 self.contentType = ContentType.handshake 595 self.next_proto = None
596
597 - def create(self, next_proto):
598 self.next_proto = next_proto
599
600 - def parse(self, p):
601 p.startLengthCheck(3) 602 self.next_proto = p.getVarBytes(1) 603 _ = p.getVarBytes(1) 604 p.stopLengthCheck() 605 return self
606
607 - def write(self):
608 w = Writer() 609 w.addVarSeq(self.next_proto, 1, 1) 610 w.addVarSeq('\x00' * 32, 1, 32 - ((len(self.next_proto) + 2) % 32)); 611 return self.postWrite(w)
612
613 -class Finished(HandshakeMsg):
614 - def __init__(self, version):
615 HandshakeMsg.__init__(self, HandshakeType.finished) 616 self.version = version 617 self.verify_data = createByteArraySequence([])
618
619 - def create(self, verify_data):
620 self.verify_data = verify_data 621 return self
622
623 - def parse(self, p):
624 p.startLengthCheck(3) 625 if self.version == (3,0): 626 self.verify_data = p.getFixBytes(36) 627 elif self.version in ((3,1), (3,2)): 628 self.verify_data = p.getFixBytes(12) 629 else: 630 raise AssertionError() 631 p.stopLengthCheck() 632 return self
633
634 - def write(self):
635 w = Writer() 636 w.addFixSeq(self.verify_data, 1) 637 return self.postWrite(w)
638
639 -class ApplicationData:
640 - def __init__(self):
641 self.contentType = ContentType.application_data 642 self.bytes = createByteArraySequence([])
643
644 - def create(self, bytes):
645 self.bytes = bytes 646 return self
647
648 - def splitFirstByte(self):
649 newMsg = ApplicationData().create(self.bytes[:1]) 650 self.bytes = self.bytes[1:] 651 return newMsg
652
653 - def parse(self, p):
654 self.bytes = p.bytes 655 return self
656
657 - def write(self):
658 return self.bytes
659