Package tlslite :: Package utils :: Module cryptomath
[hide private]
[frames] | no frames]

Source Code for Module tlslite.utils.cryptomath

  1  # Author: Trevor Perrin 
  2  # See the LICENSE file for legal information regarding use of this file. 
  3   
  4  """cryptomath module 
  5   
  6  This module has basic math/crypto code.""" 
  7   
  8  import os 
  9  import math 
 10  import base64 
 11  import binascii 
 12   
 13  from .compat import * 
 14   
 15   
 16  # ************************************************************************** 
 17  # Load Optional Modules 
 18  # ************************************************************************** 
 19   
 20  # Try to load M2Crypto/OpenSSL 
 21  try: 
 22      from M2Crypto import m2 
 23      m2cryptoLoaded = True 
 24   
 25  except ImportError: 
 26      m2cryptoLoaded = False 
 27   
 28  #Try to load GMPY 
 29  try: 
 30      import gmpy 
 31      gmpyLoaded = True 
 32  except ImportError: 
 33      gmpyLoaded = False 
 34   
 35  #Try to load pycrypto 
 36  try: 
 37      import Crypto.Cipher.AES 
 38      pycryptoLoaded = True 
 39  except ImportError: 
 40      pycryptoLoaded = False 
 41   
 42   
 43  # ************************************************************************** 
 44  # PRNG Functions 
 45  # ************************************************************************** 
 46   
 47  # Check that os.urandom works 
 48  import zlib 
 49  length = len(zlib.compress(os.urandom(1000))) 
 50  assert(length > 900) 
 51   
