Source code for tlslite.keyexchange

# Authors:
#   Hubert Kario (2015)
#
# See the LICENSE file for legal information regarding use of this file.
"""Handling of cryptographic operations for key exchange"""

import ecdsa
from .mathtls import goodGroupParameters, makeK, makeU, makeX, \
        paramStrength, RFC7919_GROUPS, calc_key
from .errors import TLSInsufficientSecurity, TLSUnknownPSKIdentity, \
        TLSIllegalParameterException, TLSDecryptionFailed, TLSInternalError, \
        TLSDecodeError
from .messages import ServerKeyExchange, ClientKeyExchange, CertificateVerify
from .constants import SignatureAlgorithm, HashAlgorithm, CipherSuite, \
        ExtensionType, GroupName, ECCurveType, SignatureScheme, ECPointFormat
from .utils.ecc import getCurveByName, getPointByteSize
from .utils.rsakey import RSAKey
from .utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \
        numBits, numberToByteArray, divceil, numBytes, secureHash
from .utils.lists import getFirstMatching
from .utils import tlshashlib as hashlib
from .utils.x25519 import x25519, x448, X25519_G, X448_G, X25519_ORDER_SIZE, \
        X448_ORDER_SIZE
from .utils.compat import int_types, ML_KEM_AVAILABLE
from .utils.codec import DecodeError

if ML_KEM_AVAILABLE:
    from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024


