summaryrefslogtreecommitdiffstats
path: root/src/lib/tlslite/utils/codec.py
blob: 13022a0b93261992801629c4b828ee19feb5d3c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Classes for reading/writing binary data (such as TLS records)."""

from compat import *

class Writer:
    def __init__(self, length=0):
        #If length is zero, then this is just a "trial run" to determine length
        self.index = 0
        self.bytes = createByteArrayZeros(length)

    def add(self, x, length):
        if self.bytes:
            newIndex = self.index+length-1
            while newIndex >= self.index:
                self.bytes[newIndex] = x & 0xFF
                x >>= 8
                newIndex -= 1
        self.index += length

    def addFixSeq(self, seq, length):
        if self.bytes:
            for e in seq:
                self.add(e, length)
        else:
            self.index += len(seq)*length

    def addVarSeq(self, seq, length, lengthLength):
        if self.bytes:
            self.add(len(seq)*length, lengthLength)
            for e in seq:
                self.add(e, length)
        else:
            self.index += lengthLength + (len(seq)*length)


class Parser:
    def __init__(self, bytes):
        self.bytes = bytes
        self.index = 0

    def get(self, length):
        if self.index + length > len(self.bytes):
            raise SyntaxError()
        x = 0
        for count in range(length):
            x <<= 8
            x |= self.bytes[self.index]
            self.index += 1
        return x

    def getFixBytes(self, lengthBytes):
        bytes = self.bytes[self.index : self.index+lengthBytes]
        self.index += lengthBytes
        return bytes

    def getVarBytes(self, lengthLength):
        lengthBytes = self.get(lengthLength)
        return self.getFixBytes(lengthBytes)

    def getFixList(self, length, lengthList):
        l = [0] * lengthList
        for x in range(lengthList):
            l[x] = self.get(length)
        return l

    def getVarList(self, length, lengthLength):
        lengthList = self.get(lengthLength)
        if lengthList % length != 0:
            raise SyntaxError()
        lengthList = int(lengthList/length)
        l = [0] * lengthList
        for x in range(lengthList):
            l[x] = self.get(length)
        return l

    def startLengthCheck(self, lengthLength):
        self.lengthCheck = self.get(lengthLength)
        self.indexCheck = self.index

    def setLengthCheck(self, length):
        self.lengthCheck = length
        self.indexCheck = self.index

    def stopLengthCheck(self):
        if (self.index - self.indexCheck) != self.lengthCheck:
            raise SyntaxError()

    def atLengthCheck(self):
        if (self.index - self.indexCheck) < self.lengthCheck:
            return False
        elif (self.index - self.indexCheck) == self.lengthCheck:
            return True
        else:
            raise SyntaxError()