diff options
Diffstat (limited to 'src/lib/tlslite/utils/cryptomath.py')
-rwxr-xr-x | src/lib/tlslite/utils/cryptomath.py | 400 |
1 files changed, 400 insertions, 0 deletions
diff --git a/src/lib/tlslite/utils/cryptomath.py b/src/lib/tlslite/utils/cryptomath.py new file mode 100755 index 000000000..51d6dff7c --- /dev/null +++ b/src/lib/tlslite/utils/cryptomath.py @@ -0,0 +1,400 @@ +"""cryptomath module + +This module has basic math/crypto code.""" + +import os +import math +import base64 +import binascii +import sha + +from compat import * + + +# ************************************************************************** +# Load Optional Modules +# ************************************************************************** + +# Try to load M2Crypto/OpenSSL +try: + from M2Crypto import m2 + m2cryptoLoaded = True + +except ImportError: + m2cryptoLoaded = False + + +# Try to load cryptlib +try: + import cryptlib_py + try: + cryptlib_py.cryptInit() + except cryptlib_py.CryptException, e: + #If tlslite and cryptoIDlib are both present, + #they might each try to re-initialize this, + #so we're tolerant of that. + if e[0] != cryptlib_py.CRYPT_ERROR_INITED: + raise + cryptlibpyLoaded = True + +except ImportError: + cryptlibpyLoaded = False + +#Try to load GMPY +try: + import gmpy + gmpyLoaded = True +except ImportError: + gmpyLoaded = False + +#Try to load pycrypto +try: + import Crypto.Cipher.AES + pycryptoLoaded = True +except ImportError: + pycryptoLoaded = False + + +# ************************************************************************** +# PRNG Functions +# ************************************************************************** + +# Get os.urandom PRNG +try: + os.urandom(1) + def getRandomBytes(howMany): + return stringToBytes(os.urandom(howMany)) + prngName = "os.urandom" + +except: + # Else get cryptlib PRNG + if cryptlibpyLoaded: + def getRandomBytes(howMany): + randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED, + cryptlib_py.CRYPT_ALGO_AES) + cryptlib_py.cryptSetAttribute(randomKey, + cryptlib_py.CRYPT_CTXINFO_MODE, + cryptlib_py.CRYPT_MODE_OFB) + cryptlib_py.cryptGenerateKey(randomKey) + bytes = createByteArrayZeros(howMany) + cryptlib_py.cryptEncrypt(randomKey, bytes) + return bytes + prngName = "cryptlib" + + else: + #Else get UNIX /dev/urandom PRNG + try: + devRandomFile = open("/dev/urandom", "rb") + def getRandomBytes(howMany): + return stringToBytes(devRandomFile.read(howMany)) + prngName = "/dev/urandom" + except IOError: + #Else get Win32 CryptoAPI PRNG + try: + import win32prng + def getRandomBytes(howMany): + s = win32prng.getRandomBytes(howMany) + if len(s) != howMany: + raise AssertionError() + return stringToBytes(s) + prngName ="CryptoAPI" + except ImportError: + #Else no PRNG :-( + def getRandomBytes(howMany): + raise NotImplementedError("No Random Number Generator "\ + "available.") + prngName = "None" + +# ************************************************************************** +# Converter Functions +# ************************************************************************** + +def bytesToNumber(bytes): + total = 0L + multiplier = 1L + for count in range(len(bytes)-1, -1, -1): + byte = bytes[count] + total += multiplier * byte + multiplier *= 256 + return total + +def numberToBytes(n): + howManyBytes = numBytes(n) + bytes = createByteArrayZeros(howManyBytes) + for count in range(howManyBytes-1, -1, -1): + bytes[count] = int(n % 256) + n >>= 8 + return bytes + +def bytesToBase64(bytes): + s = bytesToString(bytes) + return stringToBase64(s) + +def base64ToBytes(s): + s = base64ToString(s) + return stringToBytes(s) + +def numberToBase64(n): + bytes = numberToBytes(n) + return bytesToBase64(bytes) + +def base64ToNumber(s): + bytes = base64ToBytes(s) + return bytesToNumber(bytes) + +def stringToNumber(s): + bytes = stringToBytes(s) + return bytesToNumber(bytes) + +def numberToString(s): + bytes = numberToBytes(s) + return bytesToString(bytes) + +def base64ToString(s): + try: + return base64.decodestring(s) + except binascii.Error, e: + raise SyntaxError(e) + except binascii.Incomplete, e: + raise SyntaxError(e) + +def stringToBase64(s): + return base64.encodestring(s).replace("\n", "") + +def mpiToNumber(mpi): #mpi is an openssl-format bignum string + if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number + raise AssertionError() + bytes = stringToBytes(mpi[4:]) + return bytesToNumber(bytes) + +def numberToMPI(n): + bytes = numberToBytes(n) + ext = 0 + #If the high-order bit is going to be set, + #add an extra byte of zeros + if (numBits(n) & 0x7)==0: + ext = 1 + length = numBytes(n) + ext + bytes = concatArrays(createByteArrayZeros(4+ext), bytes) + bytes[0] = (length >> 24) & 0xFF + bytes[1] = (length >> 16) & 0xFF + bytes[2] = (length >> 8) & 0xFF + bytes[3] = length & 0xFF + return bytesToString(bytes) + + + +# ************************************************************************** +# Misc. Utility Functions +# ************************************************************************** + +def numBytes(n): + if n==0: + return 0 + bits = numBits(n) + return int(math.ceil(bits / 8.0)) + +def hashAndBase64(s): + return stringToBase64(sha.sha(s).digest()) + +def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce + bytes = getRandomBytes(numChars) + bytesStr = "".join([chr(b) for b in bytes]) + return stringToBase64(bytesStr)[:numChars] + + +# ************************************************************************** +# Big Number Math +# ************************************************************************** + +def getRandomNumber(low, high): + if low >= high: + raise AssertionError() + howManyBits = numBits(high) + howManyBytes = numBytes(high) + lastBits = howManyBits % 8 + while 1: + bytes = getRandomBytes(howManyBytes) + if lastBits: + bytes[0] = bytes[0] % (1 << lastBits) + n = bytesToNumber(bytes) + if n >= low and n < high: + return n + +def gcd(a,b): + a, b = max(a,b), min(a,b) + while b: + a, b = b, a % b + return a + +def lcm(a, b): + #This will break when python division changes, but we can't use // cause + #of Jython + return (a * b) / gcd(a, b) + +#Returns inverse of a mod b, zero if none +#Uses Extended Euclidean Algorithm +def invMod(a, b): + c, d = a, b + uc, ud = 1, 0 + while c != 0: + #This will break when python division changes, but we can't use // + #cause of Jython + q = d / c + c, d = d-(q*c), c + uc, ud = ud - (q * uc), uc + if d == 1: + return ud % b + return 0 + + +if gmpyLoaded: + def powMod(base, power, modulus): + base = gmpy.mpz(base) + power = gmpy.mpz(power) + modulus = gmpy.mpz(modulus) + result = pow(base, power, modulus) + return long(result) + +else: + #Copied from Bryan G. Olson's post to comp.lang.python + #Does left-to-right instead of pow()'s right-to-left, + #thus about 30% faster than the python built-in with small bases + def powMod(base, power, modulus): + nBitScan = 5 + + """ Return base**power mod modulus, using multi bit scanning + with nBitScan bits at a time.""" + + #TREV - Added support for negative exponents + negativeResult = False + if (power < 0): + power *= -1 + negativeResult = True + + exp2 = 2**nBitScan + mask = exp2 - 1 + + # Break power into a list of digits of nBitScan bits. + # The list is recursive so easy to read in reverse direction. + nibbles = None + while power: + nibbles = int(power & mask), nibbles + power = power >> nBitScan + + # Make a table of powers of base up to 2**nBitScan - 1 + lowPowers = [1] + for i in xrange(1, exp2): + lowPowers.append((lowPowers[i-1] * base) % modulus) + + # To exponentiate by the first nibble, look it up in the table + nib, nibbles = nibbles + prod = lowPowers[nib] + + # For the rest, square nBitScan times, then multiply by + # base^nibble + while nibbles: + nib, nibbles = nibbles + for i in xrange(nBitScan): + prod = (prod * prod) % modulus + if nib: prod = (prod * lowPowers[nib]) % modulus + + #TREV - Added support for negative exponents + if negativeResult: + prodInv = invMod(prod, modulus) + #Check to make sure the inverse is correct + if (prod * prodInv) % modulus != 1: + raise AssertionError() + return prodInv + return prod + + +#Pre-calculate a sieve of the ~100 primes < 1000: +def makeSieve(n): + sieve = range(n) + for count in range(2, int(math.sqrt(n))): + if sieve[count] == 0: + continue + x = sieve[count] * 2 + while x < len(sieve): + sieve[x] = 0 + x += sieve[count] + sieve = [x for x in sieve[2:] if x] + return sieve + +sieve = makeSieve(1000) + +def isPrime(n, iterations=5, display=False): + #Trial division with sieve + for x in sieve: + if x >= n: return True + if n % x == 0: return False + #Passed trial division, proceed to Rabin-Miller + #Rabin-Miller implemented per Ferguson & Schneier + #Compute s, t for Rabin-Miller + if display: print "*", + s, t = n-1, 0 + while s % 2 == 0: + s, t = s/2, t+1 + #Repeat Rabin-Miller x times + a = 2 #Use 2 as a base for first iteration speedup, per HAC + for count in range(iterations): + v = powMod(a, s, n) + if v==1: + continue + i = 0 + while v != n-1: + if i == t-1: + return False + else: + v, i = powMod(v, 2, n), i+1 + a = getRandomNumber(2, n) + return True + +def getRandomPrime(bits, display=False): + if bits < 10: + raise AssertionError() + #The 1.5 ensures the 2 MSBs are set + #Thus, when used for p,q in RSA, n will have its MSB set + # + #Since 30 is lcm(2,3,5), we'll set our test numbers to + #29 % 30 and keep them there + low = (2L ** (bits-1)) * 3/2 + high = 2L ** bits - 30 + p = getRandomNumber(low, high) + p += 29 - (p % 30) + while 1: + if display: print ".", + p += 30 + if p >= high: + p = getRandomNumber(low, high) + p += 29 - (p % 30) + if isPrime(p, display=display): + return p + +#Unused at the moment... +def getRandomSafePrime(bits, display=False): + if bits < 10: + raise AssertionError() + #The 1.5 ensures the 2 MSBs are set + #Thus, when used for p,q in RSA, n will have its MSB set + # + #Since 30 is lcm(2,3,5), we'll set our test numbers to + #29 % 30 and keep them there + low = (2 ** (bits-2)) * 3/2 + high = (2 ** (bits-1)) - 30 + q = getRandomNumber(low, high) + q += 29 - (q % 30) + while 1: + if display: print ".", + q += 30 + if (q >= high): + q = getRandomNumber(low, high) + q += 29 - (q % 30) + #Ideas from Tom Wu's SRP code + #Do trial division on p and q before Rabin-Miller + if isPrime(q, 0, display=display): + p = (2 * q) + 1 + if isPrime(p, display=display): + if isPrime(q, display=display): + return p |