Package tlslite :: Module messages
[hide private]
[frames] | no frames]

Source Code for Module tlslite.messages

  1  # Authors:  
  2  #   Trevor Perrin 
  3  #   Google - handling CertificateRequest.certificate_types 
  4  #   Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support 
  5  #   Dimitris Moraitis - Anon ciphersuites 
  6  # 
  7  # See the LICENSE file for legal information regarding use of this file. 
  8   
  9  """Classes representing TLS messages.""" 
 10   
 11  from .utils.compat import * 
 12  from .utils.cryptomath import * 
 13  from .errors import * 
 14  from .utils.codec import * 
 15  from .constants import * 
 16  from .x509 import X509 
 17  from .x509certchain import X509CertChain 
 18  from .utils.tackwrapper import * 
 19   
20 -class RecordHeader3(object):
21 - def __init__(self):
22 self.type = 0 23 self.version = (0,0) 24 self.length = 0 25 self.ssl2 = False
26
27 - def create(self, version, type, length):
28 self.type = type 29 self.version = version 30 self.length = length 31 return self
32
33 - def write(self):
34 w = Writer() 35 w.add(self.type, 1) 36 w.add(self.version[0], 1) 37 w.add(self.version[1], 1) 38 w.add(self.length, 2) 39 return w.bytes
40
41 - def parse(self, p):
42 self.type = p.get(1) 43 self.version = (p.get(1), p.get(1)) 44 self.length = p.get(2) 45 self.ssl2 = False 46 return self
47
48 -class RecordHeader2(object):
49 - def __init__(self):
50 self.type = 0 51 self.version = (0,0) 52 self.length = 0 53 self.ssl2 = True
54
55 - def parse(self, p):
56 if p.get(1)!=128: 57 raise SyntaxError() 58 self.type = ContentType.handshake 59 self.version = (2,0) 60 #We don't support 2-byte-length-headers; could be a problem 61 self.length = p.get(1) 62 return self
63 64
65 -class Alert(object):
66 - def __init__(self):
67 self.contentType = ContentType.alert 68 self.level = 0 69 self.description = 0
70
71 - def create(self, description, level=AlertLevel.fatal):
72 self.level = level 73 self.description = description 74 return self
75
76 - def parse(self, p):
77 p.setLengthCheck(2) 78 self.level = p.get(1) 79 self.description = p.get(1) 80 p.stopLengthCheck() 81 return self
82
83 - def write(self):
84 w = Writer() 85 w.add(self.level, 1) 86 w.add(self.description, 1) 87 return w.bytes
88 89
90 -class HandshakeMsg(object):
91 - def __init__(self, handshakeType):
92 self.contentType = ContentType.handshake 93 self.handshakeType = handshakeType
94
95 - def postWrite(self, w):
96 headerWriter = Writer() 97 headerWriter.add(self.handshakeType, 1) 98 headerWriter.add(len(w.bytes), 3) 99 return headerWriter.bytes + w.bytes
100
101 -class ClientHello(HandshakeMsg):
102 - def __init__(self, ssl2=False):
103 HandshakeMsg.__init__(self, HandshakeType.client_hello) 104 self.ssl2 = ssl2 105 self.client_version = (0,0) 106 self.random = bytearray(32) 107 self.session_id = bytearray(0) 108 self.cipher_suites = [] # a list of 16-bit values 109 self.certificate_types = [CertificateType.x509] 110 self.compression_methods = [] # a list of 8-bit values 111 self.srp_username = None # a string 112 self.tack = False 113 self.supports_npn = False 114 self.server_name = bytearray(0)
115
116 - def create(self, version, random, session_id, cipher_suites, 117 certificate_types=None, srpUsername=None, 118 tack=False, supports_npn=False, serverName=None):
119 self.client_version = version 120 self.random = random 121 self.session_id = session_id 122 self.cipher_suites = cipher_suites 123 self.certificate_types = certificate_types 124 self.compression_methods = [0] 125 if srpUsername: 126 self.srp_username = bytearray(srpUsername, "utf-8") 127 self.tack = tack 128 self.supports_npn = supports_npn 129 if serverName: 130 self.server_name = bytearray(serverName, "utf-8") 131 return self
132
133 - def parse(self, p):
134 if self.ssl2: 135 self.client_version = (p.get(1), p.get(1)) 136 cipherSpecsLength = p.get(2) 137 sessionIDLength = p.get(2) 138 randomLength = p.get(2) 139 self.cipher_suites = p.getFixList(3, cipherSpecsLength//3) 140 self.session_id = p.getFixBytes(sessionIDLength) 141 self.random = p.getFixBytes(randomLength) 142 if len(self.random) < 32: 143 zeroBytes = 32-len(self.random) 144 self.random = bytearray(zeroBytes) + self.random 145 self.compression_methods = [0]#Fake this value 146 147 #We're not doing a stopLengthCheck() for SSLv2, oh well.. 148 else: 149 p.startLengthCheck(3) 150 self.client_version = (p.get(1), p.get(1)) 151 self.random = p.getFixBytes(32) 152 self.session_id = p.getVarBytes(1) 153 self.cipher_suites = p.getVarList(2, 2) 154 self.compression_methods = p.getVarList(1, 1) 155 if not p.atLengthCheck(): 156 totalExtLength = p.get(2) 157 soFar = 0 158 while soFar != totalExtLength: 159 extType = p.get(2) 160 extLength = p.get(2) 161 index1 = p.index 162 if extType == ExtensionType.srp: 163 self.srp_username = p.getVarBytes(1) 164 elif extType == ExtensionType.cert_type: 165 self.certificate_types = p.getVarList(1, 1) 166 elif extType == ExtensionType.tack: 167 self.tack = True 168 elif extType == ExtensionType.supports_npn: 169 self.supports_npn = True 170 elif extType == ExtensionType.server_name: 171 serverNameListBytes = p.getFixBytes(extLength) 172 p2 = Parser(serverNameListBytes) 173 p2.startLengthCheck(2) 174 while 1: 175 if p2.atLengthCheck(): 176 break # no host_name, oh well 177 name_type = p2.get(1) 178 hostNameBytes = p2.getVarBytes(2) 179 if name_type == NameType.host_name: 180 self.server_name = hostNameBytes 181 break 182 else: 183 _ = p.getFixBytes(extLength) 184 index2 = p.index 185 if index2 - index1 != extLength: 186 raise SyntaxError("Bad length for extension_data") 187 soFar += 4 + extLength 188 p.stopLengthCheck() 189 return self
190
191 - def write(self):
192 w = Writer() 193 w.add(self.client_version[0], 1) 194 w.add(self.client_version[1], 1) 195 w.addFixSeq(self.random, 1) 196 w.addVarSeq(self.session_id, 1, 1) 197 w.addVarSeq(self.cipher_suites, 2, 2) 198 w.addVarSeq(self.compression_methods, 1, 1) 199 200 w2 = Writer() # For Extensions 201 if self.certificate_types and self.certificate_types != \ 202 [CertificateType.x509]: 203 w2.add(ExtensionType.cert_type, 2) 204 w2.add(len(self.certificate_types)+1, 2) 205 w2.addVarSeq(self.certificate_types, 1, 1) 206 if self.srp_username: 207 w2.add(ExtensionType.srp, 2) 208 w2.add(len(self.srp_username)+1, 2) 209 w2.addVarSeq(self.srp_username, 1, 1) 210 if self.supports_npn: 211 w2.add(ExtensionType.supports_npn, 2) 212 w2.add(0, 2) 213 if self.server_name: 214 w2.add(ExtensionType.server_name, 2) 215 w2.add(len(self.server_name)+5, 2) 216 w2.add(len(self.server_name)+3, 2) 217 w2.add(NameType.host_name, 1) 218 w2.addVarSeq(self.server_name, 1, 2) 219 if self.tack: 220 w2.add(ExtensionType.tack, 2) 221 w2.add(0, 2) 222 if len(w2.bytes): 223 w.add(len(w2.bytes), 2) 224 w.bytes += w2.bytes 225 return self.postWrite(w)
226
227 -class BadNextProtos(Exception):
228 - def __init__(self, l):
229 self.length = l
230
231 - def __str__(self):
232 return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
233
234 -class ServerHello(HandshakeMsg):
235 - def __init__(self):
236 HandshakeMsg.__init__(self, HandshakeType.server_hello) 237 self.server_version = (0,0) 238 self.random = bytearray(32) 239 self.session_id = bytearray(0) 240 self.cipher_suite = 0 241 self.certificate_type = CertificateType.x509 242 self.compression_method = 0 243 self.tackExt = None 244 self.next_protos_advertised = None 245 self.next_protos = None
246
247 - def create(self, version, random, session_id, cipher_suite, 248 certificate_type, tackExt, next_protos_advertised):
249 self.server_version = version 250 self.random = random 251 self.session_id = session_id 252 self.cipher_suite = cipher_suite 253 self.certificate_type = certificate_type 254 self.compression_method = 0 255 self.tackExt = tackExt 256 self.next_protos_advertised = next_protos_advertised 257 return self
258
259 - def parse(self, p):
260 p.startLengthCheck(3) 261 self.server_version = (p.get(1), p.get(1)) 262 self.random = p.getFixBytes(32) 263 self.session_id = p.getVarBytes(1) 264 self.cipher_suite = p.get(2) 265 self.compression_method = p.get(1) 266 if not p.atLengthCheck(): 267 totalExtLength = p.get(2) 268 soFar = 0 269 while soFar != totalExtLength: 270 extType = p.get(2) 271 extLength = p.get(2) 272 if extType == ExtensionType.cert_type: 273 if extLength != 1: 274 raise SyntaxError() 275 self.certificate_type = p.get(1) 276 elif extType == ExtensionType.tack and tackpyLoaded: 277 self.tackExt = TackExtension(p.getFixBytes(extLength)) 278 elif extType == ExtensionType.supports_npn: 279 self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength)) 280 else: 281 p.getFixBytes(extLength) 282 soFar += 4 + extLength 283 p.stopLengthCheck() 284 return self
285
286 - def __parse_next_protos(self, b):
287 protos = [] 288 while True: 289 if len(b) == 0: 290 break 291 l = b[0] 292 b = b[1:] 293 if len(b) < l: 294 raise BadNextProtos(len(b)) 295 protos.append(b[:l]) 296 b = b[l:] 297 return protos
298
299 - def __next_protos_encoded(self):
300 b = bytearray() 301 for e in self.next_protos_advertised: 302 if len(e) > 255 or len(e) == 0: 303 raise BadNextProtos(len(e)) 304 b += bytearray( [len(e)] ) + bytearray(e) 305 return b
306
307 - def write(self):
308 w = Writer() 309 w.add(self.server_version[0], 1) 310 w.add(self.server_version[1], 1) 311 w.addFixSeq(self.random, 1) 312 w.addVarSeq(self.session_id, 1, 1) 313 w.add(self.cipher_suite, 2) 314 w.add(self.compression_method, 1) 315 316 w2 = Writer() # For Extensions 317 if self.certificate_type and self.certificate_type != \ 318 CertificateType.x509: 319 w2.add(ExtensionType.cert_type, 2) 320 w2.add(1, 2) 321 w2.add(self.certificate_type, 1) 322 if self.tackExt: 323 b = self.tackExt.serialize() 324 w2.add(ExtensionType.tack, 2) 325 w2.add(len(b), 2) 326 w2.bytes += b 327 if self.next_protos_advertised is not None: 328 encoded_next_protos_advertised = self.__next_protos_encoded() 329 w2.add(ExtensionType.supports_npn, 2) 330 w2.add(len(encoded_next_protos_advertised), 2) 331 w2.addFixSeq(encoded_next_protos_advertised, 1) 332 if len(w2.bytes): 333 w.add(len(w2.bytes), 2) 334 w.bytes += w2.bytes 335 return self.postWrite(w)
336 337
338 -class Certificate(HandshakeMsg):
339 - def __init__(self, certificateType):
340 HandshakeMsg.__init__(self, HandshakeType.certificate) 341 self.certificateType = certificateType 342 self.certChain = None
343
344 - def create(self, certChain):
345 self.certChain = certChain 346 return self
347
348 - def parse(self, p):
349 p.startLengthCheck(3) 350 if self.certificateType == CertificateType.x509: 351 chainLength = p.get(3) 352 index = 0 353 certificate_list = [] 354 while index != chainLength: 355 certBytes = p.getVarBytes(3) 356 x509 = X509() 357 x509.parseBinary(certBytes) 358 certificate_list.append(x509) 359 index += len(certBytes)+3 360 if certificate_list: 361 self.certChain = X509CertChain(certificate_list) 362 else: 363 raise AssertionError() 364 365 p.stopLengthCheck() 366 return self
367
368 - def write(self):
369 w = Writer() 370 if self.certificateType == CertificateType.x509: 371 chainLength = 0 372 if self.certChain: 373 certificate_list = self.certChain.x509List 374 else: 375 certificate_list = [] 376 #determine length 377 for cert in certificate_list: 378 bytes = cert.writeBytes() 379 chainLength += len(bytes)+3 380 #add bytes 381 w.add(chainLength, 3) 382 for cert in certificate_list: 383 bytes = cert.writeBytes() 384 w.addVarSeq(bytes, 1, 3) 385 else: 386 raise AssertionError() 387 return self.postWrite(w)
388
389 -class CertificateRequest(HandshakeMsg):
390 - def __init__(self, version):
391 HandshakeMsg.__init__(self, HandshakeType.certificate_request) 392 #Apple's Secure Transport library rejects empty certificate_types, so 393 #default to rsa_sign. 394 self.certificate_types = [ClientCertificateType.rsa_sign] 395 self.certificate_authorities = [] 396 self.version = version 397 self.supported_signature_algs = []
398
399 - def create(self, certificate_types, certificate_authorities, sig_algs=(), version=(3,0)):
400 self.certificate_types = certificate_types 401 self.certificate_authorities = certificate_authorities 402 self.version = version 403 self.supported_signature_algs = sig_algs 404 return self
405
406 - def parse(self, p):
407 p.startLengthCheck(3) 408 self.certificate_types = p.getVarList(1, 1) 409 if self.version >= (3,3): 410 self.supported_signature_algs = p.getVarList(2, 2) 411 ca_list_length = p.get(2) 412 index = 0 413 self.certificate_authorities = [] 414 while index != ca_list_length: 415 ca_bytes = p.getVarBytes(2) 416 self.certificate_authorities.append(ca_bytes) 417 index += len(ca_bytes)+2 418 p.stopLengthCheck() 419 return self
420
421 - def write(self):
422 w = Writer() 423 w.addVarSeq(self.certificate_types, 1, 1) 424 if self.version >= (3,3): 425 w.addVarSeq(self.supported_signature_algs, 2, 2) 426 caLength = 0 427 #determine length 428 for ca_dn in self.certificate_authorities: 429 caLength += len(ca_dn)+2 430 w.add(caLength, 2) 431 #add bytes 432 for ca_dn in self.certificate_authorities: 433 w.addVarSeq(ca_dn, 1, 2) 434 return self.postWrite(w)
435
436 -class ServerKeyExchange(HandshakeMsg):
437 - def __init__(self, cipherSuite):
438 HandshakeMsg.__init__(self, HandshakeType.server_key_exchange) 439 self.cipherSuite = cipherSuite 440 self.srp_N = 0 441 self.srp_g = 0 442 self.srp_s = bytearray(0) 443 self.srp_B = 0 444 # Anon DH params: 445 self.dh_p = 0 446 self.dh_g = 0 447 self.dh_Ys = 0 448 self.signature = bytearray(0)
449
450 - def createSRP(self, srp_N, srp_g, srp_s, srp_B):
451 self.srp_N = srp_N 452 self.srp_g = srp_g 453 self.srp_s = srp_s 454 self.srp_B = srp_B 455 return self
456
457 - def createDH(self, dh_p, dh_g, dh_Ys):
458 self.dh_p = dh_p 459 self.dh_g = dh_g 460 self.dh_Ys = dh_Ys 461 return self
462
463 - def parse(self, p):
464 p.startLengthCheck(3) 465 if self.cipherSuite in CipherSuite.srpAllSuites: 466 self.srp_N = bytesToNumber(p.getVarBytes(2)) 467 self.srp_g = bytesToNumber(p.getVarBytes(2)) 468 self.srp_s = p.getVarBytes(1) 469 self.srp_B = bytesToNumber(p.getVarBytes(2)) 470 if self.cipherSuite in CipherSuite.srpCertSuites: 471 self.signature = p.getVarBytes(2) 472 elif self.cipherSuite in CipherSuite.anonSuites: 473 self.dh_p = bytesToNumber(p.getVarBytes(2)) 474 self.dh_g = bytesToNumber(p.getVarBytes(2)) 475 self.dh_Ys = bytesToNumber(p.getVarBytes(2)) 476 p.stopLengthCheck() 477 return self
478
479 - def write(self):
480 w = Writer() 481 if self.cipherSuite in CipherSuite.srpAllSuites: 482 w.addVarSeq(numberToByteArray(self.srp_N), 1, 2) 483 w.addVarSeq(numberToByteArray(self.srp_g), 1, 2) 484 w.addVarSeq(self.srp_s, 1, 1) 485 w.addVarSeq(numberToByteArray(self.srp_B), 1, 2) 486 if self.cipherSuite in CipherSuite.srpCertSuites: 487 w.addVarSeq(self.signature, 1, 2) 488 elif self.cipherSuite in CipherSuite.anonSuites: 489 w.addVarSeq(numberToByteArray(self.dh_p), 1, 2) 490 w.addVarSeq(numberToByteArray(self.dh_g), 1, 2) 491 w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2) 492 if self.cipherSuite in []: # TODO support for signed_params 493 w.addVarSeq(self.signature, 1, 2) 494 return self.postWrite(w)
495
496 - def hash(self, clientRandom, serverRandom):
497 oldCipherSuite = self.cipherSuite 498 self.cipherSuite = None 499 try: 500 bytes = clientRandom + serverRandom + self.write()[4:] 501 return MD5(bytes) + SHA1(bytes) 502 finally: 503 self.cipherSuite = oldCipherSuite
504
505 -class ServerHelloDone(HandshakeMsg):
506 - def __init__(self):
508
509 - def create(self):
510 return self
511
512 - def parse(self, p):
513 p.startLengthCheck(3) 514 p.stopLengthCheck() 515 return self
516
517 - def write(self):
518 w = Writer() 519 return self.postWrite(w)
520
521 -class ClientKeyExchange(HandshakeMsg):
522 - def __init__(self, cipherSuite, version=None):
523 HandshakeMsg.__init__(self, HandshakeType.client_key_exchange) 524 self.cipherSuite = cipherSuite 525 self.version = version 526 self.srp_A = 0 527 self.encryptedPreMasterSecret = bytearray(0)
528
529 - def createSRP(self, srp_A):
530 self.srp_A = srp_A 531 return self
532
533 - def createRSA(self, encryptedPreMasterSecret):
534 self.encryptedPreMasterSecret = encryptedPreMasterSecret 535 return self
536
537 - def createDH(self, dh_Yc):
538 self.dh_Yc = dh_Yc 539 return self
540
541 - def parse(self, p):
542 p.startLengthCheck(3) 543 if self.cipherSuite in CipherSuite.srpAllSuites: 544 self.srp_A = bytesToNumber(p.getVarBytes(2)) 545 elif self.cipherSuite in CipherSuite.certSuites: 546 if self.version in ((3,1), (3,2), (3,3)): 547 self.encryptedPreMasterSecret = p.getVarBytes(2) 548 elif self.version == (3,0): 549 self.encryptedPreMasterSecret = \ 550 p.getFixBytes(len(p.bytes)-p.index) 551 else: 552 raise AssertionError() 553 elif self.cipherSuite in CipherSuite.anonSuites: 554 self.dh_Yc = bytesToNumber(p.getVarBytes(2)) 555 else: 556 raise AssertionError() 557 p.stopLengthCheck() 558 return self
559
560 - def write(self):
561 w = Writer() 562 if self.cipherSuite in CipherSuite.srpAllSuites: 563 w.addVarSeq(numberToByteArray(self.srp_A), 1, 2) 564 elif self.cipherSuite in CipherSuite.certSuites: 565 if self.version in ((3,1), (3,2), (3,3)): 566 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2) 567 elif self.version == (3,0): 568 w.addFixSeq(self.encryptedPreMasterSecret, 1) 569 else: 570 raise AssertionError() 571 elif self.cipherSuite in CipherSuite.anonSuites: 572 w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2) 573 else: 574 raise AssertionError() 575 return self.postWrite(w)
576
577 -class CertificateVerify(HandshakeMsg):
578 - def __init__(self):
579 HandshakeMsg.__init__(self, HandshakeType.certificate_verify) 580 self.signature = bytearray(0)
581
582 - def create(self, signature):
583 self.signature = signature 584 return self
585
586 - def parse(self, p):
587 p.startLengthCheck(3) 588 self.signature = p.getVarBytes(2) 589 p.stopLengthCheck() 590 return self
591
592 - def write(self):
593 w = Writer() 594 w.addVarSeq(self.signature, 1, 2) 595 return self.postWrite(w)
596
597 -class ChangeCipherSpec(object):
598 - def __init__(self):
599 self.contentType = ContentType.change_cipher_spec 600 self.type = 1
601
602 - def create(self):
603 self.type = 1 604 return self
605
606 - def parse(self, p):
607 p.setLengthCheck(1) 608 self.type = p.get(1) 609 p.stopLengthCheck() 610 return self
611
612 - def write(self):
613 w = Writer() 614 w.add(self.type,1) 615 return w.bytes
616 617
618 -class NextProtocol(HandshakeMsg):
619 - def __init__(self):
620 HandshakeMsg.__init__(self, HandshakeType.next_protocol) 621 self.next_proto = None
622
623 - def create(self, next_proto):
624 self.next_proto = next_proto 625 return self
626
627 - def parse(self, p):
628 p.startLengthCheck(3) 629 self.next_proto = p.getVarBytes(1) 630 _ = p.getVarBytes(1) 631 p.stopLengthCheck() 632 return self
633
634 - def write(self, trial=False):
635 w = Writer() 636 w.addVarSeq(self.next_proto, 1, 1) 637 paddingLen = 32 - ((len(self.next_proto) + 2) % 32) 638 w.addVarSeq(bytearray(paddingLen), 1, 1) 639 return self.postWrite(w)
640
641 -class Finished(HandshakeMsg):
642 - def __init__(self, version):
643 HandshakeMsg.__init__(self, HandshakeType.finished) 644 self.version = version 645 self.verify_data = bytearray(0)
646
647 - def create(self, verify_data):
648 self.verify_data = verify_data 649 return self
650
651 - def parse(self, p):
652 p.startLengthCheck(3) 653 if self.version == (3,0): 654 self.verify_data = p.getFixBytes(36) 655 elif self.version in ((3,1), (3,2), (3,3)): 656 self.verify_data = p.getFixBytes(12) 657 else: 658 raise AssertionError() 659 p.stopLengthCheck() 660 return self
661
662 - def write(self):
663 w = Writer() 664 w.addFixSeq(self.verify_data, 1) 665 return self.postWrite(w)
666
667 -class ApplicationData(object):
668 - def __init__(self):
669 self.contentType = ContentType.application_data 670 self.bytes = bytearray(0)
671
672 - def create(self, bytes):
673 self.bytes = bytes 674 return self
675
676 - def splitFirstByte(self):
677 newMsg = ApplicationData().create(self.bytes[:1]) 678 self.bytes = self.bytes[1:] 679 return newMsg
680
681 - def parse(self, p):
682 self.bytes = p.bytes 683 return self
684
685 - def write(self):
686 return self.bytes
687