1
2
3
4 """Abstract class for RSA."""
5
6 from .cryptomath import *
7 from .poly1305 import Poly1305
8 from . import tlshashlib as hashlib
9 from ..errors import MaskTooLongError, MessageTooLongError, EncodingError, \
10 InvalidSignature, UnknownRSAType
14 """This is an abstract base class for RSA keys.
15
16 Particular implementations of RSA keys, such as
17 L{openssl_rsakey.OpenSSL_RSAKey},
18 L{python_rsakey.Python_RSAKey}, and
19 L{pycrypto_rsakey.PyCrypto_RSAKey},
20 inherit from this.
21
22 To create or parse an RSA key, don't use one of these classes
23 directly. Instead, use the factory functions in
24 L{tlslite.utils.keyfactory}.
25 """
26
28 """Create a new RSA key.
29
30 If n and e are passed in, the new key will be initialized.
31
32 @type n: int
33 @param n: RSA modulus.
34
35 @type e: int
36 @param e: RSA public exponent.
37 """
38 raise NotImplementedError()
39
41 """Return the length of this key in bits.
42
43 @rtype: int
44 """
45 return numBits(self.n)
46
48 """Return whether or not this key has a private component.
49
50 @rtype: bool
51 """
52 raise NotImplementedError()
53
54 - def hashAndSign(self, bytes, rsaScheme='PKCS1', hAlg='sha1', sLen=0):
55 """Hash and sign the passed-in bytes.
56
57 This requires the key to have a private component. It performs
58 a PKCS1 or PSS signature on the passed-in data with selected hash
59 algorithm.
60
61 @type bytes: str or L{bytearray} of unsigned bytes
62 @param bytes: The value which will be hashed and signed.
63
64 @type rsaScheme: str
65 @param rsaScheme: The type of RSA scheme that will be applied,
66 "PKCS1" for RSASSA-PKCS#1 v1.5 signature and "PSS"
67 for RSASSA-PSS with MGF1 signature method
68
69 @type hAlg: str
70 @param hAlg: The hash algorithm that will be used
71
72 @type sLen: int
73 @param sLen: The length of intended salt value, applicable only
74 for RSASSA-PSS signatures
75
76 @rtype: L{bytearray} of unsigned bytes.
77 @return: A PKCS1 or PSS signature on the passed-in data.
78 """
79 if rsaScheme == "PKCS1":
80 hashBytes = secureHash(bytearray(bytes), hAlg)
81 prefixedHashBytes = self.addPKCS1Prefix(hashBytes, hAlg)
82 sigBytes = self.sign(prefixedHashBytes)
83 elif rsaScheme == "PSS":
84 sigBytes = self.RSASSA_PSS_sign(bytearray(bytes), hAlg, sLen)
85 else:
86 raise UnknownRSAType("Unknown RSA algorithm type")
87 return sigBytes
88
89 - def hashAndVerify(self, sigBytes, bytes, rsaScheme='PKCS1', hAlg='sha1',
90 sLen=0):
91 """Hash and verify the passed-in bytes with the signature.
92
93 This verifies a PKCS1 or PSS signature on the passed-in data
94 with selected hash algorithm.
95
96 @type sigBytes: L{bytearray} of unsigned bytes
97 @param sigBytes: A PKCS1 or PSS signature.
98
99 @type bytes: str or L{bytearray} of unsigned bytes
100 @param bytes: The value which will be hashed and verified.
101
102 @type rsaScheme: str
103 @param rsaScheme: The type of RSA scheme that will be applied,
104 "PKCS1" for RSASSA-PKCS#1 v1.5 signature and "PSS"
105 for RSASSA-PSS with MGF1 signature method
106
107 @type hAlg: str
108 @param hAlg: The hash algorithm that will be used
109
110 @type sLen: int
111 @param sLen: The length of intended salt value, applicable only
112 for RSASSA-PSS signatures
113
114 @rtype: bool
115 @return: Whether the signature matches the passed-in data.
116 """
117
118
119 if rsaScheme == "PKCS1" and hAlg == 'sha1':
120 hashBytes = secureHash(bytearray(bytes), hAlg)
121 prefixedHashBytes1 = self.addPKCS1SHA1Prefix(hashBytes, False)
122 prefixedHashBytes2 = self.addPKCS1SHA1Prefix(hashBytes, True)
123 result1 = self.verify(sigBytes, prefixedHashBytes1)
124 result2 = self.verify(sigBytes, prefixedHashBytes2)
125 return (result1 or result2)
126 elif rsaScheme == 'PKCS1':
127 hashBytes = secureHash(bytearray(bytes), hAlg)
128 prefixedHashBytes = self.addPKCS1Prefix(hashBytes, hAlg)
129 r = self.verify(sigBytes, prefixedHashBytes)
130 return r
131 elif rsaScheme == "PSS":
132 r = self.RSASSA_PSS_verify(bytearray(bytes), sigBytes, hAlg, sLen)
133 return r
134 else:
135 raise UnknownRSAType("Unknown RSA algorithm type")
136
137 - def MGF1(self, mgfSeed, maskLen, hAlg):
138 """Generate mask from passed-in seed.
139
140 This generates mask based on passed-in seed and output maskLen.
141
142 @type mgfSeed: L{bytearray}
143 @param mgfSeed: Seed from which mask will be generated.
144
145 @type maskLen: int
146 @param maskLen: Wished length of the mask, in octets
147
148 @rtype: L{bytearray}
149 @return: Mask
150 """
151 hashLen = getattr(hashlib, hAlg)().digest_size
152 if maskLen > (2 ** 32) * hashLen:
153 raise MaskTooLongError("Incorrect parameter maskLen")
154 T = bytearray()
155 end = (Poly1305.divceil(maskLen, hashLen))
156 for x in range(0, end):
157 C = numberToByteArray(x, 4)
158 T += secureHash(mgfSeed + C, hAlg)
159 return T[:maskLen]
160
162 """Encode the passed in message
163
164 This encodes the message using selected hash algorithm
165
166 @type M: bytearray
167 @param M: Message to be encoded
168
169 @type emBits: int
170 @param emBits: maximal length of returned EM
171
172 @type hAlg: str
173 @param hAlg: hash algorithm to be used
174
175 @type sLen: int
176 @param sLen: length of salt"""
177 hashLen = getattr(hashlib, hAlg)().digest_size
178 mHash = secureHash(M, hAlg)
179 emLen = Poly1305.divceil(emBits, 8)
180 if emLen < hashLen + sLen + 2:
181 raise EncodingError("The ending limit too short for " +
182 "selected hash and salt length")
183 salt = getRandomBytes(sLen)
184 M2 = bytearray(8) + mHash + salt
185 H = secureHash(M2, hAlg)
186 PS = bytearray(emLen - sLen - hashLen - 2)
187 DB = PS + bytearray(b'\x01') + salt
188 dbMask = self.MGF1(H, emLen - hashLen - 1, hAlg)
189 maskedDB = bytearray(i ^ j for i, j in zip(DB, dbMask))
190 mLen = emLen*8 - emBits
191 mask = (1 << 8 - mLen) - 1
192 maskedDB[0] &= mask
193 EM = maskedDB + H + bytearray(b'\xbc')
194 return EM
195
197 """"Sign the passed in message
198
199 This signs the message using selected hash algorithm
200
201 @type M: bytearray
202 @param M: Message to be signed
203
204 @type hAlg: str
205 @param hAlg: hash algorithm to be used
206
207 @type sLen: int
208 @param sLen: length of salt"""
209 EM = self.EMSA_PSS_encode(M, numBits(self.n) - 1, hAlg, sLen)
210 m = bytesToNumber(EM)
211 if m >= self.n:
212 raise MessageTooLongError("Encode output too long")
213 s = self._rawPrivateKeyOp(m)
214 S = numberToByteArray(s, numBytes(self.n))
215 return S
216
218 """Verify signature in passed in encoded message
219
220 This verifies the signature in encoded message
221
222 @type M: bytearray
223 @param M: Original not signed message
224
225 @type EM: bytearray
226 @param EM: Encoded message
227
228 @type emBits: int
229 @param emBits: Length of the encoded message in bits
230
231 @type hAlg: str
232 @param hAlg: hash algorithm to be used
233
234 @type sLen: int
235 @param sLen: Length of salt
236 """
237 hashLen = getattr(hashlib, hAlg)().digest_size
238 mHash = secureHash(M, hAlg)
239 emLen = Poly1305.divceil(emBits, 8)
240 if emLen < hashLen + sLen + 2:
241 raise InvalidSignature("Invalid signature")
242 if EM[-1] != 0xbc:
243 raise InvalidSignature("Invalid signature")
244 maskedDB = EM[0:emLen - hashLen - 1]
245 H = EM[emLen - hashLen - 1:emLen - hashLen - 1 + hashLen]
246 DBHelpMask = 1 << 8 - (8*emLen - emBits)
247 DBHelpMask -= 1
248 DBHelpMask = (~DBHelpMask) & 0xff
249 if maskedDB[0] & DBHelpMask != 0:
250 raise InvalidSignature("Invalid signature")
251 dbMask = self.MGF1(H, emLen - hashLen - 1, hAlg)
252 DB = bytearray(i ^ j for i, j in zip(maskedDB, dbMask))
253 mLen = emLen*8 - emBits
254 mask = (1 << 8 - mLen) - 1
255 DB[0] &= mask
256 if any(x != 0 for x in DB[0:emLen - hashLen - sLen - 2 - 1]):
257 raise InvalidSignature("Invalid signature")
258 if DB[emLen - hashLen - sLen - 2] != 0x01:
259 raise InvalidSignature("Invalid signature")
260 if sLen != 0:
261 salt = DB[-sLen:]
262 else:
263 salt = bytearray()
264 newM = bytearray(8) + mHash + salt
265 newH = secureHash(newM, hAlg)
266 if H == newH:
267 return True
268 else:
269 raise InvalidSignature("Invalid signature")
270
272 """Verify the signature in passed in message
273
274 This verifies the signature in the signed message
275
276 @type M: bytearray
277 @param M: Original message
278
279 @type S: bytearray
280 @param S: Signed message
281
282 @type hAlg: str
283 @param hAlg: Hash algorithm to be used
284
285 @type sLen: int
286 @param sLen: Length of salt
287 """
288 if len(bytearray(S)) != len(numberToByteArray(self.n)):
289 raise InvalidSignature
290 s = bytesToNumber(S)
291 m = self._rawPublicKeyOp(s)
292 EM = numberToByteArray(m, Poly1305.divceil(numBits(self.n) - 1, 8))
293 result = self.EMSA_PSS_verify(M, EM, numBits(self.n) - 1, hAlg, sLen)
294 if result:
295 return True
296 else:
297 raise InvalidSignature("Invalid signature")
298
299 - def sign(self, bytes):
300 """Sign the passed-in bytes.
301
302 This requires the key to have a private component. It performs
303 a PKCS1 signature on the passed-in data.
304
305 @type bytes: L{bytearray} of unsigned bytes
306 @param bytes: The value which will be signed.
307
308 @rtype: L{bytearray} of unsigned bytes.
309 @return: A PKCS1 signature on the passed-in data.
310 """
311 if not self.hasPrivateKey():
312 raise AssertionError()
313 paddedBytes = self._addPKCS1Padding(bytes, 1)
314 m = bytesToNumber(paddedBytes)
315 if m >= self.n:
316 raise ValueError()
317 c = self._rawPrivateKeyOp(m)
318 sigBytes = numberToByteArray(c, numBytes(self.n))
319 return sigBytes
320
321 - def verify(self, sigBytes, bytes):
322 """Verify the passed-in bytes with the signature.
323
324 This verifies a PKCS1 signature on the passed-in data.
325
326 @type sigBytes: L{bytearray} of unsigned bytes
327 @param sigBytes: A PKCS1 signature.
328
329 @type bytes: L{bytearray} of unsigned bytes
330 @param bytes: The value which will be verified.
331
332 @rtype: bool
333 @return: Whether the signature matches the passed-in data.
334 """
335 if len(sigBytes) != numBytes(self.n):
336 return False
337 paddedBytes = self._addPKCS1Padding(bytes, 1)
338 c = bytesToNumber(sigBytes)
339 if c >= self.n:
340 return False
341 m = self._rawPublicKeyOp(c)
342 checkBytes = numberToByteArray(m, numBytes(self.n))
343 return checkBytes == paddedBytes
344
346 """Encrypt the passed-in bytes.
347
348 This performs PKCS1 encryption of the passed-in data.
349
350 @type bytes: L{bytearray} of unsigned bytes
351 @param bytes: The value which will be encrypted.
352
353 @rtype: L{bytearray} of unsigned bytes.
354 @return: A PKCS1 encryption of the passed-in data.
355 """
356 paddedBytes = self._addPKCS1Padding(bytes, 2)
357 m = bytesToNumber(paddedBytes)
358 if m >= self.n:
359 raise ValueError()
360 c = self._rawPublicKeyOp(m)
361 encBytes = numberToByteArray(c, numBytes(self.n))
362 return encBytes
363
365 """Decrypt the passed-in bytes.
366
367 This requires the key to have a private component. It performs
368 PKCS1 decryption of the passed-in data.
369
370 @type encBytes: L{bytearray} of unsigned bytes
371 @param encBytes: The value which will be decrypted.
372
373 @rtype: L{bytearray} of unsigned bytes or None.
374 @return: A PKCS1 decryption of the passed-in data or None if
375 the data is not properly formatted.
376 """
377 if not self.hasPrivateKey():
378 raise AssertionError()
379 if len(encBytes) != numBytes(self.n):
380 return None
381 c = bytesToNumber(encBytes)
382 if c >= self.n:
383 return None
384 m = self._rawPrivateKeyOp(c)
385 decBytes = numberToByteArray(m, numBytes(self.n))
386
387 if decBytes[0] != 0 or decBytes[1] != 2:
388 return None
389
390 for x in range(1, len(decBytes)-1):
391 if decBytes[x]== 0:
392 break
393 else:
394 return None
395 return decBytes[x+1:]
396
398 raise NotImplementedError()
399
401 raise NotImplementedError()
402
404 """Return True if the write() method accepts a password for use
405 in encrypting the private key.
406
407 @rtype: bool
408 """
409 raise NotImplementedError()
410
411 - def write(self, password=None):
412 """Return a string containing the key.
413
414 @rtype: str
415 @return: A string describing the key, in whichever format (PEM)
416 is native to the implementation.
417 """
418 raise NotImplementedError()
419
421 """Generate a new key with the specified bit length.
422
423 @rtype: L{tlslite.utils.RSAKey.RSAKey}
424 """
425 raise NotImplementedError()
426 generate = staticmethod(generate)
427
428
429
430
431
432
433 @classmethod
435 """Add PKCS#1 v1.5 algorithm identifier prefix to SHA1 hash bytes"""
436
437
438
439
440
441
442
443
444 if not withNULL:
445 prefixBytes = bytearray([0x30, 0x1f, 0x30, 0x07, 0x06, 0x05, 0x2b,
446 0x0e, 0x03, 0x02, 0x1a, 0x04, 0x14])
447 else:
448 prefixBytes = cls._pkcs1Prefixes['sha1']
449 prefixedBytes = prefixBytes + hashBytes
450 return prefixedBytes
451
452 _pkcs1Prefixes = {'md5' : bytearray([0x30, 0x20, 0x30, 0x0c, 0x06, 0x08,
453 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
454 0x02, 0x05, 0x05, 0x00, 0x04, 0x10]),
455 'sha1' : bytearray([0x30, 0x21, 0x30, 0x09, 0x06, 0x05,
456 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05,
457 0x00, 0x04, 0x14]),
458 'sha224' : bytearray([0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09,
459 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
460 0x04, 0x02, 0x04, 0x05, 0x00, 0x04,
461 0x1c]),
462 'sha256' : bytearray([0x30, 0x31, 0x30, 0x0d, 0x06, 0x09,
463 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
464 0x04, 0x02, 0x01, 0x05, 0x00, 0x04,
465 0x20]),
466 'sha384' : bytearray([0x30, 0x41, 0x30, 0x0d, 0x06, 0x09,
467 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
468 0x04, 0x02, 0x02, 0x05, 0x00, 0x04,
469 0x30]),
470 'sha512' : bytearray([0x30, 0x51, 0x30, 0x0d, 0x06, 0x09,
471 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
472 0x04, 0x02, 0x03, 0x05, 0x00, 0x04,
473 0x40])}
474
475 @classmethod
477 """Add the PKCS#1 v1.5 algorithm identifier prefix to hash bytes"""
478 hashName = hashName.lower()
479 assert hashName in cls._pkcs1Prefixes
480 prefixBytes = cls._pkcs1Prefixes[hashName]
481 return prefixBytes + data
482
484 padLength = (numBytes(self.n) - (len(bytes)+3))
485 if blockType == 1:
486 pad = [0xFF] * padLength
487 elif blockType == 2:
488 pad = bytearray(0)
489 while len(pad) < padLength:
490 padBytes = getRandomBytes(padLength * 2)
491 pad = [b for b in padBytes if b != 0]
492 pad = pad[:padLength]
493 else:
494 raise AssertionError()
495
496 padding = bytearray([0,blockType] + pad + [0])
497 paddedBytes = padding + bytes
498 return paddedBytes
499