52 -def getRandomBytes(howMany):
53 s = os.urandom(howMany) 54 assert(len(s) == howMany) 55 return stringToBytes(s)
56 prngName = "os.urandom" 57 58 59 # ************************************************************************** 60 # Converter Functions 61 # ************************************************************************** 62
63 -def bytesToNumber(bytes):
64 total = 0L 65 multiplier = 1L 66 for count in range(len(bytes)-1, -1, -1): 67 byte = bytes[count] 68 total += multiplier * byte 69 multiplier *= 256 70 return total
71
72 -def numberToBytes(n, howManyBytes=None):
73 """Convert an integer into a bytearray, zero-pad to howManyBytes. 74 75 The returned bytearray may be smaller than howManyBytes, but will 76 not be larger. The returned bytearray will contain a big-endian 77 encoding of the input integer (n). 78 """ 79 if howManyBytes == None: 80 howManyBytes = numBytes(n) 81 bytes = createByteArrayZeros(howManyBytes) 82 for count in range(howManyBytes-1, -1, -1): 83 bytes[count] = int(n % 256) 84 n >>= 8 85 return bytes
86
87 -def bytesToBase64(bytes):
88 s = bytesToString(bytes) 89 return stringToBase64(s)
90
91 -def base64ToBytes(s):
92 s = base64ToString(s) 93 return stringToBytes(s)
94
95 -def numberToBase64(n):
96 bytes = numberToBytes(n) 97 return bytesToBase64(bytes)
98
99 -def base64ToNumber(s):
100 bytes = base64ToBytes(s) 101 return bytesToNumber(bytes)
102
103 -def stringToNumber(s):
104 bytes = stringToBytes(s) 105 return bytesToNumber(bytes)
106
107 -def numberToString(s):
108 bytes = numberToBytes(s) 109 return bytesToString(bytes)
110
111 -def base64ToString(s):
112 try: 113 return base64.decodestring(s) 114 except binascii.Error, e: 115 raise SyntaxError(e) 116 except binascii.Incomplete, e: 117 raise SyntaxError(e)
118
119 -def stringToBase64(s):
120 return base64.encodestring(s).replace("\n", "")
121
122 -def mpiToNumber(mpi): #mpi is an openssl-format bignum string
123 if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number 124 raise AssertionError() 125 bytes = stringToBytes(mpi[4:]) 126 return bytesToNumber(bytes) 127
128 -def numberToMPI(n):
129 bytes = numberToBytes(n) 130 ext = 0 131 #If the high-order bit is going to be set, 132 #add an extra byte of zeros 133 if (numBits(n) & 0x7)==0: 134 ext = 1 135 length = numBytes(n) + ext 136 bytes = createByteArrayZeros(4+ext) + bytes 137 bytes[0] = (length >> 24) & 0xFF 138 bytes[1] = (length >> 16) & 0xFF 139 bytes[2] = (length >> 8) & 0xFF 140 bytes[3] = length & 0xFF 141 return bytesToString(bytes)
142 143 144 145 # ************************************************************************** 146 # Misc. Utility Functions 147 # ************************************************************************** 148
149 -def numBytes(n):
150 if n==0: 151 return 0 152 bits = numBits(n) 153 return int(math.ceil(bits / 8.0))
154 155 # ************************************************************************** 156 # Big Number Math 157 # ************************************************************************** 158
159 -def getRandomNumber(low, high):
160 if low >= high: 161 raise AssertionError() 162 howManyBits = numBits(high) 163 howManyBytes = numBytes(high) 164 lastBits = howManyBits % 8 165 while 1: 166 bytes = getRandomBytes(howManyBytes) 167 if lastBits: 168 bytes[0] = bytes[0] % (1 << lastBits) 169 n = bytesToNumber(bytes) 170 if n >= low and n < high: 171 return n
172
173 -def gcd(a,b):
174 a, b = max(a,b), min(a,b) 175 while b: 176 a, b = b, a % b 177 return a
178
179 -def lcm(a, b):
180 return (a * b) // gcd(a, b)
181 182 #Returns inverse of a mod b, zero if none 183 #Uses Extended Euclidean Algorithm
184 -def invMod(a, b):
185 c, d = a, b 186 uc, ud = 1, 0 187 while c != 0: 188 q = d // c 189 c, d = d-(q*c), c 190 uc, ud = ud - (q * uc), uc 191 if d == 1: 192 return ud % b 193 return 0
194 195 196 if gmpyLoaded:
197 - def powMod(base, power, modulus):
198 base = gmpy.mpz(base) 199 power = gmpy.mpz(power) 200 modulus = gmpy.mpz(modulus) 201 result = pow(base, power, modulus) 202 return long(result)
203 204 else:
205 - def powMod(base, power, modulus):
206 if power < 0: 207 result = pow(base, power*-1, modulus) 208 result = invMod(result, modulus) 209 return result 210 else: 211 return pow(base, power, modulus)
212 213 #Pre-calculate a sieve of the ~100 primes < 1000:
214 -def makeSieve(n):
215 sieve = range(n) 216 for count in range(2, int(math.sqrt(n))): 217 if sieve[count] == 0: 218 continue 219 x = sieve[count] * 2 220 while x < len(sieve): 221 sieve[x] = 0 222 x += sieve[count] 223 sieve = [x for x in sieve[2:] if x] 224 return sieve
225 226 sieve = makeSieve(1000) 227
228 -def isPrime(n, iterations=5, display=False):
229 #Trial division with sieve 230 for x in sieve: 231 if x >= n: return True 232 if n % x == 0: return False 233 #Passed trial division, proceed to Rabin-Miller 234 #Rabin-Miller implemented per Ferguson & Schneier 235 #Compute s, t for Rabin-Miller 236 if display: print "*", 237 s, t = n-1, 0 238 while s % 2 == 0: 239 s, t = s//2, t+1 240 #Repeat Rabin-Miller x times 241 a = 2 #Use 2 as a base for first iteration speedup, per HAC 242 for count in range(iterations): 243 v = powMod(a, s, n) 244 if v==1: 245 continue 246 i = 0 247 while v != n-1: 248 if i == t-1: 249 return False 250 else: 251 v, i = powMod(v, 2, n), i+1 252 a = getRandomNumber(2, n) 253 return True
254
255 -def getRandomPrime(bits, display=False):
256 if bits < 10: 257 raise AssertionError() 258 #The 1.5 ensures the 2 MSBs are set 259 #Thus, when used for p,q in RSA, n will have its MSB set 260 # 261 #Since 30 is lcm(2,3,5), we'll set our test numbers to 262 #29 % 30 and keep them there 263 low = ((2L ** (bits-1)) * 3) // 2 264 high = 2L ** bits - 30 265 p = getRandomNumber(low, high) 266 p += 29 - (p % 30) 267 while 1: 268 if display: print ".", 269 p += 30 270 if p >= high: 271 p = getRandomNumber(low, high) 272 p += 29 - (p % 30) 273 if isPrime(p, display=display): 274 return p
275 276 #Unused at the moment...
277 -def getRandomSafePrime(bits, display=False):
278 if bits < 10: 279 raise AssertionError() 280 #The 1.5 ensures the 2 MSBs are set 281 #Thus, when used for p,q in RSA, n will have its MSB set 282 # 283 #Since 30 is lcm(2,3,5), we'll set our test numbers to 284 #29 % 30 and keep them there 285 low = (2 ** (bits-2)) * 3//2 286 high = (2 ** (bits-1)) - 30 287 q = getRandomNumber(low, high) 288 q += 29 - (q % 30) 289 while 1: 290 if display: print ".", 291 q += 30 292 if (q >= high): 293 q = getRandomNumber(low, high) 294 q += 29 - (q % 30) 295 #Ideas from Tom Wu's SRP code 296 #Do trial division on p and q before Rabin-Miller 297 if isPrime(q, 0, display=display): 298 p = (2 * q) + 1 299 if isPrime(p, display=display): 300 if isPrime(q, display=display): 301 return p
302