[docs] class KeyExchange(object): """ Common API for calculating Premaster secret NOT stable, will get moved from this file """
[docs] def __init__(self, cipherSuite, clientHello, serverHello, privateKey=None): """Initialize KeyExchange. privateKey is the signing private key""" self.cipherSuite = cipherSuite self.clientHello = clientHello self.serverHello = serverHello self.privateKey = privateKey
[docs] def makeServerKeyExchange(self, sigHash=None): """ Create a ServerKeyExchange object Returns a ServerKeyExchange object for the server's initial leg in the handshake. If the key exchange method does not send ServerKeyExchange (e.g. RSA), it returns None. """ raise NotImplementedError()
[docs] def makeClientKeyExchange(self): """ Create a ClientKeyExchange object Returns a ClientKeyExchange for the second flight from client in the handshake. """ return ClientKeyExchange(self.cipherSuite, self.serverHello.server_version)
[docs] def processClientKeyExchange(self, clientKeyExchange): """ Process ClientKeyExchange and return premaster secret Processes the client's ClientKeyExchange message and returns the premaster secret. Raises TLSLocalAlert on error. """ raise NotImplementedError()
[docs] def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Process the server KEX and return premaster secret""" raise NotImplementedError()
def _tls12_sign_ecdsa_SKE(self, serverKeyExchange, sigHash=None): try: serverKeyExchange.hashAlg, serverKeyExchange.signAlg = \ getattr(SignatureScheme, sigHash) hashName = SignatureScheme.getHash(sigHash) except AttributeError: serverKeyExchange.hashAlg = getattr(HashAlgorithm, sigHash) serverKeyExchange.signAlg = SignatureAlgorithm.ecdsa hashName = sigHash hash_bytes = serverKeyExchange.hash(self.clientHello.random, self.serverHello.random) hash_bytes = hash_bytes[:self.privateKey.private_key.curve.baselen] serverKeyExchange.signature = \ self.privateKey.sign(hash_bytes, hashAlg=hashName) if not serverKeyExchange.signature: raise TLSInternalError("Empty signature") if not self.privateKey.verify(serverKeyExchange.signature, hash_bytes, ecdsa.util.sigdecode_der): raise TLSInternalError("signature validation failure") def _tls12_sign_dsa_SKE(self, serverKeyExchange, sigHash=None): """Sign a TLSv1.2 SKE message.""" try: serverKeyExchange.hashAlg, serverKeyExchange.signAlg = \ getattr(SignatureScheme, sigHash) except AttributeError: serverKeyExchange.signAlg = SignatureAlgorithm.dsa serverKeyExchange.hashAlg = getattr(HashAlgorithm, sigHash) hashBytes = serverKeyExchange.hash(self.clientHello.random, self.serverHello.random) serverKeyExchange.signature = \ self.privateKey.sign(hashBytes) if not serverKeyExchange.signature: raise TLSInternalError("Empty signature") if not self.privateKey.verify(serverKeyExchange.signature, hashBytes): raise TLSInternalError("Server Key Exchange signature invalid") def _tls12_sign_eddsa_ske(self, server_key_exchange, sig_hash): """Sign a TLSv1.2 SKE message.""" server_key_exchange.hashAlg, server_key_exchange.signAlg = \ getattr(SignatureScheme, sig_hash) pad_type = None hash_name = None salt_len = None hash_bytes = server_key_exchange.hash(self.clientHello.random, self.serverHello.random) server_key_exchange.signature = \ self.privateKey.hashAndSign(hash_bytes, pad_type, hash_name, salt_len) if not server_key_exchange.signature: raise TLSInternalError("Empty signature") if not self.privateKey.hashAndVerify( server_key_exchange.signature, hash_bytes, pad_type, hash_name, salt_len): raise TLSInternalError("Server Key Exchange signature invalid") def _tls12_signSKE(self, serverKeyExchange, sigHash=None): """Sign a TLSv1.2 SKE message.""" try: serverKeyExchange.hashAlg, serverKeyExchange.signAlg = \ getattr(SignatureScheme, sigHash) keyType = SignatureScheme.getKeyType(sigHash) padType = SignatureScheme.getPadding(sigHash) hashName = SignatureScheme.getHash(sigHash) saltLen = getattr(hashlib, hashName)().digest_size except AttributeError: serverKeyExchange.signAlg = SignatureAlgorithm.rsa serverKeyExchange.hashAlg = getattr(HashAlgorithm, sigHash) keyType = 'rsa' padType = 'pkcs1' hashName = sigHash saltLen = 0 assert keyType == 'rsa' hashBytes = serverKeyExchange.hash(self.clientHello.random, self.serverHello.random) serverKeyExchange.signature = \ self.privateKey.sign(hashBytes, padding=padType, hashAlg=hashName, saltLen=saltLen) if not serverKeyExchange.signature: raise TLSInternalError("Empty signature") if not self.privateKey.verify(serverKeyExchange.signature, hashBytes, padding=padType, hashAlg=hashName, saltLen=saltLen): raise TLSInternalError("Server Key Exchange signature invalid")
[docs] def signServerKeyExchange(self, serverKeyExchange, sigHash=None): """ Sign a server key exchange using default or specified algorithm :type sigHash: str :param sigHash: name of the signature hash to be used for signing """ if self.serverHello.server_version < (3, 3): if self.privateKey.key_type == "ecdsa": serverKeyExchange.signAlg = SignatureAlgorithm.ecdsa if self.privateKey.key_type == "dsa": serverKeyExchange.signAlg = SignatureAlgorithm.dsa hashBytes = serverKeyExchange.hash(self.clientHello.random, self.serverHello.random) serverKeyExchange.signature = self.privateKey.sign(hashBytes) if not serverKeyExchange.signature: raise TLSInternalError("Empty signature") if not self.privateKey.verify(serverKeyExchange.signature, hashBytes): raise TLSInternalError("Server Key Exchange signature invalid") else: if self.privateKey.key_type == "ecdsa": self._tls12_sign_ecdsa_SKE(serverKeyExchange, sigHash) elif self.privateKey.key_type == "dsa": self._tls12_sign_dsa_SKE(serverKeyExchange, sigHash) elif self.privateKey.key_type in ("Ed25519", "Ed448"): self._tls12_sign_eddsa_ske(serverKeyExchange, sigHash) else: self._tls12_signSKE(serverKeyExchange, sigHash)
@staticmethod def _tls12_verify_ecdsa_SKE(serverKeyExchange, publicKey, clientRandom, serverRandom, validSigAlgs): hashName = HashAlgorithm.toRepr(serverKeyExchange.hashAlg) if not hashName: raise TLSIllegalParameterException("Unknown hash algorithm") hashBytes = serverKeyExchange.hash(clientRandom, serverRandom) hashBytes = hashBytes[:publicKey.public_key.curve.baselen] if not publicKey.verify(serverKeyExchange.signature, hashBytes, padding=None, hashAlg=hashName, saltLen=None): raise TLSDecryptionFailed("Server Key Exchange signature " "invalid") @staticmethod def _tls12_verify_eddsa_ske(server_key_exchange, public_key, client_random, server_random, valid_sig_algs): """Verify SeverKeyExchange messages with EdDSA signatures.""" del valid_sig_algs sig_bytes = server_key_exchange.signature if not sig_bytes: raise TLSIllegalParameterException("Empty signature") hash_bytes = server_key_exchange.hash(client_random, server_random) if not public_key.hashAndVerify(sig_bytes, hash_bytes): raise TLSDecryptionFailed("Server Key Exchange signature invalid") @staticmethod def _tls12_verify_dsa_SKE(serverKeyExchange, publicKey, clientRandom, serverRandom, validSigAlgs): hashBytes = serverKeyExchange.hash(clientRandom, serverRandom) if not publicKey.verify(serverKeyExchange.signature, hashBytes): raise TLSDecryptionFailed("Server Key Exchange signature " "invalid") @staticmethod def _tls12_verify_SKE(serverKeyExchange, publicKey, clientRandom, serverRandom, validSigAlgs): """Verify TLSv1.2 version of SKE.""" if (serverKeyExchange.hashAlg, serverKeyExchange.signAlg) not in \ validSigAlgs: raise TLSIllegalParameterException("Server selected " "invalid signature " "algorithm") if (serverKeyExchange.hashAlg, serverKeyExchange.signAlg) in ( SignatureScheme.ed25519, SignatureScheme.ed448): return KeyExchange._tls12_verify_eddsa_ske(serverKeyExchange, publicKey, clientRandom, serverRandom, validSigAlgs) if serverKeyExchange.signAlg == SignatureAlgorithm.ecdsa: return KeyExchange._tls12_verify_ecdsa_SKE(serverKeyExchange, publicKey, clientRandom, serverRandom, validSigAlgs) elif serverKeyExchange.signAlg == SignatureAlgorithm.dsa: return KeyExchange._tls12_verify_dsa_SKE(serverKeyExchange, publicKey, clientRandom, serverRandom, validSigAlgs) schemeID = (serverKeyExchange.hashAlg, serverKeyExchange.signAlg) scheme = SignatureScheme.toRepr(schemeID) if scheme is not None: keyType = SignatureScheme.getKeyType(scheme) padType = SignatureScheme.getPadding(scheme) hashName = SignatureScheme.getHash(scheme) saltLen = getattr(hashlib, hashName)().digest_size else: if serverKeyExchange.signAlg != SignatureAlgorithm.rsa: raise TLSInternalError("non-RSA sigs are not supported") keyType = 'rsa' padType = 'pkcs1' saltLen = 0 hashName = HashAlgorithm.toRepr(serverKeyExchange.hashAlg) if hashName is None: msg = "Unknown hash ID: {0}"\ .format(serverKeyExchange.hashAlg) raise TLSIllegalParameterException(msg) assert keyType == 'rsa' hashBytes = serverKeyExchange.hash(clientRandom, serverRandom) sigBytes = serverKeyExchange.signature if not sigBytes: raise TLSIllegalParameterException("Empty signature") if not publicKey.verify(sigBytes, hashBytes, padding=padType, hashAlg=hashName, saltLen=saltLen): raise TLSDecryptionFailed("Server Key Exchange signature " "invalid")
[docs] @staticmethod def verifyServerKeyExchange(serverKeyExchange, publicKey, clientRandom, serverRandom, validSigAlgs): """Verify signature on the Server Key Exchange message the only acceptable signature algorithms are specified by validSigAlgs """ if serverKeyExchange.version < (3, 3): hashBytes = serverKeyExchange.hash(clientRandom, serverRandom) sigBytes = serverKeyExchange.signature if not sigBytes: raise TLSIllegalParameterException("Empty signature") if not publicKey.verify(sigBytes, hashBytes): raise TLSDecryptionFailed("Server Key Exchange signature " "invalid") else: KeyExchange._tls12_verify_SKE(serverKeyExchange, publicKey, clientRandom, serverRandom, validSigAlgs)
[docs] @staticmethod def calcVerifyBytes(version, handshakeHashes, signatureAlg, premasterSecret, clientRandom, serverRandom, prf_name = None, peer_tag=b'client', key_type="rsa"): """Calculate signed bytes for Certificate Verify""" if version == (3, 0): masterSecret = calc_key(version, premasterSecret, 0, b"master secret", client_random=clientRandom, server_random=serverRandom, output_length=48) verifyBytes = handshakeHashes.digestSSL(masterSecret, b"") elif version in ((3, 1), (3, 2)): if key_type != "ecdsa": verifyBytes = handshakeHashes.digest() else: verifyBytes = handshakeHashes.digest("sha1") elif version == (3, 3): if signatureAlg in (SignatureScheme.ed25519, SignatureScheme.ed448): hashName = "intrinsic" padding = None elif signatureAlg[1] == SignatureAlgorithm.dsa: hashName = HashAlgorithm.toRepr(signatureAlg[0]) padding = None elif signatureAlg[1] != SignatureAlgorithm.ecdsa: scheme = SignatureScheme.toRepr(signatureAlg) if scheme is None: hashName = HashAlgorithm.toRepr(signatureAlg[0]) padding = 'pkcs1' else: hashName = SignatureScheme.getHash(scheme) padding = SignatureScheme.getPadding(scheme) else: padding = None hashName = HashAlgorithm.toRepr(signatureAlg[0]) verifyBytes = handshakeHashes.digest(hashName) if padding == 'pkcs1': verifyBytes = RSAKey.addPKCS1Prefix(verifyBytes, hashName) elif version == (3, 4): scheme = SignatureScheme.toRepr(signatureAlg) if scheme: hash_name = SignatureScheme.getHash(scheme) else: # handles negative test cases when we try to pass in # schemes that are not supported in TLS1.3 hash_name = HashAlgorithm.toRepr(signatureAlg[0]) verifyBytes = bytearray(b'\x20' * 64 + b'TLS 1.3, ' + peer_tag + b' CertificateVerify' + b'\x00') + \ handshakeHashes.digest(prf_name) if hash_name != "intrinsic": verifyBytes = secureHash(verifyBytes, hash_name) else: raise ValueError("Unsupported TLS version {0}".format(version)) return verifyBytes
[docs] @staticmethod def makeCertificateVerify(version, handshakeHashes, validSigAlgs, privateKey, certificateRequest, premasterSecret, clientRandom, serverRandom): """Create a Certificate Verify message :param version: protocol version in use :param handshakeHashes: the running hash of all handshake messages :param validSigAlgs: acceptable signature algorithms for client side, applicable only to TLSv1.2 (or later) :param certificateRequest: the server provided Certificate Request message :param premasterSecret: the premaster secret, needed only for SSLv3 :param clientRandom: client provided random value, needed only for SSLv3 :param serverRandom: server provided random value, needed only for SSLv3 """ signatureAlgorithm = None if privateKey.key_type == "ecdsa" and version < (3, 3): signatureAlgorithm = (HashAlgorithm.sha1, SignatureAlgorithm.ecdsa) # in TLS 1.2 we must decide which algorithm to use for signing if version == (3, 3): serverSigAlgs = certificateRequest.supported_signature_algs signatureAlgorithm = getFirstMatching(validSigAlgs, serverSigAlgs) # if none acceptable, do a last resort: if signatureAlgorithm is None: signatureAlgorithm = validSigAlgs[0] verifyBytes = KeyExchange.calcVerifyBytes(version, handshakeHashes, signatureAlgorithm, premasterSecret, clientRandom, serverRandom, key_type=privateKey.key_type) if signatureAlgorithm and signatureAlgorithm in ( SignatureScheme.ed25519, SignatureScheme.ed448): padding = None hashName = "intrinsic" saltLen = None sig_func = privateKey.hashAndSign ver_func = privateKey.hashAndVerify elif signatureAlgorithm and \ signatureAlgorithm[1] == SignatureAlgorithm.ecdsa: padding = None hashName = HashAlgorithm.toRepr(signatureAlgorithm[0]) saltLen = None verifyBytes = verifyBytes[:privateKey.private_key.curve.baselen] sig_func = privateKey.sign ver_func = privateKey.verify elif signatureAlgorithm and \ signatureAlgorithm[1] == SignatureAlgorithm.dsa: padding = None hashName = HashAlgorithm.toRepr(signatureAlgorithm[0]) saltLen = None sig_func = privateKey.sign ver_func = privateKey.verify else: scheme = SignatureScheme.toRepr(signatureAlgorithm) # for pkcs1 signatures hash is used to add PKCS#1 prefix, but # that was already done by calcVerifyBytes hashName = None saltLen = 0 if scheme is None: padding = 'pkcs1' else: padding = SignatureScheme.getPadding(scheme) if padding == 'pss': hashName = SignatureScheme.getHash(scheme) saltLen = getattr(hashlib, hashName)().digest_size sig_func = privateKey.sign ver_func = privateKey.verify signedBytes = sig_func(verifyBytes, padding, hashName, saltLen) if not ver_func(signedBytes, verifyBytes, padding, hashName, saltLen): raise TLSInternalError("Certificate Verify signature invalid") certificateVerify = CertificateVerify(version) certificateVerify.create(signedBytes, signatureAlgorithm) return certificateVerify
[docs] class AuthenticatedKeyExchange(KeyExchange): """ Common methods for key exchanges that authenticate Server Key Exchange Methods for signing Server Key Exchange message """
[docs] def makeServerKeyExchange(self, sigHash=None): """Prepare server side of key exchange with selected parameters""" ske = super(AuthenticatedKeyExchange, self).makeServerKeyExchange() self.signServerKeyExchange(ske, sigHash) return ske
[docs] class RSAKeyExchange(KeyExchange): """ Handling of RSA key exchange NOT stable API, do NOT use """
[docs] def __init__(self, cipherSuite, clientHello, serverHello, privateKey): super(RSAKeyExchange, self).__init__(cipherSuite, clientHello, serverHello, privateKey) self.encPremasterSecret = None
[docs] def makeServerKeyExchange(self, sigHash=None): """Don't create a server key exchange for RSA key exchange""" return None
[docs] def processClientKeyExchange(self, clientKeyExchange): """Decrypt client key exchange, return premaster secret""" premasterSecret = self.privateKey.decrypt(\ clientKeyExchange.encryptedPreMasterSecret) # On decryption failure randomize premaster secret to avoid # Bleichenbacher's "million message" attack randomPreMasterSecret = getRandomBytes(48) if not premasterSecret: premasterSecret = randomPreMasterSecret elif len(premasterSecret) != 48: premasterSecret = randomPreMasterSecret else: versionCheck = (premasterSecret[0], premasterSecret[1]) if versionCheck != self.clientHello.client_version: #Tolerate buggy IE clients if versionCheck != self.serverHello.server_version: premasterSecret = randomPreMasterSecret return premasterSecret
[docs] def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Generate premaster secret for server""" del serverKeyExchange # not present in RSA key exchange premasterSecret = getRandomBytes(48) premasterSecret[0] = self.clientHello.client_version[0] premasterSecret[1] = self.clientHello.client_version[1] self.encPremasterSecret = srvPublicKey.encrypt(premasterSecret) return premasterSecret
[docs] def makeClientKeyExchange(self): """Return a client key exchange with clients key share""" clientKeyExchange = super(RSAKeyExchange, self).makeClientKeyExchange() clientKeyExchange.createRSA(self.encPremasterSecret) return clientKeyExchange
[docs] class ADHKeyExchange(KeyExchange): """ Handling of anonymous Diffie-Hellman Key exchange FFDHE without signing serverKeyExchange useful for anonymous DH """
[docs] def __init__(self, cipherSuite, clientHello, serverHello, dhParams=None, dhGroups=None): super(ADHKeyExchange, self).__init__(cipherSuite, clientHello, serverHello) #pylint: enable = invalid-name self.dh_Xs = None self.dh_Yc = None if dhParams: self.dh_g, self.dh_p = dhParams else: # 2048-bit MODP Group (RFC 5054, group 3) self.dh_g, self.dh_p = goodGroupParameters[2] self.dhGroups = dhGroups
[docs] def makeServerKeyExchange(self): """ Prepare server side of anonymous key exchange with selected parameters """ # Check for RFC 7919 support ext = self.clientHello.getExtension(ExtensionType.supported_groups) if ext and self.dhGroups: commonGroup = getFirstMatching(ext.groups, self.dhGroups) if commonGroup: self.dh_g, self.dh_p = RFC7919_GROUPS[commonGroup - 256] elif getFirstMatching(ext.groups, range(256, 512)): raise TLSInternalError("DHE key exchange attempted despite no " "overlap between supported groups") # for TLS < 1.3 we need special algorithm to select params (see above) # so do not pass in the group, if we selected one kex = FFDHKeyExchange(None, self.serverHello.server_version, self.dh_g, self.dh_p) self.dh_Xs = kex.get_random_private_key() dh_Ys = kex.calc_public_value(self.dh_Xs) version = self.serverHello.server_version serverKeyExchange = ServerKeyExchange(self.cipherSuite, version) serverKeyExchange.createDH(self.dh_p, self.dh_g, dh_Ys) # No sign for anonymous ServerKeyExchange. return serverKeyExchange
[docs] def processClientKeyExchange(self, clientKeyExchange): """Use client provided parameters to establish premaster secret""" dh_Yc = clientKeyExchange.dh_Yc kex = FFDHKeyExchange(None, self.serverHello.server_version, self.dh_g, self.dh_p) return kex.calc_shared_key(self.dh_Xs, dh_Yc)
[docs] def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Process the server key exchange, return premaster secret.""" del srvPublicKey dh_p = serverKeyExchange.dh_p # TODO make the minimum changeable if dh_p < 2**1023: raise TLSInsufficientSecurity("DH prime too small") dh_g = serverKeyExchange.dh_g dh_Ys = serverKeyExchange.dh_Ys kex = FFDHKeyExchange(None, self.serverHello.server_version, dh_g, dh_p) dh_Xc = kex.get_random_private_key() self.dh_Yc = kex.calc_public_value(dh_Xc) return kex.calc_shared_key(dh_Xc, dh_Ys)
[docs] def makeClientKeyExchange(self): """Create client key share for the key exchange""" cke = super(ADHKeyExchange, self).makeClientKeyExchange() cke.createDH(self.dh_Yc) return cke
# the DHE_RSA part comes from IETF ciphersuite names, we want to keep it #pylint: disable = invalid-name
[docs] class DHE_RSAKeyExchange(AuthenticatedKeyExchange, ADHKeyExchange): """ Handling of authenticated ephemeral Diffe-Hellman Key exchange. """
[docs] def __init__(self, cipherSuite, clientHello, serverHello, privateKey, dhParams=None, dhGroups=None): """ Create helper object for Diffie-Hellamn key exchange. :param dhParams: Diffie-Hellman parameters that will be used by server. First element of the tuple is the generator, the second is the prime. If not specified it will use a secure set (currently a 2048-bit safe prime). :type dhParams: 2-element tuple of int """ super(DHE_RSAKeyExchange, self).__init__(cipherSuite, clientHello, serverHello, dhParams, dhGroups) #pylint: enable = invalid-name self.privateKey = privateKey
[docs] class AECDHKeyExchange(KeyExchange): """ Handling of anonymous Eliptic curve Diffie-Hellman Key exchange ECDHE without signing serverKeyExchange useful for anonymous ECDH """
[docs] def __init__(self, cipherSuite, clientHello, serverHello, acceptedCurves, defaultCurve=GroupName.secp256r1): super(AECDHKeyExchange, self).__init__(cipherSuite, clientHello, serverHello) self.ecdhXs = None self.acceptedCurves = acceptedCurves self.group_id = None self.ecdhYc = None self.defaultCurve = defaultCurve
[docs] def makeServerKeyExchange(self, sigHash=None): """Create AECDHE version of Server Key Exchange""" #Get client supported groups client_curves = self.clientHello.getExtension( ExtensionType.supported_groups) if client_curves is None: # in case there is no extension, we can pick any curve, # use the configured one client_curves = [self.defaultCurve] elif not client_curves.groups: # extension should have been validated before raise TLSInternalError("Can't do ECDHE with no client curves") else: client_curves = client_curves.groups #Pick first client preferred group we support self.group_id = getFirstMatching(client_curves, self.acceptedCurves) if self.group_id is None: raise TLSInsufficientSecurity("No mutual groups") kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) self.ecdhXs = kex.get_random_private_key() ext_negotiated = ECPointFormat.uncompressed ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) if ext_c and ext_s: try: ext_negotiated = next((i for i in ext_c.formats \ if i in ext_s.formats)) except StopIteration: raise TLSIllegalParameterException("No common EC point format") ext_negotiated = 'uncompressed' if \ ext_negotiated == ECPointFormat.uncompressed else 'compressed' ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated) version = self.serverHello.server_version serverKeyExchange = ServerKeyExchange(self.cipherSuite, version) serverKeyExchange.createECDH(ECCurveType.named_curve, named_curve=self.group_id, point=ecdhYs) # No sign for anonymous ServerKeyExchange return serverKeyExchange
[docs] def processClientKeyExchange(self, clientKeyExchange): """Calculate premaster secret from previously generated SKE and CKE""" ecdhYc = clientKeyExchange.ecdh_Yc if not ecdhYc: raise TLSDecodeError("No key share") kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) ext_supported = [ECPointFormat.uncompressed] ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) if ext_c and ext_s: ext_supported = [ ext for ext in ext_c.formats if ext in ext_s.formats ] if not ext_supported: raise TLSIllegalParameterException("No common EC point format") ext_supported = map( lambda x: 'uncompressed' if x == ECPointFormat.uncompressed else 'compressed', ext_supported ) return kex.calc_shared_key(self.ecdhXs, ecdhYc, set(ext_supported))
[docs] def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Process the server key exchange, return premaster secret""" del srvPublicKey if serverKeyExchange.curve_type != ECCurveType.named_curve \ or serverKeyExchange.named_curve not in self.acceptedCurves: raise TLSIllegalParameterException("Server picked curve we " "didn't advertise") ecdh_Ys = serverKeyExchange.ecdh_Ys if not ecdh_Ys: raise TLSDecodeError("Empty server key share") kex = ECDHKeyExchange(serverKeyExchange.named_curve, self.serverHello.server_version) ecdhXc = kex.get_random_private_key() ext_negotiated = ECPointFormat.uncompressed ext_supported = [ECPointFormat.uncompressed] if self.clientHello: ext_c = self.clientHello.getExtension( ExtensionType.ec_point_formats) ext_s = self.serverHello.getExtension( ExtensionType.ec_point_formats) if ext_c and ext_s: try: ext_supported = [ i for i in ext_c.formats if i in ext_s.formats ] ext_negotiated = ext_supported[0] except IndexError: raise TLSIllegalParameterException( "No common EC point format") ext_negotiated = 'uncompressed' if \ ext_negotiated == ECPointFormat.uncompressed else 'compressed' ext_supported = map( lambda x: 'uncompressed' if x == ECPointFormat.uncompressed else 'compressed', ext_supported ) self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated) return kex.calc_shared_key(ecdhXc, ecdh_Ys, set(ext_supported))
[docs] def makeClientKeyExchange(self): """Make client key exchange for ECDHE""" cke = super(AECDHKeyExchange, self).makeClientKeyExchange() cke.createECDH(self.ecdhYc) return cke
# The ECDHE_RSA part comes from the IETF names of ciphersuites, so we want to # keep it #pylint: disable = invalid-name
[docs] class ECDHE_RSAKeyExchange(AuthenticatedKeyExchange, AECDHKeyExchange): """Helper class for conducting ECDHE key exchange"""
[docs] def __init__(self, cipherSuite, clientHello, serverHello, privateKey, acceptedCurves, defaultCurve=GroupName.secp256r1): super(ECDHE_RSAKeyExchange, self).__init__(cipherSuite, clientHello, serverHello, acceptedCurves, defaultCurve) #pylint: enable = invalid-name self.privateKey = privateKey
[docs] class SRPKeyExchange(KeyExchange): """Helper class for conducting SRP key exchange"""
[docs] def __init__(self, cipherSuite, clientHello, serverHello, privateKey, verifierDB, srpUsername=None, password=None, settings=None): """Link Key Exchange options with verifierDB for SRP""" super(SRPKeyExchange, self).__init__(cipherSuite, clientHello, serverHello, privateKey) self.N = None self.v = None self.b = None self.B = None self.verifierDB = verifierDB self.A = None self.srpUsername = srpUsername self.password = password self.settings = settings if srpUsername is not None and not isinstance(srpUsername, bytearray): raise TypeError("srpUsername must be a bytearray object") if password is not None and not isinstance(password, bytearray): raise TypeError("password must be a bytearray object")
[docs] def makeServerKeyExchange(self, sigHash=None): """Create SRP version of Server Key Exchange""" srpUsername = bytes(self.clientHello.srp_username) #Get parameters from username try: entry = self.verifierDB[srpUsername] except KeyError: raise TLSUnknownPSKIdentity("Unknown identity") (self.N, g, s, self.v) = entry #Calculate server's ephemeral DH values (b, B) self.b = bytesToNumber(getRandomBytes(32)) k = makeK(self.N, g) self.B = (powMod(g, self.b, self.N) + (k * self.v)) % self.N #Create ServerKeyExchange, signing it if necessary serverKeyExchange = ServerKeyExchange(self.cipherSuite, self.serverHello.server_version) serverKeyExchange.createSRP(self.N, g, s, self.B) if self.cipherSuite in CipherSuite.srpCertSuites: self.signServerKeyExchange(serverKeyExchange, sigHash) return serverKeyExchange
[docs] def processClientKeyExchange(self, clientKeyExchange): """Calculate premaster secret from Client Key Exchange and sent SKE""" A = clientKeyExchange.srp_A if A % self.N == 0: raise TLSIllegalParameterException("Invalid SRP A value") #Calculate u u = makeU(self.N, A, self.B) #Calculate premaster secret S = powMod((A * powMod(self.v, u, self.N)) % self.N, self.b, self.N) return numberToByteArray(S)
[docs] def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Calculate premaster secret from ServerKeyExchange""" del srvPublicKey # irrelevant for SRP N = serverKeyExchange.srp_N g = serverKeyExchange.srp_g s = serverKeyExchange.srp_s B = serverKeyExchange.srp_B if (g, N) not in goodGroupParameters: raise TLSInsufficientSecurity("Unknown group parameters") if numBits(N) < self.settings.minKeySize: raise TLSInsufficientSecurity("N value is too small: {0}".\ format(numBits(N))) if numBits(N) > self.settings.maxKeySize: raise TLSInsufficientSecurity("N value is too large: {0}".\ format(numBits(N))) if B % N == 0: raise TLSIllegalParameterException("Suspicious B value") #Client ephemeral value a = bytesToNumber(getRandomBytes(32)) self.A = powMod(g, a, N) #Calculate client's static DH values (x, v) x = makeX(s, self.srpUsername, self.password) v = powMod(g, x, N) #Calculate u u = makeU(N, self.A, B) #Calculate premaster secret k = makeK(N, g) S = powMod((B - (k*v)) % N, a+(u*x), N) return numberToByteArray(S)
[docs] def makeClientKeyExchange(self): """Create ClientKeyExchange""" cke = super(SRPKeyExchange, self).makeClientKeyExchange() cke.createSRP(self.A) return cke
[docs] class RawDHKeyExchange(object): """ Abstract class for performing Diffe-Hellman key exchange. Provides a shared API for X25519, ECDHE and FFDHE key exchange. """
[docs] def __init__(self, group, version): """ Set the parameters of the key exchange Sets group on which the KEX will take part and protocol version used. """ self.group = group self.version = version
[docs] def get_random_private_key(self): """ Generate a random value suitable for use as the private value of KEX. """ raise NotImplementedError("Abstract class")
[docs] def calc_public_value(self, private, point_format=None): """Calculate the public value from the provided private value.""" raise NotImplementedError("Abstract class")
[docs] def calc_shared_key(self, private, peer_share, valid_point_formats=None): """Calcualte the shared key given our private and remote share value""" raise NotImplementedError("Abstract class")
[docs] class FFDHKeyExchange(RawDHKeyExchange): """Implemenation of the Finite Field Diffie-Hellman key exchange."""
[docs] def __init__(self, group, version, generator=None, prime=None): super(FFDHKeyExchange, self).__init__(group, version) if prime and group: raise ValueError("Can't set the RFC7919 group and custom params" " at the same time") if group: self.generator, self.prime = RFC7919_GROUPS[group-256] else: self.prime = prime self.generator = generator if not 1 < self.generator < self.prime: raise TLSIllegalParameterException("Invalid DH generator")
[docs] def get_random_private_key(self): """ Return a random private value for the prime used. :rtype: int """ # Per RFC 3526, Section 1, the exponent should have double the entropy # of the strength of the group. needed_bytes = divceil(paramStrength(self.prime) * 2, 8) return bytesToNumber(getRandomBytes(needed_bytes))
[docs] def calc_public_value(self, private, point_format=None): """ Calculate the public value for given private value. :param point_format: ignored, used for compatibility with ECDH groups :rtype: int """ dh_Y = powMod(self.generator, private, self.prime) if dh_Y in (1, self.prime - 1): raise TLSIllegalParameterException("Small subgroup capture") if self.version < (3, 4): return dh_Y else: return numberToByteArray(dh_Y, numBytes(self.prime))
def _normalise_peer_share(self, peer_share): """Convert the peer_share to number if necessary.""" if isinstance(peer_share, (int_types)): return peer_share if numBytes(self.prime) != len(peer_share): raise TLSIllegalParameterException( "Key share does not match FFDH prime") return bytesToNumber(peer_share)
[docs] def calc_shared_key(self, private, peer_share, valid_point_formats=None): """Calculate the shared key. :param valid_point_formats: ignored, used for compatibility with ECDH groups :rtype: bytearray""" peer_share = self._normalise_peer_share(peer_share) # First half of RFC 2631, Section 2.1.5. Validate the client's public # key. # use of safe primes also means that the p-1 is invalid if not 2 <= peer_share < self.prime - 1: raise TLSIllegalParameterException("Invalid peer key share") S = powMod(peer_share, private, self.prime) if S in (1, self.prime - 1): raise TLSIllegalParameterException("Small subgroup capture") if self.version < (3, 4): return numberToByteArray(S) else: return numberToByteArray(S, numBytes(self.prime))
[docs] class ECDHKeyExchange(RawDHKeyExchange): """Implementation of the Elliptic Curve Diffie-Hellman key exchange.""" _x_groups = set((GroupName.x25519, GroupName.x448)) @staticmethod def _non_zero_check(value): """ Verify using constant time operation that the bytearray is not zero :raises TLSIllegalParameterException: if the value is all zero """ summa = 0 for i in value: summa |= i if summa == 0: raise TLSIllegalParameterException("Invalid key share")
[docs] def __init__(self, group, version): super(ECDHKeyExchange, self).__init__(group, version)
[docs] def get_random_private_key(self): """Return random private key value for the selected curve.""" if self.group in self._x_groups: if self.group == GroupName.x25519: return getRandomBytes(X25519_ORDER_SIZE) else: return getRandomBytes(X448_ORDER_SIZE) else: curve = getCurveByName(GroupName.toStr(self.group)) return ecdsa.keys.SigningKey.generate(curve)
def _get_fun_gen_size(self): """Return the function and generator for X25519/X448 KEX.""" if self.group == GroupName.x25519: return x25519, bytearray(X25519_G), X25519_ORDER_SIZE else: return x448, bytearray(X448_G), X448_ORDER_SIZE
[docs] def calc_public_value(self, private, point_format='uncompressed'): """ Calculate public value for given private key. :param private: Private key for the selected key exchange group. :param str point_format: The point format to use for the ECDH public key. Applies only to NIST curves. """ if isinstance(private, ecdsa.keys.SigningKey): return private.verifying_key.to_string(point_format) if self.group in self._x_groups: fun, generator, _ = self._get_fun_gen_size() return fun(private, generator) else: curve = getCurveByName(GroupName.toStr(self.group)) point = curve.generator * private return bytearray(point.to_bytes(point_format))
[docs] def calc_shared_key(self, private, peer_share, valid_point_formats=('uncompressed',)): """ Calculate the shared key. :param bytearray | SigningKey private: private value :param bytearray peer_share: public value :param set(str) valid_point_formats: list of point formats that the peer share can be in; ["uncompressed"] by default. :rtype: bytearray :returns: shared key :raises TLSIllegalParameterException when the paramentrs for point are invalid. :raises TLSDecodeError when the the valid_point_formats is empty. """ if self.group in self._x_groups: fun, _, size = self._get_fun_gen_size() if len(peer_share) != size: raise TLSIllegalParameterException("Invalid key share") if isinstance(private, ecdsa.keys.SigningKey): private = bytesToNumber(private.to_string()) S = fun(private, peer_share) self._non_zero_check(S) return S curve = getCurveByName(GroupName.toRepr(self.group)) try: abstractPoint = ecdsa.ellipticcurve.AbstractPoint() point = abstractPoint.from_bytes( curve.curve, peer_share, valid_encodings=valid_point_formats) ecdhYc = ecdsa.ellipticcurve.Point( curve.curve, point[0], point[1]) except AssertionError: raise TLSIllegalParameterException("Invalid ECC point") except DecodeError: raise TLSDecodeError("Empty point formats extension") if isinstance(private, ecdsa.keys.SigningKey): ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private) ecdh.load_received_public_key_bytes(peer_share, valid_encodings= valid_point_formats) return bytearray(ecdh.generate_sharedsecret_bytes()) S = ecdhYc * private return numberToByteArray(S.x(), getPointByteSize(ecdhYc))
[docs] class KEMKeyExchange(object): """ Implementation of the Hybrid KEM key exchange groups. Caution, KEMs are not symmetric! While they client calls the same get_random_private_key(), calc_public_value(), and calc_shared_key() as in FFDH or ECDH, the server calls just the encapsulate_key() method. """
[docs] def __init__(self, group, version): if not ML_KEM_AVAILABLE: raise TLSInternalError("kyber-py library not installed!") self.group = group assert version == (3, 4) del version if self.group not in GroupName.allKEM: raise TLSInternalError("called with wrong group") if self.group == GroupName.secp256r1mlkem768: self._classic_group = GroupName.secp256r1 elif self.group == GroupName.x25519mlkem768: self._classic_group = GroupName.x25519 else: assert self.group == GroupName.secp384r1mlkem1024 self._classic_group = GroupName.secp384r1
[docs] def get_random_private_key(self): """ Generates a random value to be used as the private key in KEM. To be used only to generate the KeyShare in ClientHello. """ if self.group not in GroupName.allKEM: raise TLSInternalError("called with wrong group") if self.group in (GroupName.secp256r1mlkem768, GroupName.x25519mlkem768): pqc_pub_key, pqc_priv_key = ML_KEM_768.keygen() else: pqc_pub_key, pqc_priv_key = ML_KEM_1024.keygen() classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) classic_key = classic_kex.get_random_private_key() return ((pqc_pub_key, pqc_priv_key), classic_key)
[docs] def calc_public_value(self, private, point_format='uncompressed'): """ Extract public values for the private key. To be used only to generate the KeyShare in ClientHello. :param str point_format: Point format of the ECDH portion of the key exchange (effective only for NIST curves, valid is 'uncompressed' only) """ classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) classic_pub_key_share = classic_kex.calc_public_value( private[1], point_format=point_format) if self.group == GroupName.x25519mlkem768: return private[0][0] + classic_pub_key_share return classic_pub_key_share + private[0][0]
@staticmethod def _split_key_shares(public, pqc_first, pqc_key_len, classic_key_len): if len(public) != classic_key_len + pqc_key_len: raise TLSIllegalParameterException( "Invalid key size for the selected group. " "Expected: {0}, received: {1}".format( classic_key_len + pqc_key_len, len(public))) if pqc_first: pqc_key = public[:pqc_key_len] classic_key_share = bytearray(public[pqc_key_len:]) else: classic_key_share = bytearray(public[:classic_key_len]) pqc_key = public[classic_key_len:] return pqc_key, classic_key_share def _group_to_params(self): """Returns a tuple: classic_key_len, pqc_ek_key_len, pqc_ciphertext_len, pqc_first, ML_KEM """ if self.group == GroupName.secp256r1mlkem768: classic_key_len = 65 pqc_key_len = 1184 pqc_ciphertext_len = 1088 pqc_first = False ml_kem = ML_KEM_768 elif self.group == GroupName.x25519mlkem768: classic_key_len = 32 pqc_key_len = 1184 pqc_ciphertext_len = 1088 pqc_first = True ml_kem = ML_KEM_768 else: assert self.group == GroupName.secp384r1mlkem1024 classic_key_len = 97 pqc_key_len = 1568 pqc_ciphertext_len = 1568 pqc_first = False ml_kem = ML_KEM_1024 return classic_key_len, pqc_key_len, pqc_ciphertext_len, pqc_first, \ ml_kem
[docs] def encapsulate_key(self, public): """ Generate a random secret, encapsulate it given the public key, and return both the random secret and encapsulation of it. To be used for generation of KeyShare in ServerHello. """ classic_key_len, pqc_key_len, _, pqc_first, ml_kem = \ self._group_to_params() pqc_key, classic_key_share = self._split_key_shares( public, pqc_first, pqc_key_len, classic_key_len) classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) classic_key = classic_kex.get_random_private_key() classic_my_key_share = classic_kex.calc_public_value(classic_key) classic_shared_secret = classic_kex.calc_shared_key( classic_key, classic_key_share) try: pqc_shared_secret, pqc_encaps = ml_kem.encaps(pqc_key) except ValueError: raise TLSIllegalParameterException( "Invalid PQC key from peer") if pqc_first: shared_secret = pqc_shared_secret + classic_shared_secret key_encapsulation = pqc_encaps + classic_my_key_share else: shared_secret = classic_shared_secret + pqc_shared_secret key_encapsulation = classic_my_key_share + pqc_encaps return shared_secret, key_encapsulation
[docs] def calc_shared_key(self, private, key_encaps): """ Decapsulate the key share received from server. """ classic_key_len, _, pqc_key_len, pqc_first, ml_kem = \ self._group_to_params() pqc_key, classic_key_share = self._split_key_shares( key_encaps, pqc_first, pqc_key_len, classic_key_len) classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) classic_shared_secret = classic_kex.calc_shared_key( private[1], classic_key_share) try: pqc_shared_secret = ml_kem.decaps(private[0][1], pqc_key) except ValueError: raise TLSIllegalParameterException( "Error in KEM decapsulation") if pqc_first: shared_secret = pqc_shared_secret + classic_shared_secret else: shared_secret = classic_shared_secret + pqc_shared_secret return shared_secret