Package tlslite :: Module messages
[hide private]
[frames] | no frames]

Source Code for Module tlslite.messages

  1  # Authors:  
  2  #   Trevor Perrin 
  3  #   Google - handling CertificateRequest.certificate_types 
  4  #   Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support 
  5  #   Dimitris Moraitis - Anon ciphersuites 
  6  #   Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2 
  7  # 
  8  # See the LICENSE file for legal information regarding use of this file. 
  9   
 10  """Classes representing TLS messages.""" 
 11   
 12  from .utils.compat import * 
 13  from .utils.cryptomath import * 
 14  from .errors import * 
 15  from .utils.codec import * 
 16  from .constants import * 
 17  from .x509 import X509 
 18  from .x509certchain import X509CertChain 
 19  from .utils.tackwrapper import * 
 20   
21 -class RecordHeader3(object):
22 - def __init__(self):
23 self.type = 0 24 self.version = (0,0) 25 self.length = 0 26 self.ssl2 = False
27
28 - def create(self, version, type, length):
29 self.type = type 30 self.version = version 31 self.length = length 32 return self
33
34 - def write(self):
35 w = Writer() 36 w.add(self.type, 1) 37 w.add(self.version[0], 1) 38 w.add(self.version[1], 1) 39 w.add(self.length, 2) 40 return w.bytes
41
42 - def parse(self, p):
43 self.type = p.get(1) 44 self.version = (p.get(1), p.get(1)) 45 self.length = p.get(2) 46 self.ssl2 = False 47 return self
48
49 -class RecordHeader2(object):
50 - def __init__(self):
51 self.type = 0 52 self.version = (0,0) 53 self.length = 0 54 self.ssl2 = True
55
56 - def parse(self, p):
57 if p.get(1)!=128: 58 raise SyntaxError() 59 self.type = ContentType.handshake 60 self.version = (2,0) 61 #We don't support 2-byte-length-headers; could be a problem 62 self.length = p.get(1) 63 return self
64 65
66 -class Alert(object):
67 - def __init__(self):
68 self.contentType = ContentType.alert 69 self.level = 0 70 self.description = 0
71
72 - def create(self, description, level=AlertLevel.fatal):
73 self.level = level 74 self.description = description 75 return self
76
77 - def parse(self, p):
78 p.setLengthCheck(2) 79 self.level = p.get(1) 80 self.description = p.get(1) 81 p.stopLengthCheck() 82 return self
83
84 - def write(self):
85 w = Writer() 86 w.add(self.level, 1) 87 w.add(self.description, 1) 88 return w.bytes
89 90
91 -class HandshakeMsg(object):
92 - def __init__(self, handshakeType):
93 self.contentType = ContentType.handshake 94 self.handshakeType = handshakeType
95
96 - def postWrite(self, w):
97 headerWriter = Writer() 98 headerWriter.add(self.handshakeType, 1) 99 headerWriter.add(len(w.bytes), 3) 100 return headerWriter.bytes + w.bytes
101
102 -class ClientHello(HandshakeMsg):
103 - def __init__(self, ssl2=False):
104 HandshakeMsg.__init__(self, HandshakeType.client_hello) 105 self.ssl2 = ssl2 106 self.client_version = (0,0) 107 self.random = bytearray(32) 108 self.session_id = bytearray(0) 109 self.cipher_suites = [] # a list of 16-bit values 110 self.certificate_types = [CertificateType.x509] 111 self.compression_methods = [] # a list of 8-bit values 112 self.srp_username = None # a string 113 self.tack = False 114 self.supports_npn = False 115 self.server_name = bytearray(0)
116
117 - def create(self, version, random, session_id, cipher_suites, 118 certificate_types=None, srpUsername=None, 119 tack=False, supports_npn=False, serverName=None):
120 self.client_version = version 121 self.random = random 122 self.session_id = session_id 123 self.cipher_suites = cipher_suites 124 self.certificate_types = certificate_types 125 self.compression_methods = [0] 126 if srpUsername: 127 self.srp_username = bytearray(srpUsername, "utf-8") 128 self.tack = tack 129 self.supports_npn = supports_npn 130 if serverName: 131 self.server_name = bytearray(serverName, "utf-8") 132 return self
133
134 - def parse(self, p):
135 if self.ssl2: 136 self.client_version = (p.get(1), p.get(1)) 137 cipherSpecsLength = p.get(2) 138 sessionIDLength = p.get(2) 139 randomLength = p.get(2) 140 self.cipher_suites = p.getFixList(3, cipherSpecsLength//3) 141 self.session_id = p.getFixBytes(sessionIDLength) 142 self.random = p.getFixBytes(randomLength) 143 if len(self.random) < 32: 144 zeroBytes = 32-len(self.random) 145 self.random = bytearray(zeroBytes) + self.random 146 self.compression_methods = [0]#Fake this value 147 148 #We're not doing a stopLengthCheck() for SSLv2, oh well.. 149 else: 150 p.startLengthCheck(3) 151 self.client_version = (p.get(1), p.get(1)) 152 self.random = p.getFixBytes(32) 153 self.session_id = p.getVarBytes(1) 154 self.cipher_suites = p.getVarList(2, 2) 155 self.compression_methods = p.getVarList(1, 1) 156 if not p.atLengthCheck(): 157 totalExtLength = p.get(2) 158 soFar = 0 159 while soFar != totalExtLength: 160 extType = p.get(2) 161 extLength = p.get(2) 162 index1 = p.index 163 if extType == ExtensionType.srp: 164 self.srp_username = p.getVarBytes(1) 165 elif extType == ExtensionType.cert_type: 166 self.certificate_types = p.getVarList(1, 1) 167 elif extType == ExtensionType.tack: 168 self.tack = True 169 elif extType == ExtensionType.supports_npn: 170 self.supports_npn = True 171 elif extType == ExtensionType.server_name: 172 serverNameListBytes = p.getFixBytes(extLength) 173 p2 = Parser(serverNameListBytes) 174 p2.startLengthCheck(2) 175 while 1: 176 if p2.atLengthCheck(): 177 break # no host_name, oh well 178 name_type = p2.get(1) 179 hostNameBytes = p2.getVarBytes(2) 180 if name_type == NameType.host_name: 181 self.server_name = hostNameBytes 182 break 183 else: 184 _ = p.getFixBytes(extLength) 185 index2 = p.index 186 if index2 - index1 != extLength: 187 raise SyntaxError("Bad length for extension_data") 188 soFar += 4 + extLength 189 p.stopLengthCheck() 190 return self
191
192 - def write(self):
193 w = Writer() 194 w.add(self.client_version[0], 1) 195 w.add(self.client_version[1], 1) 196 w.addFixSeq(self.random, 1) 197 w.addVarSeq(self.session_id, 1, 1) 198 w.addVarSeq(self.cipher_suites, 2, 2) 199 w.addVarSeq(self.compression_methods, 1, 1) 200 201 w2 = Writer() # For Extensions 202 if self.certificate_types and self.certificate_types != \ 203 [CertificateType.x509]: 204 w2.add(ExtensionType.cert_type, 2) 205 w2.add(len(self.certificate_types)+1, 2) 206 w2.addVarSeq(self.certificate_types, 1, 1) 207 if self.srp_username: 208 w2.add(ExtensionType.srp, 2) 209 w2.add(len(self.srp_username)+1, 2) 210 w2.addVarSeq(self.srp_username, 1, 1) 211 if self.supports_npn: 212 w2.add(ExtensionType.supports_npn, 2) 213 w2.add(0, 2) 214 if self.server_name: 215 w2.add(ExtensionType.server_name, 2) 216 w2.add(len(self.server_name)+5, 2) 217 w2.add(len(self.server_name)+3, 2) 218 w2.add(NameType.host_name, 1) 219 w2.addVarSeq(self.server_name, 1, 2) 220 if self.tack: 221 w2.add(ExtensionType.tack, 2) 222 w2.add(0, 2) 223 if len(w2.bytes): 224 w.add(len(w2.bytes), 2) 225 w.bytes += w2.bytes 226 return self.postWrite(w)
227
228 -class BadNextProtos(Exception):
229 - def __init__(self, l):
230 self.length = l
231
232 - def __str__(self):
233 return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
234
235 -class ServerHello(HandshakeMsg):
236 - def __init__(self):
237 HandshakeMsg.__init__(self, HandshakeType.server_hello) 238 self.server_version = (0,0) 239 self.random = bytearray(32) 240 self.session_id = bytearray(0) 241 self.cipher_suite = 0 242 self.certificate_type = CertificateType.x509 243 self.compression_method = 0 244 self.tackExt = None 245 self.next_protos_advertised = None 246 self.next_protos = None
247
248 - def create(self, version, random, session_id, cipher_suite, 249 certificate_type, tackExt, next_protos_advertised):
250 self.server_version = version 251 self.random = random 252 self.session_id = session_id 253 self.cipher_suite = cipher_suite 254 self.certificate_type = certificate_type 255 self.compression_method = 0 256 self.tackExt = tackExt 257 self.next_protos_advertised = next_protos_advertised 258 return self
259
260 - def parse(self, p):
261 p.startLengthCheck(3) 262 self.server_version = (p.get(1), p.get(1)) 263 self.random = p.getFixBytes(32) 264 self.session_id = p.getVarBytes(1) 265 self.cipher_suite = p.get(2) 266 self.compression_method = p.get(1) 267 if not p.atLengthCheck(): 268 totalExtLength = p.get(2) 269 soFar = 0 270 while soFar != totalExtLength: 271 extType = p.get(2) 272 extLength = p.get(2) 273 if extType == ExtensionType.cert_type: 274 if extLength != 1: 275 raise SyntaxError() 276 self.certificate_type = p.get(1) 277 elif extType == ExtensionType.tack and tackpyLoaded: 278 self.tackExt = TackExtension(p.getFixBytes(extLength)) 279 elif extType == ExtensionType.supports_npn: 280 self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength)) 281 else: 282 p.getFixBytes(extLength) 283 soFar += 4 + extLength 284 p.stopLengthCheck() 285 return self
286
287 - def __parse_next_protos(self, b):
288 protos = [] 289 while True: 290 if len(b) == 0: 291 break 292 l = b[0] 293 b = b[1:] 294 if len(b) < l: 295 raise BadNextProtos(len(b)) 296 protos.append(b[:l]) 297 b = b[l:] 298 return protos
299
300 - def __next_protos_encoded(self):
301 b = bytearray() 302 for e in self.next_protos_advertised: 303 if len(e) > 255 or len(e) == 0: 304 raise BadNextProtos(len(e)) 305 b += bytearray( [len(e)] ) + bytearray(e) 306 return b
307
308 - def write(self):
309 w = Writer() 310 w.add(self.server_version[0], 1) 311 w.add(self.server_version[1], 1) 312 w.addFixSeq(self.random, 1) 313 w.addVarSeq(self.session_id, 1, 1) 314 w.add(self.cipher_suite, 2) 315 w.add(self.compression_method, 1) 316 317 w2 = Writer() # For Extensions 318 if self.certificate_type and self.certificate_type != \ 319 CertificateType.x509: 320 w2.add(ExtensionType.cert_type, 2) 321 w2.add(1, 2) 322 w2.add(self.certificate_type, 1) 323 if self.tackExt: 324 b = self.tackExt.serialize() 325 w2.add(ExtensionType.tack, 2) 326 w2.add(len(b), 2) 327 w2.bytes += b 328 if self.next_protos_advertised is not None: 329 encoded_next_protos_advertised = self.__next_protos_encoded() 330 w2.add(ExtensionType.supports_npn, 2) 331 w2.add(len(encoded_next_protos_advertised), 2) 332 w2.addFixSeq(encoded_next_protos_advertised, 1) 333 if len(w2.bytes): 334 w.add(len(w2.bytes), 2) 335 w.bytes += w2.bytes 336 return self.postWrite(w)
337 338
339 -class Certificate(HandshakeMsg):
340 - def __init__(self, certificateType):
341 HandshakeMsg.__init__(self, HandshakeType.certificate) 342 self.certificateType = certificateType 343 self.certChain = None
344
345 - def create(self, certChain):
346 self.certChain = certChain 347 return self
348
349 - def parse(self, p):
350 p.startLengthCheck(3) 351 if self.certificateType == CertificateType.x509: 352 chainLength = p.get(3) 353 index = 0 354 certificate_list = [] 355 while index != chainLength: 356 certBytes = p.getVarBytes(3) 357 x509 = X509() 358 x509.parseBinary(certBytes) 359 certificate_list.append(x509) 360 index += len(certBytes)+3 361 if certificate_list: 362 self.certChain = X509CertChain(certificate_list) 363 else: 364 raise AssertionError() 365 366 p.stopLengthCheck() 367 return self
368
369 - def write(self):
370 w = Writer() 371 if self.certificateType == CertificateType.x509: 372 chainLength = 0 373 if self.certChain: 374 certificate_list = self.certChain.x509List 375 else: 376 certificate_list = [] 377 #determine length 378 for cert in certificate_list: 379 bytes = cert.writeBytes() 380 chainLength += len(bytes)+3 381 #add bytes 382 w.add(chainLength, 3) 383 for cert in certificate_list: 384 bytes = cert.writeBytes() 385 w.addVarSeq(bytes, 1, 3) 386 else: 387 raise AssertionError() 388 return self.postWrite(w)
389
390 -class CertificateRequest(HandshakeMsg):
391 - def __init__(self, version):
392 HandshakeMsg.__init__(self, HandshakeType.certificate_request) 393 #Apple's Secure Transport library rejects empty certificate_types, so 394 #default to rsa_sign. 395 self.certificate_types = [ClientCertificateType.rsa_sign] 396 self.certificate_authorities = [] 397 self.version = version 398 self.supported_signature_algs = []
399
400 - def create(self, certificate_types, certificate_authorities, sig_algs=(), version=(3,0)):
401 self.certificate_types = certificate_types 402 self.certificate_authorities = certificate_authorities 403 self.version = version 404 self.supported_signature_algs = sig_algs 405 return self
406
407 - def parse(self, p):
408 p.startLengthCheck(3) 409 self.certificate_types = p.getVarList(1, 1) 410 if self.version >= (3,3): 411 self.supported_signature_algs = p.getVarList(2, 2) 412 ca_list_length = p.get(2) 413 index = 0 414 self.certificate_authorities = [] 415 while index != ca_list_length: 416 ca_bytes = p.getVarBytes(2) 417 self.certificate_authorities.append(ca_bytes) 418 index += len(ca_bytes)+2 419 p.stopLengthCheck() 420 return self
421
422 - def write(self):
423 w = Writer() 424 w.addVarSeq(self.certificate_types, 1, 1) 425 if self.version >= (3,3): 426 w.addVarSeq(self.supported_signature_algs, 2, 2) 427 caLength = 0 428 #determine length 429 for ca_dn in self.certificate_authorities: 430 caLength += len(ca_dn)+2 431 w.add(caLength, 2) 432 #add bytes 433 for ca_dn in self.certificate_authorities: 434 w.addVarSeq(ca_dn, 1, 2) 435 return self.postWrite(w)
436
437 -class ServerKeyExchange(HandshakeMsg):
438 - def __init__(self, cipherSuite):
439 HandshakeMsg.__init__(self, HandshakeType.server_key_exchange) 440 self.cipherSuite = cipherSuite 441 self.srp_N = 0 442 self.srp_g = 0 443 self.srp_s = bytearray(0) 444 self.srp_B = 0 445 # Anon DH params: 446 self.dh_p = 0 447 self.dh_g = 0 448 self.dh_Ys = 0 449 self.signature = bytearray(0)
450
451 - def createSRP(self, srp_N, srp_g, srp_s, srp_B):
452 self.srp_N = srp_N 453 self.srp_g = srp_g 454 self.srp_s = srp_s 455 self.srp_B = srp_B 456 return self
457
458 - def createDH(self, dh_p, dh_g, dh_Ys):
459 self.dh_p = dh_p 460 self.dh_g = dh_g 461 self.dh_Ys = dh_Ys 462 return self
463
464 - def parse(self, p):
465 p.startLengthCheck(3) 466 if self.cipherSuite in CipherSuite.srpAllSuites: 467 self.srp_N = bytesToNumber(p.getVarBytes(2)) 468 self.srp_g = bytesToNumber(p.getVarBytes(2)) 469 self.srp_s = p.getVarBytes(1) 470 self.srp_B = bytesToNumber(p.getVarBytes(2)) 471 if self.cipherSuite in CipherSuite.srpCertSuites: 472 self.signature = p.getVarBytes(2) 473 elif self.cipherSuite in CipherSuite.anonSuites: 474 self.dh_p = bytesToNumber(p.getVarBytes(2)) 475 self.dh_g = bytesToNumber(p.getVarBytes(2)) 476 self.dh_Ys = bytesToNumber(p.getVarBytes(2)) 477 p.stopLengthCheck() 478 return self
479
480 - def write(self):
481 w = Writer() 482 if self.cipherSuite in CipherSuite.srpAllSuites: 483 w.addVarSeq(numberToByteArray(self.srp_N), 1, 2) 484 w.addVarSeq(numberToByteArray(self.srp_g), 1, 2) 485 w.addVarSeq(self.srp_s, 1, 1) 486 w.addVarSeq(numberToByteArray(self.srp_B), 1, 2) 487 if self.cipherSuite in CipherSuite.srpCertSuites: 488 w.addVarSeq(self.signature, 1, 2) 489 elif self.cipherSuite in CipherSuite.anonSuites: 490 w.addVarSeq(numberToByteArray(self.dh_p), 1, 2) 491 w.addVarSeq(numberToByteArray(self.dh_g), 1, 2) 492 w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2) 493 if self.cipherSuite in []: # TODO support for signed_params 494 w.addVarSeq(self.signature, 1, 2) 495 return self.postWrite(w)
496
497 - def hash(self, clientRandom, serverRandom):
498 oldCipherSuite = self.cipherSuite 499 self.cipherSuite = None 500 try: 501 bytes = clientRandom + serverRandom + self.write()[4:] 502 return MD5(bytes) + SHA1(bytes) 503 finally: 504 self.cipherSuite = oldCipherSuite
505
506 -class ServerHelloDone(HandshakeMsg):
507 - def __init__(self):
509
510 - def create(self):
511 return self
512
513 - def parse(self, p):
514 p.startLengthCheck(3) 515 p.stopLengthCheck() 516 return self
517
518 - def write(self):
519 w = Writer() 520 return self.postWrite(w)
521
522 -class ClientKeyExchange(HandshakeMsg):
523 - def __init__(self, cipherSuite, version=None):
524 HandshakeMsg.__init__(self, HandshakeType.client_key_exchange) 525 self.cipherSuite = cipherSuite 526 self.version = version 527 self.srp_A = 0 528 self.encryptedPreMasterSecret = bytearray(0)
529
530 - def createSRP(self, srp_A):
531 self.srp_A = srp_A 532 return self
533
534 - def createRSA(self, encryptedPreMasterSecret):
535 self.encryptedPreMasterSecret = encryptedPreMasterSecret 536 return self
537
538 - def createDH(self, dh_Yc):
539 self.dh_Yc = dh_Yc 540 return self
541
542 - def parse(self, p):
543 p.startLengthCheck(3) 544 if self.cipherSuite in CipherSuite.srpAllSuites: 545 self.srp_A = bytesToNumber(p.getVarBytes(2)) 546 elif self.cipherSuite in CipherSuite.certSuites: 547 if self.version in ((3,1), (3,2), (3,3)): 548 self.encryptedPreMasterSecret = p.getVarBytes(2) 549 elif self.version == (3,0): 550 self.encryptedPreMasterSecret = \ 551 p.getFixBytes(len(p.bytes)-p.index) 552 else: 553 raise AssertionError() 554 elif self.cipherSuite in CipherSuite.anonSuites: 555 self.dh_Yc = bytesToNumber(p.getVarBytes(2)) 556 else: 557 raise AssertionError() 558 p.stopLengthCheck() 559 return self
560
561 - def write(self):
562 w = Writer() 563 if self.cipherSuite in CipherSuite.srpAllSuites: 564 w.addVarSeq(numberToByteArray(self.srp_A), 1, 2) 565 elif self.cipherSuite in CipherSuite.certSuites: 566 if self.version in ((3,1), (3,2), (3,3)): 567 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2) 568 elif self.version == (3,0): 569 w.addFixSeq(self.encryptedPreMasterSecret, 1) 570 else: 571 raise AssertionError() 572 elif self.cipherSuite in CipherSuite.anonSuites: 573 w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2) 574 else: 575 raise AssertionError() 576 return self.postWrite(w)
577
578 -class CertificateVerify(HandshakeMsg):
579 - def __init__(self):
580 HandshakeMsg.__init__(self, HandshakeType.certificate_verify) 581 self.signature = bytearray(0)
582
583 - def create(self, signature):
584 self.signature = signature 585 return self
586
587 - def parse(self, p):
588 p.startLengthCheck(3) 589 self.signature = p.getVarBytes(2) 590 p.stopLengthCheck() 591 return self
592
593 - def write(self):
594 w = Writer() 595 w.addVarSeq(self.signature, 1, 2) 596 return self.postWrite(w)
597
598 -class ChangeCipherSpec(object):
599 - def __init__(self):
600 self.contentType = ContentType.change_cipher_spec 601 self.type = 1
602
603 - def create(self):
604 self.type = 1 605 return self
606
607 - def parse(self, p):
608 p.setLengthCheck(1) 609 self.type = p.get(1) 610 p.stopLengthCheck() 611 return self
612
613 - def write(self):
614 w = Writer() 615 w.add(self.type,1) 616 return w.bytes
617 618
619 -class NextProtocol(HandshakeMsg):
620 - def __init__(self):
621 HandshakeMsg.__init__(self, HandshakeType.next_protocol) 622 self.next_proto = None
623
624 - def create(self, next_proto):
625 self.next_proto = next_proto 626 return self
627
628 - def parse(self, p):
629 p.startLengthCheck(3) 630 self.next_proto = p.getVarBytes(1) 631 _ = p.getVarBytes(1) 632 p.stopLengthCheck() 633 return self
634
635 - def write(self, trial=False):
636 w = Writer() 637 w.addVarSeq(self.next_proto, 1, 1) 638 paddingLen = 32 - ((len(self.next_proto) + 2) % 32) 639 w.addVarSeq(bytearray(paddingLen), 1, 1) 640 return self.postWrite(w)
641
642 -class Finished(HandshakeMsg):
643 - def __init__(self, version):
644 HandshakeMsg.__init__(self, HandshakeType.finished) 645 self.version = version 646 self.verify_data = bytearray(0)
647
648 - def create(self, verify_data):
649 self.verify_data = verify_data 650 return self
651
652 - def parse(self, p):
653 p.startLengthCheck(3) 654 if self.version == (3,0): 655 self.verify_data = p.getFixBytes(36) 656 elif self.version in ((3,1), (3,2), (3,3)): 657 self.verify_data = p.getFixBytes(12) 658 else: 659 raise AssertionError() 660 p.stopLengthCheck() 661 return self
662
663 - def write(self):
664 w = Writer() 665 w.addFixSeq(self.verify_data, 1) 666 return self.postWrite(w)
667
668 -class ApplicationData(object):
669 - def __init__(self):
670 self.contentType = ContentType.application_data 671 self.bytes = bytearray(0)
672
673 - def create(self, bytes):
674 self.bytes = bytes 675 return self
676
677 - def splitFirstByte(self):
678 newMsg = ApplicationData().create(self.bytes[:1]) 679 self.bytes = self.bytes[1:] 680 return newMsg
681
682 - def parse(self, p):
683 self.bytes = p.bytes 684 return self
685
686 - def write(self):
687 return self.bytes
688