1
2
3
4
5
6
7
8 """cryptomath module
9
10 This module has basic math/crypto code."""
11 from __future__ import print_function
12 import os
13 import math
14 import base64
15 import binascii
16
17 from .compat import *
18
19
20
21
22
23
24
25 try:
26 from M2Crypto import m2
27 m2cryptoLoaded = True
28
29 except ImportError:
30 m2cryptoLoaded = False
31
32
33 try:
34 import gmpy
35 gmpyLoaded = True
36 except ImportError:
37 gmpyLoaded = False
38
39
40 try:
41 import Crypto.Cipher.AES
42 pycryptoLoaded = True
43 except ImportError:
44 pycryptoLoaded = False
45
46
47
48
49
50
51
52 import zlib
53 length = len(zlib.compress(os.urandom(1000)))
54 assert(length > 900)
55 del length
56
58 b = bytearray(os.urandom(howMany))
59 assert(len(b) == howMany)
60 return b
61
62 prngName = "os.urandom"
63
64
65
66
67
68 import hmac
69 import hashlib
70
73
76
81
86
91
96
97
98
99
100
102 total = 0
103 multiplier = 1
104 for count in range(len(b)-1, -1, -1):
105 byte = b[count]
106 total += multiplier * byte
107 multiplier *= 256
108 return total
109
111 """Convert an integer into a bytearray, zero-pad to howManyBytes.
112
113 The returned bytearray may be smaller than howManyBytes, but will
114 not be larger. The returned bytearray will contain a big-endian
115 encoding of the input integer (n).
116 """
117 if howManyBytes == None:
118 howManyBytes = numBytes(n)
119 b = bytearray(howManyBytes)
120 for count in range(howManyBytes-1, -1, -1):
121 b[count] = int(n % 256)
122 n >>= 8
123 return b
124
126 if (ord(mpi[4]) & 0x80) !=0:
127 raise AssertionError()
128 b = bytearray(mpi[4:])
129 return bytesToNumber(b)
130
132 b = numberToByteArray(n)
133 ext = 0
134
135
136 if (numBits(n) & 0x7)==0:
137 ext = 1
138 length = numBytes(n) + ext
139 b = bytearray(4+ext) + b
140 b[0] = (length >> 24) & 0xFF
141 b[1] = (length >> 16) & 0xFF
142 b[2] = (length >> 8) & 0xFF
143 b[3] = length & 0xFF
144 return bytes(b)
145
146
147
148
149
150
152 if n==0:
153 return 0
154 s = "%x" % n
155 return ((len(s)-1)*4) + \
156 {'0':0, '1':1, '2':2, '3':2,
157 '4':3, '5':3, '6':3, '7':3,
158 '8':4, '9':4, 'a':4, 'b':4,
159 'c':4, 'd':4, 'e':4, 'f':4,
160 }[s[0]]
161 return int(math.floor(math.log(n, 2))+1)
162
164 if n==0:
165 return 0
166 bits = numBits(n)
167 return int(math.ceil(bits / 8.0))
168
169
170
171
172
174 if low >= high:
175 raise AssertionError()
176 howManyBits = numBits(high)
177 howManyBytes = numBytes(high)
178 lastBits = howManyBits % 8
179 while 1:
180 bytes = getRandomBytes(howManyBytes)
181 if lastBits:
182 bytes[0] = bytes[0] % (1 << lastBits)
183 n = bytesToNumber(bytes)
184 if n >= low and n < high:
185 return n
186
188 a, b = max(a,b), min(a,b)
189 while b:
190 a, b = b, a % b
191 return a
192
194 return (a * b) // gcd(a, b)
195
196
197
199 c, d = a, b
200 uc, ud = 1, 0
201 while c != 0:
202 q = d // c
203 c, d = d-(q*c), c
204 uc, ud = ud - (q * uc), uc
205 if d == 1:
206 return ud % b
207 return 0
208
209
210 if gmpyLoaded:
211 - def powMod(base, power, modulus):
212 base = gmpy.mpz(base)
213 power = gmpy.mpz(power)
214 modulus = gmpy.mpz(modulus)
215 result = pow(base, power, modulus)
216 return compatLong(result)
217
218 else:
219 - def powMod(base, power, modulus):
220 if power < 0:
221 result = pow(base, power*-1, modulus)
222 result = invMod(result, modulus)
223 return result
224 else:
225 return pow(base, power, modulus)
226
227
229 sieve = list(range(n))
230 for count in range(2, int(math.sqrt(n))+1):
231 if sieve[count] == 0:
232 continue
233 x = sieve[count] * 2
234 while x < len(sieve):
235 sieve[x] = 0
236 x += sieve[count]
237 sieve = [x for x in sieve[2:] if x]
238 return sieve
239
241
242 for x in sieve:
243 if x >= n: return True
244 if n % x == 0: return False
245
246
247
248 if display: print("*", end=' ')
249 s, t = n-1, 0
250 while s % 2 == 0:
251 s, t = s//2, t+1
252
253 a = 2
254 for count in range(iterations):
255 v = powMod(a, s, n)
256 if v==1:
257 continue
258 i = 0
259 while v != n-1:
260 if i == t-1:
261 return False
262 else:
263 v, i = powMod(v, 2, n), i+1
264 a = getRandomNumber(2, n)
265 return True
266
268 if bits < 10:
269 raise AssertionError()
270
271
272
273
274
275 low = ((2 ** (bits-1)) * 3) // 2
276 high = 2 ** bits - 30
277 p = getRandomNumber(low, high)
278 p += 29 - (p % 30)
279 while 1:
280 if display: print(".", end=' ')
281 p += 30
282 if p >= high:
283 p = getRandomNumber(low, high)
284 p += 29 - (p % 30)
285 if isPrime(p, display=display):
286 return p
287
288
290 if bits < 10:
291 raise AssertionError()
292
293
294
295
296
297 low = (2 ** (bits-2)) * 3//2
298 high = (2 ** (bits-1)) - 30
299 q = getRandomNumber(low, high)
300 q += 29 - (q % 30)
301 while 1:
302 if display: print(".", end=' ')
303 q += 30
304 if (q >= high):
305 q = getRandomNumber(low, high)
306 q += 29 - (q % 30)
307
308
309 if isPrime(q, 0, display=display):
310 p = (2 * q) + 1
311 if isPrime(p, display=display):
312 if isPrime(q, display=display):
313 return p
314