1
2
3
4
5 """Implementation of the TLS Record Layer protocol"""
6
7 import socket
8 import errno
9 import hashlib
10 from .constants import ContentType, CipherSuite
11 from .messages import RecordHeader3, RecordHeader2, Message
12 from .utils.cipherfactory import createAESGCM, createAES, createRC4, \
13 createTripleDES, createCHACHA20
14 from .utils.codec import Parser, Writer
15 from .utils.compat import compatHMAC
16 from .utils.cryptomath import getRandomBytes
17 from .utils.constanttime import ct_compare_digest, ct_check_cbc_mac_and_pad
18 from .errors import TLSRecordOverflow, TLSIllegalParameterException,\
19 TLSAbruptCloseError, TLSDecryptionFailed, TLSBadRecordMAC
20 from .mathtls import createMAC_SSL, createHMAC, PRF_SSL, PRF, PRF_1_2, \
21 PRF_1_2_SHA384
24
25 """Socket wrapper for reading and writing TLS Records"""
26
28 """
29 Assign socket to wrapper
30
31 @type sock: socket.socket
32 """
33 self.sock = sock
34 self.version = (0, 0)
35
37 """
38 Send all data through socket
39
40 @type data: bytearray
41 @param data: data to send
42 @raise socket.error: when write to socket failed
43 """
44 while 1:
45 try:
46 bytesSent = self.sock.send(data)
47 except socket.error as why:
48 if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
49 yield 1
50 continue
51 raise
52
53 if bytesSent == len(data):
54 return
55 data = data[bytesSent:]
56 yield 1
57
58 - def send(self, msg):
59 """
60 Send the message through socket.
61
62 @type msg: bytearray
63 @param msg: TLS message to send
64 @raise socket.error: when write to socket failed
65 """
66 data = msg.write()
67
68 header = RecordHeader3().create(self.version,
69 msg.contentType,
70 len(data))
71
72 data = header.write() + data
73
74 for result in self._sockSendAll(data):
75 yield result
76
78 """
79 Read exactly the amount of bytes specified in L{length} from raw socket.
80
81 @rtype: generator
82 @return: generator that will return 0 or 1 in case the socket is non
83 blocking and would block and bytearray in case the read finished
84 @raise TLSAbruptCloseError: when the socket closed
85 """
86 buf = bytearray(0)
87
88 if length == 0:
89 yield buf
90
91 while True:
92 try:
93 socketBytes = self.sock.recv(length - len(buf))
94 except socket.error as why:
95 if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
96 yield 0
97 continue
98 else:
99 raise
100
101
102 if len(socketBytes) == 0:
103 raise TLSAbruptCloseError()
104
105 buf += bytearray(socketBytes)
106 if len(buf) == length:
107 yield buf
108
110 """Read a single record header from socket"""
111
112 buf = bytearray(0)
113 ssl2 = False
114
115 result = None
116 for result in self._sockRecvAll(1):
117 if result in (0, 1):
118 yield result
119 else: break
120 assert result is not None
121
122 buf += result
123
124 if buf[0] in ContentType.all:
125 ssl2 = False
126
127 result = None
128 for result in self._sockRecvAll(4):
129 if result in (0, 1):
130 yield result
131 else: break
132 assert result is not None
133 buf += result
134
135
136 elif buf[0] == 128:
137 ssl2 = True
138
139
140 result = None
141 for result in self._sockRecvAll(1):
142 if result in (0, 1):
143 yield result
144 else: break
145 assert result is not None
146 buf += result
147 else:
148 raise TLSIllegalParameterException(
149 "Record header type doesn't specify known type")
150
151
152 if ssl2:
153 record = RecordHeader2().parse(Parser(buf))
154 else:
155 record = RecordHeader3().parse(Parser(buf))
156
157 yield record
158
160 """
161 Read a single record from socket, handle SSLv2 and SSLv3 record layer
162
163 @rtype: generator
164 @return: generator that returns 0 or 1 in case the read would be
165 blocking or a tuple containing record header (object) and record
166 data (bytearray) read from socket
167 @raise socket.error: In case of network error
168 @raise TLSAbruptCloseError: When the socket was closed on the other
169 side in middle of record receiving
170 @raise TLSRecordOverflow: When the received record was longer than
171 allowed by TLS
172 @raise TLSIllegalParameterException: When the record header was
173 malformed
174 """
175 record = None
176 for record in self._recvHeader():
177 if record in (0, 1):
178 yield record
179 else: break
180 assert record is not None
181
182
183
184
185 if record.length > 18432:
186 raise TLSRecordOverflow()
187
188
189 buf = bytearray(0)
190
191 result = None
192 for result in self._sockRecvAll(record.length):
193 if result in (0, 1):
194 yield result
195 else: break
196 assert result is not None
197
198 buf += result
199
200 yield (record, buf)
201
203
204 """Preserve the connection state for reading and writing data to records"""
205
207 """Create an instance with empty encryption and MACing contexts"""
208 self.macContext = None
209 self.encContext = None
210 self.fixedNonce = None
211 self.seqnum = 0
212
214 """Return encoded sequence number and increment it."""
215 writer = Writer()
216 writer.add(self.seqnum, 8)
217 self.seqnum += 1
218 return writer.bytes
219
221
222 """
223 Implementation of TLS record layer protocol
224
225 @ivar version: the TLS version to use (tuple encoded as on the wire)
226 @ivar sock: underlying socket
227 @ivar client: whether the connection should use encryption
228 @ivar encryptThenMAC: use the encrypt-then-MAC mechanism for record
229 integrity
230 """
231
246
247 @property
249 """Return the TLS version used by record layer"""
250 return self._version
251
252 @version.setter
254 """Set the TLS version used by record layer"""
255 self._version = val
256 self._recordSocket.version = val
257
259 """
260 Return the name of the bulk cipher used by this connection
261
262 @rtype: str
263 @return: The name of the cipher, like 'aes128', 'rc4', etc.
264 """
265 if self._writeState.encContext is None:
266 return None
267 return self._writeState.encContext.name
268
270 """
271 Return the name of the implementation used for the connection
272
273 'python' for tlslite internal implementation, 'openssl' for M2crypto
274 and 'pycrypto' for pycrypto
275 @rtype: str
276 @return: Name of cipher implementation used, None if not initialised
277 """
278 if self._writeState.encContext is None:
279 return None
280 return self._writeState.encContext.implementation
281
288
290 """Returns true if cipher uses CBC mode"""
291 if self._writeState and self._writeState.encContext and \
292 self._writeState.encContext.isBlockCipher:
293 return True
294 else:
295 return False
296
297
298
299
301 """Add padding to data so that it is multiple of block size"""
302 currentLength = len(data)
303 blockLength = self._writeState.encContext.block_size
304 paddingLength = blockLength - 1 - (currentLength % blockLength)
305
306 paddingBytes = bytearray([paddingLength] * (paddingLength+1))
307 data += paddingBytes
308 return data
309
310 - def calculateMAC(self, mac, seqnumBytes, contentType, data):
322
324 """MAC, pad then encrypt data"""
325 if self._writeState.macContext:
326 seqnumBytes = self._writeState.getSeqNumBytes()
327 mac = self._writeState.macContext.copy()
328 macBytes = self.calculateMAC(mac, seqnumBytes, contentType, data)
329 data += macBytes
330
331
332 if self._writeState.encContext:
333
334 if self._writeState.encContext.isBlockCipher:
335
336
337 if self.version >= (3, 2):
338 data = self.fixedIVBlock + data
339
340 data = self.addPadding(data)
341
342
343 data = self._writeState.encContext.encrypt(data)
344
345 return data
346
348 """Pad, encrypt and then MAC the data"""
349 if self._writeState.encContext:
350
351 if self.version >= (3, 2):
352 buf = self.fixedIVBlock + buf
353
354 buf = self.addPadding(buf)
355
356 buf = self._writeState.encContext.encrypt(buf)
357
358
359 if self._writeState.macContext:
360 seqnumBytes = self._writeState.getSeqNumBytes()
361 mac = self._writeState.macContext.copy()
362
363
364 macBytes = self.calculateMAC(mac, seqnumBytes, contentType, buf)
365 buf += macBytes
366
367 return buf
368
370 """Encrypt with AEAD cipher"""
371
372 seqNumBytes = self._writeState.getSeqNumBytes()
373 authData = seqNumBytes + bytearray([contentType,
374 self.version[0],
375 self.version[1],
376 len(buf)//256,
377 len(buf)%256])
378
379
380 nonce = self._writeState.fixedNonce + seqNumBytes
381
382 assert len(nonce) == self._writeState.encContext.nonceLength
383
384 buf = self._writeState.encContext.seal(nonce, buf, authData)
385
386
387 if "aes" in self._writeState.encContext.name:
388 buf = seqNumBytes + buf
389
390 return buf
391
393 """
394 Encrypt, MAC and send arbitrary message as-is through socket.
395
396 Note that if the message was not fragmented to below 2**14 bytes
397 it will be rejected by the other connection side.
398
399 @param msg: TLS message to send
400 @type msg: ApplicationData, HandshakeMessage, etc.
401 """
402 data = msg.write()
403 contentType = msg.contentType
404
405 if self._writeState and \
406 self._writeState.encContext and \
407 self._writeState.encContext.isAEAD:
408 data = self._encryptThenSeal(data, contentType)
409 elif self.encryptThenMAC:
410 data = self._encryptThenMAC(data, contentType)
411 else:
412 data = self._macThenEncrypt(data, contentType)
413
414 encryptedMessage = Message(contentType, data)
415
416 for result in self._recordSocket.send(encryptedMessage):
417 yield result
418
419
420
421
422
424 """Decrypt a stream cipher and check MAC"""
425 if self._readState.encContext:
426 assert self.version in ((3, 0), (3, 1), (3, 2), (3, 3))
427
428 data = self._readState.encContext.decrypt(data)
429
430 if self._readState.macContext:
431
432 macGood = True
433 macLength = self._readState.macContext.digest_size
434 endLength = macLength
435 if endLength > len(data):
436 macGood = False
437 else:
438
439 startIndex = len(data) - endLength
440 endIndex = startIndex + macLength
441 checkBytes = data[startIndex : endIndex]
442
443
444 seqnumBytes = self._readState.getSeqNumBytes()
445 data = data[:-endLength]
446 mac = self._readState.macContext.copy()
447 macBytes = self.calculateMAC(mac, seqnumBytes, recordType,
448 data)
449
450
451 if not ct_compare_digest(macBytes, checkBytes):
452 macGood = False
453
454 if not macGood:
455 raise TLSBadRecordMAC()
456
457 return data
458
459
461 """Decrypt data, check padding and MAC"""
462 if self._readState.encContext:
463 assert self.version in ((3, 0), (3, 1), (3, 2), (3, 3))
464 assert self._readState.encContext.isBlockCipher
465 assert self._readState.macContext
466
467
468
469
470 blockLength = self._readState.encContext.block_size
471 if len(data) % blockLength != 0:
472 raise TLSDecryptionFailed()
473 data = self._readState.encContext.decrypt(data)
474 if self.version >= (3, 2):
475 data = data[self._readState.encContext.block_size : ]
476
477
478
479
480 seqnumBytes = self._readState.getSeqNumBytes()
481
482 if not ct_check_cbc_mac_and_pad(data,
483 self._readState.macContext,
484 seqnumBytes,
485 recordType,
486 self.version):
487 raise TLSBadRecordMAC()
488
489
490
491
492
493 endLength = data[-1] + 1 + self._readState.macContext.digest_size
494
495 data = data[:-endLength]
496
497 return data
498
500 """
501 Check MAC of data, then decrypt and remove padding
502
503 @raise TLSBadRecordMAC: when the mac value is invalid
504 @raise TLSDecryptionFailed: when the data to decrypt has invalid size
505 """
506 if self._readState.macContext:
507 macLength = self._readState.macContext.digest_size
508 if len(buf) < macLength:
509 raise TLSBadRecordMAC("Truncated data")
510
511 checkBytes = buf[-macLength:]
512 buf = buf[:-macLength]
513
514 seqnumBytes = self._readState.getSeqNumBytes()
515 mac = self._readState.macContext.copy()
516
517 macBytes = self.calculateMAC(mac, seqnumBytes, recordType, buf)
518
519 if not ct_compare_digest(macBytes, checkBytes):
520 raise TLSBadRecordMAC("MAC mismatch")
521
522 if self._readState.encContext:
523 blockLength = self._readState.encContext.block_size
524 if len(buf) % blockLength != 0:
525 raise TLSDecryptionFailed("data length not multiple of "\
526 "block size")
527
528 buf = self._readState.encContext.decrypt(buf)
529
530
531 if self.version >= (3, 2):
532 buf = buf[blockLength:]
533
534 if len(buf) == 0:
535 raise TLSBadRecordMAC("No data left after IV removal")
536
537
538 paddingLength = buf[-1]
539 if paddingLength + 1 > len(buf):
540 raise TLSBadRecordMAC("Invalid padding length")
541
542 paddingGood = True
543 totalPaddingLength = paddingLength+1
544 if self.version != (3, 0):
545 paddingBytes = buf[-totalPaddingLength:-1]
546 for byte in paddingBytes:
547 if byte != paddingLength:
548 paddingGood = False
549
550 if not paddingGood:
551 raise TLSBadRecordMAC("Invalid padding byte values")
552
553
554 buf = buf[:-totalPaddingLength]
555
556 return buf
557
559 """Decrypt AEAD encrypted data"""
560 seqnumBytes = self._readState.getSeqNumBytes()
561
562 if "aes" in self._readState.encContext.name:
563 explicitNonceLength = 8
564 if explicitNonceLength > len(buf):
565
566 raise TLSBadRecordMAC("Truncated nonce")
567 nonce = self._readState.fixedNonce + buf[:explicitNonceLength]
568 buf = buf[8:]
569 else:
570 nonce = self._readState.fixedNonce + seqnumBytes
571
572 if self._readState.encContext.tagLength > len(buf):
573
574 raise TLSBadRecordMAC("Truncated tag")
575
576 plaintextLen = len(buf) - self._readState.encContext.tagLength
577 authData = seqnumBytes + bytearray([recordType, self.version[0],
578 self.version[1],
579 plaintextLen//256,
580 plaintextLen%256])
581
582 buf = self._readState.encContext.open(nonce, buf, authData)
583 if buf is None:
584 raise TLSBadRecordMAC("Invalid tag, decryption failure")
585 return buf
586
588 """
589 Read, decrypt and check integrity of a single record
590
591 @rtype: tuple
592 @return: message header and decrypted message payload
593 @raise TLSDecryptionFailed: when decryption of data failed
594 @raise TLSBadRecordMAC: when record has bad MAC or padding
595 @raise socket.error: when reading from socket was unsuccessful
596 """
597 result = None
598 for result in self._recordSocket.recv():
599 if result in (0, 1):
600 yield result
601 else: break
602 assert result is not None
603
604 (header, data) = result
605
606 if self._readState and \
607 self._readState.encContext and \
608 self._readState.encContext.isAEAD:
609 data = self._decryptAndUnseal(header.type, data)
610 elif self.encryptThenMAC:
611 data = self._macThenDecrypt(header.type, data)
612 elif self._readState and \
613 self._readState.encContext and \
614 self._readState.encContext.isBlockCipher:
615 data = self._decryptThenMAC(header.type, data)
616 else:
617 data = self._decryptStreamThenMAC(header.type, data)
618
619 yield (header, Parser(data))
620
621
622
623
624
626 """
627 Change the cipher state to the pending one for write operations.
628
629 This should be done only once after a call to L{calcPendingStates} was
630 performed and directly after sending a L{ChangeCipherSpec} message.
631 """
632 self._writeState = self._pendingWriteState
633 self._pendingWriteState = ConnectionState()
634
636 """
637 Change the cipher state to the pending one for read operations.
638
639 This should be done only once after a call to L{calcPendingStates} was
640 performed and directly after receiving a L{ChangeCipherSpec} message.
641 """
642 self._readState = self._pendingReadState
643 self._pendingReadState = ConnectionState()
644
645 @staticmethod
684
685 @staticmethod
704
705 @staticmethod
707 """Get the HMAC method"""
708 assert version in ((3, 0), (3, 1), (3, 2), (3, 3))
709 if version == (3, 0):
710 createMACFunc = createMAC_SSL
711 elif version in ((3, 1), (3, 2), (3, 3)):
712 createMACFunc = createHMAC
713
714 return createMACFunc
715
716 - def _calcKeyBlock(self, cipherSuite, masterSecret, clientRandom,
717 serverRandom, outputLength):
718 """Calculate the overall key to slice up"""
719 if self.version == (3, 0):
720 keyBlock = PRF_SSL(masterSecret,
721 serverRandom + clientRandom,
722 outputLength)
723 elif self.version in ((3, 1), (3, 2)):
724 keyBlock = PRF(masterSecret,
725 b"key expansion",
726 serverRandom + clientRandom,
727 outputLength)
728 elif self.version == (3, 3):
729 if cipherSuite in CipherSuite.sha384PrfSuites:
730 keyBlock = PRF_1_2_SHA384(masterSecret,
731 b"key expansion",
732 serverRandom + clientRandom,
733 outputLength)
734 else:
735 keyBlock = PRF_1_2(masterSecret,
736 b"key expansion",
737 serverRandom + clientRandom,
738 outputLength)
739 else:
740 raise AssertionError()
741
742 return keyBlock
743
744 - def calcPendingStates(self, cipherSuite, masterSecret, clientRandom,
745 serverRandom, implementations):
746 """Create pending states for encryption and decryption."""
747 keyLength, ivLength, createCipherFunc = \
748 self._getCipherSettings(cipherSuite)
749
750 macLength, digestmod = self._getMacSettings(cipherSuite)
751
752 if not digestmod:
753 createMACFunc = None
754 else:
755 createMACFunc = self._getHMACMethod(self.version)
756
757 outputLength = (macLength*2) + (keyLength*2) + (ivLength*2)
758
759
760 keyBlock = self._calcKeyBlock(cipherSuite, masterSecret, clientRandom,
761 serverRandom, outputLength)
762
763
764 clientPendingState = ConnectionState()
765 serverPendingState = ConnectionState()
766 parser = Parser(keyBlock)
767 clientMACBlock = parser.getFixBytes(macLength)
768 serverMACBlock = parser.getFixBytes(macLength)
769 clientKeyBlock = parser.getFixBytes(keyLength)
770 serverKeyBlock = parser.getFixBytes(keyLength)
771 clientIVBlock = parser.getFixBytes(ivLength)
772 serverIVBlock = parser.getFixBytes(ivLength)
773
774 if digestmod:
775
776 clientPendingState.macContext = createMACFunc(
777 compatHMAC(clientMACBlock), digestmod=digestmod)
778 serverPendingState.macContext = createMACFunc(
779 compatHMAC(serverMACBlock), digestmod=digestmod)
780 if createCipherFunc is not None:
781 clientPendingState.encContext = \
782 createCipherFunc(clientKeyBlock,
783 clientIVBlock,
784 implementations)
785 serverPendingState.encContext = \
786 createCipherFunc(serverKeyBlock,
787 serverIVBlock,
788 implementations)
789 else:
790
791 clientPendingState.macContext = None
792 serverPendingState.macContext = None
793 clientPendingState.encContext = createCipherFunc(clientKeyBlock,
794 implementations)
795 serverPendingState.encContext = createCipherFunc(serverKeyBlock,
796 implementations)
797 clientPendingState.fixedNonce = clientIVBlock
798 serverPendingState.fixedNonce = serverIVBlock
799
800
801 if self.client:
802 self._pendingWriteState = clientPendingState
803 self._pendingReadState = serverPendingState
804 else:
805 self._pendingWriteState = serverPendingState
806 self._pendingReadState = clientPendingState
807
808 if self.version >= (3, 2) and ivLength:
809
810
811 self.fixedIVBlock = getRandomBytes(ivLength)
812