summaryrefslogtreecommitdiffstats
path: root/src/lib/tlslite/messages.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/tlslite/messages.py')
-rwxr-xr-xsrc/lib/tlslite/messages.py561
1 files changed, 561 insertions, 0 deletions
diff --git a/src/lib/tlslite/messages.py b/src/lib/tlslite/messages.py
new file mode 100755
index 000000000..afccc793a
--- /dev/null
+++ b/src/lib/tlslite/messages.py
@@ -0,0 +1,561 @@
+"""Classes representing TLS messages."""
+
+from utils.compat import *
+from utils.cryptomath import *
+from errors import *
+from utils.codec import *
+from constants import *
+from X509 import X509
+from X509CertChain import X509CertChain
+
+import sha
+import md5
+
+class RecordHeader3:
+ def __init__(self):
+ self.type = 0
+ self.version = (0,0)
+ self.length = 0
+ self.ssl2 = False
+
+ def create(self, version, type, length):
+ self.type = type
+ self.version = version
+ self.length = length
+ return self
+
+ def write(self):
+ w = Writer(5)
+ w.add(self.type, 1)
+ w.add(self.version[0], 1)
+ w.add(self.version[1], 1)
+ w.add(self.length, 2)
+ return w.bytes
+
+ def parse(self, p):
+ self.type = p.get(1)
+ self.version = (p.get(1), p.get(1))
+ self.length = p.get(2)
+ self.ssl2 = False
+ return self
+
+class RecordHeader2:
+ def __init__(self):
+ self.type = 0
+ self.version = (0,0)
+ self.length = 0
+ self.ssl2 = True
+
+ def parse(self, p):
+ if p.get(1)!=128:
+ raise SyntaxError()
+ self.type = ContentType.handshake
+ self.version = (2,0)
+ #We don't support 2-byte-length-headers; could be a problem
+ self.length = p.get(1)
+ return self
+
+
+class Msg:
+ def preWrite(self, trial):
+ if trial:
+ w = Writer()
+ else:
+ length = self.write(True)
+ w = Writer(length)
+ return w
+
+ def postWrite(self, w, trial):
+ if trial:
+ return w.index
+ else:
+ return w.bytes
+
+class Alert(Msg):
+ def __init__(self):
+ self.contentType = ContentType.alert
+ self.level = 0
+ self.description = 0
+
+ def create(self, description, level=AlertLevel.fatal):
+ self.level = level
+ self.description = description
+ return self
+
+ def parse(self, p):
+ p.setLengthCheck(2)
+ self.level = p.get(1)
+ self.description = p.get(1)
+ p.stopLengthCheck()
+ return self
+
+ def write(self):
+ w = Writer(2)
+ w.add(self.level, 1)
+ w.add(self.description, 1)
+ return w.bytes
+
+
+class HandshakeMsg(Msg):
+ def preWrite(self, handshakeType, trial):
+ if trial:
+ w = Writer()
+ w.add(handshakeType, 1)
+ w.add(0, 3)
+ else:
+ length = self.write(True)
+ w = Writer(length)
+ w.add(handshakeType, 1)
+ w.add(length-4, 3)
+ return w
+
+
+class ClientHello(HandshakeMsg):
+ def __init__(self, ssl2=False):
+ self.contentType = ContentType.handshake
+ self.ssl2 = ssl2
+ self.client_version = (0,0)
+ self.random = createByteArrayZeros(32)
+ self.session_id = createByteArraySequence([])
+ self.cipher_suites = [] # a list of 16-bit values
+ self.certificate_types = [CertificateType.x509]
+ self.compression_methods = [] # a list of 8-bit values
+ self.srp_username = None # a string
+
+ def create(self, version, random, session_id, cipher_suites,
+ certificate_types=None, srp_username=None):
+ self.client_version = version
+ self.random = random
+ self.session_id = session_id
+ self.cipher_suites = cipher_suites
+ self.certificate_types = certificate_types
+ self.compression_methods = [0]
+ self.srp_username = srp_username
+ return self
+
+ def parse(self, p):
+ if self.ssl2:
+ self.client_version = (p.get(1), p.get(1))
+ cipherSpecsLength = p.get(2)
+ sessionIDLength = p.get(2)
+ randomLength = p.get(2)
+ self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
+ self.session_id = p.getFixBytes(sessionIDLength)
+ self.random = p.getFixBytes(randomLength)
+ if len(self.random) < 32:
+ zeroBytes = 32-len(self.random)
+ self.random = createByteArrayZeros(zeroBytes) + self.random
+ self.compression_methods = [0]#Fake this value
+
+ #We're not doing a stopLengthCheck() for SSLv2, oh well..
+ else:
+ p.startLengthCheck(3)
+ self.client_version = (p.get(1), p.get(1))
+ self.random = p.getFixBytes(32)
+ self.session_id = p.getVarBytes(1)
+ self.cipher_suites = p.getVarList(2, 2)
+ self.compression_methods = p.getVarList(1, 1)
+ if not p.atLengthCheck():
+ totalExtLength = p.get(2)
+ soFar = 0
+ while soFar != totalExtLength:
+ extType = p.get(2)
+ extLength = p.get(2)
+ if extType == 6:
+ self.srp_username = bytesToString(p.getVarBytes(1))
+ elif extType == 7:
+ self.certificate_types = p.getVarList(1, 1)
+ else:
+ p.getFixBytes(extLength)
+ soFar += 4 + extLength
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
+ w.add(self.client_version[0], 1)
+ w.add(self.client_version[1], 1)
+ w.addFixSeq(self.random, 1)
+ w.addVarSeq(self.session_id, 1, 1)
+ w.addVarSeq(self.cipher_suites, 2, 2)
+ w.addVarSeq(self.compression_methods, 1, 1)
+
+ extLength = 0
+ if self.certificate_types and self.certificate_types != \
+ [CertificateType.x509]:
+ extLength += 5 + len(self.certificate_types)
+ if self.srp_username:
+ extLength += 5 + len(self.srp_username)
+ if extLength > 0:
+ w.add(extLength, 2)
+
+ if self.certificate_types and self.certificate_types != \
+ [CertificateType.x509]:
+ w.add(7, 2)
+ w.add(len(self.certificate_types)+1, 2)
+ w.addVarSeq(self.certificate_types, 1, 1)
+ if self.srp_username:
+ w.add(6, 2)
+ w.add(len(self.srp_username)+1, 2)
+ w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
+
+ return HandshakeMsg.postWrite(self, w, trial)
+
+
+class ServerHello(HandshakeMsg):
+ def __init__(self):
+ self.contentType = ContentType.handshake
+ self.server_version = (0,0)
+ self.random = createByteArrayZeros(32)
+ self.session_id = createByteArraySequence([])
+ self.cipher_suite = 0
+ self.certificate_type = CertificateType.x509
+ self.compression_method = 0
+
+ def create(self, version, random, session_id, cipher_suite,
+ certificate_type):
+ self.server_version = version
+ self.random = random
+ self.session_id = session_id
+ self.cipher_suite = cipher_suite
+ self.certificate_type = certificate_type
+ self.compression_method = 0
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ self.server_version = (p.get(1), p.get(1))
+ self.random = p.getFixBytes(32)
+ self.session_id = p.getVarBytes(1)
+ self.cipher_suite = p.get(2)
+ self.compression_method = p.get(1)
+ if not p.atLengthCheck():
+ totalExtLength = p.get(2)
+ soFar = 0
+ while soFar != totalExtLength:
+ extType = p.get(2)
+ extLength = p.get(2)
+ if extType == 7:
+ self.certificate_type = p.get(1)
+ else:
+ p.getFixBytes(extLength)
+ soFar += 4 + extLength
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
+ w.add(self.server_version[0], 1)
+ w.add(self.server_version[1], 1)
+ w.addFixSeq(self.random, 1)
+ w.addVarSeq(self.session_id, 1, 1)
+ w.add(self.cipher_suite, 2)
+ w.add(self.compression_method, 1)
+
+ extLength = 0
+ if self.certificate_type and self.certificate_type != \
+ CertificateType.x509:
+ extLength += 5
+
+ if extLength != 0:
+ w.add(extLength, 2)
+
+ if self.certificate_type and self.certificate_type != \
+ CertificateType.x509:
+ w.add(7, 2)
+ w.add(1, 2)
+ w.add(self.certificate_type, 1)
+
+ return HandshakeMsg.postWrite(self, w, trial)
+
+class Certificate(HandshakeMsg):
+ def __init__(self, certificateType):
+ self.certificateType = certificateType
+ self.contentType = ContentType.handshake
+ self.certChain = None
+
+ def create(self, certChain):
+ self.certChain = certChain
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ if self.certificateType == CertificateType.x509:
+ chainLength = p.get(3)
+ index = 0
+ certificate_list = []
+ while index != chainLength:
+ certBytes = p.getVarBytes(3)
+ x509 = X509()
+ x509.parseBinary(certBytes)
+ certificate_list.append(x509)
+ index += len(certBytes)+3
+ if certificate_list:
+ self.certChain = X509CertChain(certificate_list)
+ elif self.certificateType == CertificateType.cryptoID:
+ s = bytesToString(p.getVarBytes(2))
+ if s:
+ try:
+ import cryptoIDlib.CertChain
+ except ImportError:
+ raise SyntaxError(\
+ "cryptoID cert chain received, cryptoIDlib not present")
+ self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
+ else:
+ raise AssertionError()
+
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
+ if self.certificateType == CertificateType.x509:
+ chainLength = 0
+ if self.certChain:
+ certificate_list = self.certChain.x509List
+ else:
+ certificate_list = []
+ #determine length
+ for cert in certificate_list:
+ bytes = cert.writeBytes()
+ chainLength += len(bytes)+3
+ #add bytes
+ w.add(chainLength, 3)
+ for cert in certificate_list:
+ bytes = cert.writeBytes()
+ w.addVarSeq(bytes, 1, 3)
+ elif self.certificateType == CertificateType.cryptoID:
+ if self.certChain:
+ bytes = stringToBytes(self.certChain.write())
+ else:
+ bytes = createByteArraySequence([])
+ w.addVarSeq(bytes, 1, 2)
+ else:
+ raise AssertionError()
+ return HandshakeMsg.postWrite(self, w, trial)
+
+class CertificateRequest(HandshakeMsg):
+ def __init__(self):
+ self.contentType = ContentType.handshake
+ self.certificate_types = []
+ #treat as opaque bytes for now
+ self.certificate_authorities = createByteArraySequence([])
+
+ def create(self, certificate_types, certificate_authorities):
+ self.certificate_types = certificate_types
+ self.certificate_authorities = certificate_authorities
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ self.certificate_types = p.getVarList(1, 1)
+ self.certificate_authorities = p.getVarBytes(2)
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
+ trial)
+ w.addVarSeq(self.certificate_types, 1, 1)
+ w.addVarSeq(self.certificate_authorities, 1, 2)
+ return HandshakeMsg.postWrite(self, w, trial)
+
+class ServerKeyExchange(HandshakeMsg):
+ def __init__(self, cipherSuite):
+ self.cipherSuite = cipherSuite
+ self.contentType = ContentType.handshake
+ self.srp_N = 0L
+ self.srp_g = 0L
+ self.srp_s = createByteArraySequence([])
+ self.srp_B = 0L
+ self.signature = createByteArraySequence([])
+
+ def createSRP(self, srp_N, srp_g, srp_s, srp_B):
+ self.srp_N = srp_N
+ self.srp_g = srp_g
+ self.srp_s = srp_s
+ self.srp_B = srp_B
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ self.srp_N = bytesToNumber(p.getVarBytes(2))
+ self.srp_g = bytesToNumber(p.getVarBytes(2))
+ self.srp_s = p.getVarBytes(1)
+ self.srp_B = bytesToNumber(p.getVarBytes(2))
+ if self.cipherSuite in CipherSuite.srpRsaSuites:
+ self.signature = p.getVarBytes(2)
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
+ trial)
+ w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
+ w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
+ w.addVarSeq(self.srp_s, 1, 1)
+ w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
+ if self.cipherSuite in CipherSuite.srpRsaSuites:
+ w.addVarSeq(self.signature, 1, 2)
+ return HandshakeMsg.postWrite(self, w, trial)
+
+ def hash(self, clientRandom, serverRandom):
+ oldCipherSuite = self.cipherSuite
+ self.cipherSuite = None
+ try:
+ bytes = clientRandom + serverRandom + self.write()[4:]
+ s = bytesToString(bytes)
+ return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
+ finally:
+ self.cipherSuite = oldCipherSuite
+
+class ServerHelloDone(HandshakeMsg):
+ def __init__(self):
+ self.contentType = ContentType.handshake
+
+ def create(self):
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
+ return HandshakeMsg.postWrite(self, w, trial)
+
+class ClientKeyExchange(HandshakeMsg):
+ def __init__(self, cipherSuite, version=None):
+ self.cipherSuite = cipherSuite
+ self.version = version
+ self.contentType = ContentType.handshake
+ self.srp_A = 0
+ self.encryptedPreMasterSecret = createByteArraySequence([])
+
+ def createSRP(self, srp_A):
+ self.srp_A = srp_A
+ return self
+
+ def createRSA(self, encryptedPreMasterSecret):
+ self.encryptedPreMasterSecret = encryptedPreMasterSecret
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ if self.cipherSuite in CipherSuite.srpSuites + \
+ CipherSuite.srpRsaSuites:
+ self.srp_A = bytesToNumber(p.getVarBytes(2))
+ elif self.cipherSuite in CipherSuite.rsaSuites:
+ if self.version in ((3,1), (3,2)):
+ self.encryptedPreMasterSecret = p.getVarBytes(2)
+ elif self.version == (3,0):
+ self.encryptedPreMasterSecret = \
+ p.getFixBytes(len(p.bytes)-p.index)
+ else:
+ raise AssertionError()
+ else:
+ raise AssertionError()
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
+ trial)
+ if self.cipherSuite in CipherSuite.srpSuites + \
+ CipherSuite.srpRsaSuites:
+ w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
+ elif self.cipherSuite in CipherSuite.rsaSuites:
+ if self.version in ((3,1), (3,2)):
+ w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
+ elif self.version == (3,0):
+ w.addFixSeq(self.encryptedPreMasterSecret, 1)
+ else:
+ raise AssertionError()
+ else:
+ raise AssertionError()
+ return HandshakeMsg.postWrite(self, w, trial)
+
+class CertificateVerify(HandshakeMsg):
+ def __init__(self):
+ self.contentType = ContentType.handshake
+ self.signature = createByteArraySequence([])
+
+ def create(self, signature):
+ self.signature = signature
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ self.signature = p.getVarBytes(2)
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
+ trial)
+ w.addVarSeq(self.signature, 1, 2)
+ return HandshakeMsg.postWrite(self, w, trial)
+
+class ChangeCipherSpec(Msg):
+ def __init__(self):
+ self.contentType = ContentType.change_cipher_spec
+ self.type = 1
+
+ def create(self):
+ self.type = 1
+ return self
+
+ def parse(self, p):
+ p.setLengthCheck(1)
+ self.type = p.get(1)
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = Msg.preWrite(self, trial)
+ w.add(self.type,1)
+ return Msg.postWrite(self, w, trial)
+
+
+class Finished(HandshakeMsg):
+ def __init__(self, version):
+ self.contentType = ContentType.handshake
+ self.version = version
+ self.verify_data = createByteArraySequence([])
+
+ def create(self, verify_data):
+ self.verify_data = verify_data
+ return self
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ if self.version == (3,0):
+ self.verify_data = p.getFixBytes(36)
+ elif self.version in ((3,1), (3,2)):
+ self.verify_data = p.getFixBytes(12)
+ else:
+ raise AssertionError()
+ p.stopLengthCheck()
+ return self
+
+ def write(self, trial=False):
+ w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
+ w.addFixSeq(self.verify_data, 1)
+ return HandshakeMsg.postWrite(self, w, trial)
+
+class ApplicationData(Msg):
+ def __init__(self):
+ self.contentType = ContentType.application_data
+ self.bytes = createByteArraySequence([])
+
+ def create(self, bytes):
+ self.bytes = bytes
+ return self
+
+ def parse(self, p):
+ self.bytes = p.bytes
+ return self
+
+ def write(self):
+ return self.bytes \ No newline at end of file