1
2
3
4
5
6
7
8 """cryptomath module
9
10 This module has basic math/crypto code."""
11 from __future__ import print_function
12 import os
13 import math
14 import base64
15 import binascii
16 import sys
17
18 from .compat import compat26Str, compatHMAC, compatLong
19
20
21
22
23
24
25
26 try:
27 from M2Crypto import m2
28 m2cryptoLoaded = True
29
30 try:
31 with open('/proc/sys/crypto/fips_enabled', 'r') as fipsFile:
32 if '1' in fipsFile.read():
33 m2cryptoLoaded = False
34 except (IOError, OSError):
35
36 m2cryptoLoaded = True
37
38 except ImportError:
39 m2cryptoLoaded = False
40
41
42 try:
43 import gmpy
44 gmpyLoaded = True
45 except ImportError:
46 gmpyLoaded = False
47
48
49 try:
50 import Crypto.Cipher.AES
51 pycryptoLoaded = True
52 except ImportError:
53 pycryptoLoaded = False
54
55
56
57
58
59
60
61 import zlib
62 length = len(zlib.compress(os.urandom(1000)))
63 assert(length > 900)
64 del length
65
67 b = bytearray(os.urandom(howMany))
68 assert(len(b) == howMany)
69 return b
70
71 prngName = "os.urandom"
72
73
74
75
76
77 import hmac
78 from . import tlshashlib as hashlib
79
81 """Return a MD5 digest of data"""
82 return secureHash(b, 'md5')
83
85 """Return a SHA1 digest of data"""
86 return secureHash(b, 'sha1')
87
89 """Return a digest of `data` using `algorithm`"""
90 hashInstance = hashlib.new(algorithm)
91 hashInstance.update(compat26Str(data))
92 return bytearray(hashInstance.digest())
93
98
103
108
113
114
115
116
117
119 total = 0
120 multiplier = 1
121 for count in range(len(b)-1, -1, -1):
122 byte = b[count]
123 total += multiplier * byte
124 multiplier *= 256
125 return total
126
128 """Convert an integer into a bytearray, zero-pad to howManyBytes.
129
130 The returned bytearray may be smaller than howManyBytes, but will
131 not be larger. The returned bytearray will contain a big-endian
132 encoding of the input integer (n).
133 """
134 if howManyBytes == None:
135 howManyBytes = numBytes(n)
136 b = bytearray(howManyBytes)
137 for count in range(howManyBytes-1, -1, -1):
138 b[count] = int(n % 256)
139 n >>= 8
140 return b
141
143 if (ord(mpi[4]) & 0x80) !=0:
144 raise AssertionError()
145 b = bytearray(mpi[4:])
146 return bytesToNumber(b)
147
149 b = numberToByteArray(n)
150 ext = 0
151
152
153 if (numBits(n) & 0x7)==0:
154 ext = 1
155 length = numBytes(n) + ext
156 b = bytearray(4+ext) + b
157 b[0] = (length >> 24) & 0xFF
158 b[1] = (length >> 16) & 0xFF
159 b[2] = (length >> 8) & 0xFF
160 b[3] = length & 0xFF
161 return bytes(b)
162
163
164
165
166
167
169 """Return number of bits necessary to represent the integer in binary"""
170 if n==0:
171 return 0
172 if sys.version_info < (2, 7):
173
174
175 return len(bin(n))-2
176 else:
177 return n.bit_length()
178
180 """Return number of bytes necessary to represent the integer in bytes"""
181 if n==0:
182 return 0
183 bits = numBits(n)
184 return (bits + 7) // 8
185
186
187
188
189
191 if low >= high:
192 raise AssertionError()
193 howManyBits = numBits(high)
194 howManyBytes = numBytes(high)
195 lastBits = howManyBits % 8
196 while 1:
197 bytes = getRandomBytes(howManyBytes)
198 if lastBits:
199 bytes[0] = bytes[0] % (1 << lastBits)
200 n = bytesToNumber(bytes)
201 if n >= low and n < high:
202 return n
203
205 a, b = max(a,b), min(a,b)
206 while b:
207 a, b = b, a % b
208 return a
209
211 return (a * b) // gcd(a, b)
212
213
214
216 c, d = a, b
217 uc, ud = 1, 0
218 while c != 0:
219 q = d // c
220 c, d = d-(q*c), c
221 uc, ud = ud - (q * uc), uc
222 if d == 1:
223 return ud % b
224 return 0
225
226
227 if gmpyLoaded:
228 - def powMod(base, power, modulus):
229 base = gmpy.mpz(base)
230 power = gmpy.mpz(power)
231 modulus = gmpy.mpz(modulus)
232 result = pow(base, power, modulus)
233 return compatLong(result)
234
235 else:
236 - def powMod(base, power, modulus):
237 if power < 0:
238 result = pow(base, power*-1, modulus)
239 result = invMod(result, modulus)
240 return result
241 else:
242 return pow(base, power, modulus)
243
244
246 sieve = list(range(n))
247 for count in range(2, int(math.sqrt(n))+1):
248 if sieve[count] == 0:
249 continue
250 x = sieve[count] * 2
251 while x < len(sieve):
252 sieve[x] = 0
253 x += sieve[count]
254 sieve = [x for x in sieve[2:] if x]
255 return sieve
256
258
259 for x in sieve:
260 if x >= n: return True
261 if n % x == 0: return False
262
263
264
265 if display: print("*", end=' ')
266 s, t = n-1, 0
267 while s % 2 == 0:
268 s, t = s//2, t+1
269
270 a = 2
271 for count in range(iterations):
272 v = powMod(a, s, n)
273 if v==1:
274 continue
275 i = 0
276 while v != n-1:
277 if i == t-1:
278 return False
279 else:
280 v, i = powMod(v, 2, n), i+1
281 a = getRandomNumber(2, n)
282 return True
283
285 if bits < 10:
286 raise AssertionError()
287
288
289
290
291
292 low = ((2 ** (bits-1)) * 3) // 2
293 high = 2 ** bits - 30
294 p = getRandomNumber(low, high)
295 p += 29 - (p % 30)
296 while 1:
297 if display: print(".", end=' ')
298 p += 30
299 if p >= high:
300 p = getRandomNumber(low, high)
301 p += 29 - (p % 30)
302 if isPrime(p, display=display):
303 return p
304
305
307 if bits < 10:
308 raise AssertionError()
309
310
311
312
313
314 low = (2 ** (bits-2)) * 3//2
315 high = (2 ** (bits-1)) - 30
316 q = getRandomNumber(low, high)
317 q += 29 - (q % 30)
318 while 1:
319 if display: print(".", end=' ')
320 q += 30
321 if (q >= high):
322 q = getRandomNumber(low, high)
323 q += 29 - (q % 30)
324
325
326 if isPrime(q, 0, display=display):
327 p = (2 * q) + 1
328 if isPrime(p, display=display):
329 if isPrime(q, display=display):
330 return p
331