1# Copyright (c) 2003-2018 CORE Security Technologies
2#
3# This software is provided under under a slightly modified version
4# of the Apache Software License. See the accompanying LICENSE file
5# for more information.
6#
7from __future__ import division
8from __future__ import print_function
9from struct import pack, unpack, calcsize
10from six import b
11
12class Structure:
13    """ sublcasses can define commonHdr and/or structure.
14        each of them is an tuple of either two: (fieldName, format) or three: (fieldName, ':', class) fields.
15        [it can't be a dictionary, because order is important]
16
17        where format specifies how the data in the field will be converted to/from bytes (string)
18        class is the class to use when unpacking ':' fields.
19
20        each field can only contain one value (or an array of values for *)
21           i.e. struct.pack('Hl',1,2) is valid, but format specifier 'Hl' is not (you must use 2 dfferent fields)
22
23        format specifiers:
24          specifiers from module pack can be used with the same format
25          see struct.__doc__ (pack/unpack is finally called)
26            x       [padding byte]
27            c       [character]
28            b       [signed byte]
29            B       [unsigned byte]
30            h       [signed short]
31            H       [unsigned short]
32            l       [signed long]
33            L       [unsigned long]
34            i       [signed integer]
35            I       [unsigned integer]
36            q       [signed long long (quad)]
37            Q       [unsigned long long (quad)]
38            s       [string (array of chars), must be preceded with length in format specifier, padded with zeros]
39            p       [pascal string (includes byte count), must be preceded with length in format specifier, padded with zeros]
40            f       [float]
41            d       [double]
42            =       [native byte ordering, size and alignment]
43            @       [native byte ordering, standard size and alignment]
44            !       [network byte ordering]
45            <       [little endian]
46            >       [big endian]
47
48          usual printf like specifiers can be used (if started with %)
49          [not recommended, there is no way to unpack this]
50
51            %08x    will output an 8 bytes hex
52            %s      will output a string
53            %s\\x00  will output a NUL terminated string
54            %d%d    will output 2 decimal digits (against the very same specification of Structure)
55            ...
56
57          some additional format specifiers:
58            :       just copy the bytes from the field into the output string (input may be string, other structure, or anything responding to __str__()) (for unpacking, all what's left is returned)
59            z       same as :, but adds a NUL byte at the end (asciiz) (for unpacking the first NUL byte is used as terminator)  [asciiz string]
60            u       same as z, but adds two NUL bytes at the end (after padding to an even size with NULs). (same for unpacking) [unicode string]
61            w       DCE-RPC/NDR string (it's a macro for [  '<L=(len(field)+1)/2','"\\x00\\x00\\x00\\x00','<L=(len(field)+1)/2',':' ]
62            ?-field length of field named 'field', formatted as specified with ? ('?' may be '!H' for example). The input value overrides the real length
63            ?1*?2   array of elements. Each formatted as '?2', the number of elements in the array is stored as specified by '?1' (?1 is optional, or can also be a constant (number), for unpacking)
64            'xxxx   literal xxxx (field's value doesn't change the output. quotes must not be closed or escaped)
65            "xxxx   literal xxxx (field's value doesn't change the output. quotes must not be closed or escaped)
66            _       will not pack the field. Accepts a third argument, which is an unpack code. See _Test_UnpackCode for an example
67            ?=packcode  will evaluate packcode in the context of the structure, and pack the result as specified by ?. Unpacking is made plain
68            ?&fieldname "Address of field fieldname".
69                        For packing it will simply pack the id() of fieldname. Or use 0 if fieldname doesn't exists.
70                        For unpacking, it's used to know weather fieldname has to be unpacked or not, i.e. by adding a & field you turn another field (fieldname) in an optional field.
71
72    """
73    commonHdr = ()
74    structure = ()
75    debug = 0
76
77    def __init__(self, data = None, alignment = 0):
78        if not hasattr(self, 'alignment'):
79            self.alignment = alignment
80
81        self.fields    = {}
82        self.rawData   = data
83        if data is not None:
84            self.fromString(data)
85        else:
86            self.data = None
87
88    @classmethod
89    def fromFile(self, file):
90        answer = self()
91        answer.fromString(file.read(len(answer)))
92        return answer
93
94    def setAlignment(self, alignment):
95        self.alignment = alignment
96
97    def setData(self, data):
98        self.data = data
99
100    def packField(self, fieldName, format = None):
101        if self.debug:
102            print("packField( %s | %s )" % (fieldName, format))
103
104        if format is None:
105            format = self.formatForField(fieldName)
106
107        if fieldName in self.fields:
108            ans = self.pack(format, self.fields[fieldName], field = fieldName)
109        else:
110            ans = self.pack(format, None, field = fieldName)
111
112        if self.debug:
113            print("\tanswer %r" % ans)
114
115        return ans
116
117    def getData(self):
118        if self.data is not None:
119            return self.data
120        data = bytes()
121        for field in self.commonHdr+self.structure:
122            try:
123                data += self.packField(field[0], field[1])
124            except Exception as e:
125                if field[0] in self.fields:
126                    e.args += ("When packing field '%s | %s | %r' in %s" % (field[0], field[1], self[field[0]], self.__class__),)
127                else:
128                    e.args += ("When packing field '%s | %s' in %s" % (field[0], field[1], self.__class__),)
129                raise
130            if self.alignment:
131                if len(data) % self.alignment:
132                    data += (b'\x00'*self.alignment)[:-(len(data) % self.alignment)]
133
134        #if len(data) % self.alignment: data += ('\x00'*self.alignment)[:-(len(data) % self.alignment)]
135        return data
136
137    def fromString(self, data):
138        self.rawData = data
139        for field in self.commonHdr+self.structure:
140            if self.debug:
141                print("fromString( %s | %s | %r )" % (field[0], field[1], data))
142            size = self.calcUnpackSize(field[1], data, field[0])
143            if self.debug:
144                print("  size = %d" % size)
145            dataClassOrCode = b
146            if len(field) > 2:
147                dataClassOrCode = field[2]
148            try:
149                self[field[0]] = self.unpack(field[1], data[:size], dataClassOrCode = dataClassOrCode, field = field[0])
150            except Exception as e:
151                e.args += ("When unpacking field '%s | %s | %r[:%d]'" % (field[0], field[1], data, size),)
152                raise
153
154            size = self.calcPackSize(field[1], self[field[0]], field[0])
155            if self.alignment and size % self.alignment:
156                size += self.alignment - (size % self.alignment)
157            data = data[size:]
158
159        return self
160
161    def __setitem__(self, key, value):
162        self.fields[key] = value
163        self.data = None        # force recompute
164
165    def __getitem__(self, key):
166        return self.fields[key]
167
168    def __delitem__(self, key):
169        del self.fields[key]
170
171    def __str__(self):
172        return self.getData()
173
174    def __len__(self):
175        # XXX: improve
176        return len(self.getData())
177
178    def pack(self, format, data, field = None):
179        if self.debug:
180            print("  pack( %s | %r | %s)" %  (format, data, field))
181
182        if field:
183            addressField = self.findAddressFieldFor(field)
184            if (addressField is not None) and (data is None):
185                return b''
186
187        # void specifier
188        if format[:1] == '_':
189            return b''
190
191        # quote specifier
192        if format[:1] == "'" or format[:1] == '"':
193            return b(format[1:])
194
195        # code specifier
196        two = format.split('=')
197        if len(two) >= 2:
198            try:
199                return self.pack(two[0], data)
200            except:
201                fields = {'self':self}
202                fields.update(self.fields)
203                return self.pack(two[0], eval(two[1], {}, fields))
204
205        # address specifier
206        two = format.split('&')
207        if len(two) == 2:
208            try:
209                return self.pack(two[0], data)
210            except:
211                if (two[1] in self.fields) and (self[two[1]] is not None):
212                    return self.pack(two[0], id(self[two[1]]) & ((1<<(calcsize(two[0])*8))-1) )
213                else:
214                    return self.pack(two[0], 0)
215
216        # length specifier
217        two = format.split('-')
218        if len(two) == 2:
219            try:
220                return self.pack(two[0],data)
221            except:
222                return self.pack(two[0], self.calcPackFieldSize(two[1]))
223
224        # array specifier
225        two = format.split('*')
226        if len(two) == 2:
227            answer = bytes()
228            for each in data:
229                answer += self.pack(two[1], each)
230            if two[0]:
231                if two[0].isdigit():
232                    if int(two[0]) != len(data):
233                        raise Exception("Array field has a constant size, and it doesn't match the actual value")
234                else:
235                    return self.pack(two[0], len(data))+answer
236            return answer
237
238        # "printf" string specifier
239        if format[:1] == '%':
240            # format string like specifier
241            return b(format % data)
242
243        # asciiz specifier
244        if format[:1] == 'z':
245            return bytes(b(data)+b('\0'))
246
247        # unicode specifier
248        if format[:1] == 'u':
249            return bytes(data+b('\0\0') + (len(data) & 1 and b('\0') or b''))
250
251        # DCE-RPC/NDR string specifier
252        if format[:1] == 'w':
253            if len(data) == 0:
254                data = b('\0\0')
255            elif len(data) % 2:
256                data = b(data) + b('\0')
257            l = pack('<L', len(data)//2)
258            return b''.join([l, l, b('\0\0\0\0'), data])
259
260        if data is None:
261            raise Exception("Trying to pack None")
262
263        # literal specifier
264        if format[:1] == ':':
265            if isinstance(data, Structure):
266                return data.getData()
267            elif isinstance(data, bytes) != True:
268                return bytes(b(data))
269            else:
270                return data
271
272        # struct like specifier
273        return pack(format, data)
274
275    def unpack(self, format, data, dataClassOrCode = b, field = None):
276        if self.debug:
277            print("  unpack( %s | %r )" %  (format, data))
278
279        if field:
280            addressField = self.findAddressFieldFor(field)
281            if addressField is not None:
282                if not self[addressField]:
283                    return
284
285        # void specifier
286        if format[:1] == '_':
287            if dataClassOrCode != b:
288                fields = {'self':self, 'inputDataLeft':data}
289                fields.update(self.fields)
290                return eval(dataClassOrCode, {}, fields)
291            else:
292                return None
293
294        # quote specifier
295        if format[:1] == "'" or format[:1] == '"':
296            answer = format[1:]
297            if b(answer) != data:
298                raise Exception("Unpacked data doesn't match constant value '%r' should be '%r'" % (data, answer))
299            return answer
300
301        # address specifier
302        two = format.split('&')
303        if len(two) == 2:
304            return self.unpack(two[0],data)
305
306        # code specifier
307        two = format.split('=')
308        if len(two) >= 2:
309            return self.unpack(two[0],data)
310
311        # length specifier
312        two = format.split('-')
313        if len(two) == 2:
314            return self.unpack(two[0],data)
315
316        # array specifier
317        two = format.split('*')
318        if len(two) == 2:
319            answer = []
320            sofar = 0
321            if two[0].isdigit():
322                number = int(two[0])
323            elif two[0]:
324                sofar += self.calcUnpackSize(two[0], data)
325                number = self.unpack(two[0], data[:sofar])
326            else:
327                number = -1
328
329            while number and sofar < len(data):
330                nsofar = sofar + self.calcUnpackSize(two[1],data[sofar:])
331                answer.append(self.unpack(two[1], data[sofar:nsofar], dataClassOrCode))
332                number -= 1
333                sofar = nsofar
334            return answer
335
336        # "printf" string specifier
337        if format[:1] == '%':
338            # format string like specifier
339            return format % data
340
341        # asciiz specifier
342        if format == 'z':
343            if data[-1:] != b('\x00'):
344                raise Exception("%s 'z' field is not NUL terminated: %r" % (field, data))
345            return data[:-1].decode('ascii') # remove trailing NUL
346
347        # unicode specifier
348        if format == 'u':
349            if data[-2:] != b('\x00\x00'):
350                raise Exception("%s 'u' field is not NUL-NUL terminated: %r" % (field, data))
351            return data[:-2] # remove trailing NUL
352
353        # DCE-RPC/NDR string specifier
354        if format == 'w':
355            l = unpack('<L', data[:4])[0]
356            return data[12:12+l*2]
357
358        # literal specifier
359        if format == ':':
360            if isinstance(data, bytes) and dataClassOrCode is b:
361                return data
362            return dataClassOrCode(data)
363
364        # struct like specifier
365        return unpack(format, data)[0]
366
367    def calcPackSize(self, format, data, field = None):
368#        # print "  calcPackSize  %s:%r" %  (format, data)
369        if field:
370            addressField = self.findAddressFieldFor(field)
371            if addressField is not None:
372                if not self[addressField]:
373                    return 0
374
375        # void specifier
376        if format[:1] == '_':
377            return 0
378
379        # quote specifier
380        if format[:1] == "'" or format[:1] == '"':
381            return len(format)-1
382
383        # address specifier
384        two = format.split('&')
385        if len(two) == 2:
386            return self.calcPackSize(two[0], data)
387
388        # code specifier
389        two = format.split('=')
390        if len(two) >= 2:
391            return self.calcPackSize(two[0], data)
392
393        # length specifier
394        two = format.split('-')
395        if len(two) == 2:
396            return self.calcPackSize(two[0], data)
397
398        # array specifier
399        two = format.split('*')
400        if len(two) == 2:
401            answer = 0
402            if two[0].isdigit():
403                    if int(two[0]) != len(data):
404                        raise Exception("Array field has a constant size, and it doesn't match the actual value")
405            elif two[0]:
406                answer += self.calcPackSize(two[0], len(data))
407
408            for each in data:
409                answer += self.calcPackSize(two[1], each)
410            return answer
411
412        # "printf" string specifier
413        if format[:1] == '%':
414            # format string like specifier
415            return len(format % data)
416
417        # asciiz specifier
418        if format[:1] == 'z':
419            return len(data)+1
420
421        # asciiz specifier
422        if format[:1] == 'u':
423            l = len(data)
424            return l + (l & 1 and 3 or 2)
425
426        # DCE-RPC/NDR string specifier
427        if format[:1] == 'w':
428            l = len(data)
429            return 12+l+l % 2
430
431        # literal specifier
432        if format[:1] == ':':
433            return len(data)
434
435        # struct like specifier
436        return calcsize(format)
437
438    def calcUnpackSize(self, format, data, field = None):
439        if self.debug:
440            print("  calcUnpackSize( %s | %s | %r)" %  (field, format, data))
441
442        # void specifier
443        if format[:1] == '_':
444            return 0
445
446        addressField = self.findAddressFieldFor(field)
447        if addressField is not None:
448            if not self[addressField]:
449                return 0
450
451        try:
452            lengthField = self.findLengthFieldFor(field)
453            return int(self[lengthField])
454        except:
455            pass
456
457        # XXX: Try to match to actual values, raise if no match
458
459        # quote specifier
460        if format[:1] == "'" or format[:1] == '"':
461            return len(format)-1
462
463        # address specifier
464        two = format.split('&')
465        if len(two) == 2:
466            return self.calcUnpackSize(two[0], data)
467
468        # code specifier
469        two = format.split('=')
470        if len(two) >= 2:
471            return self.calcUnpackSize(two[0], data)
472
473        # length specifier
474        two = format.split('-')
475        if len(two) == 2:
476            return self.calcUnpackSize(two[0], data)
477
478        # array specifier
479        two = format.split('*')
480        if len(two) == 2:
481            answer = 0
482            if two[0]:
483                if two[0].isdigit():
484                    number = int(two[0])
485                else:
486                    answer += self.calcUnpackSize(two[0], data)
487                    number = self.unpack(two[0], data[:answer])
488
489                while number:
490                    number -= 1
491                    answer += self.calcUnpackSize(two[1], data[answer:])
492            else:
493                while answer < len(data):
494                    answer += self.calcUnpackSize(two[1], data[answer:])
495            return answer
496
497        # "printf" string specifier
498        if format[:1] == '%':
499            raise Exception("Can't guess the size of a printf like specifier for unpacking")
500
501        # asciiz specifier
502        if format[:1] == 'z':
503            return data.index(b('\x00'))+1
504
505        # asciiz specifier
506        if format[:1] == 'u':
507            l = data.index(b('\x00\x00'))
508            return l + (l & 1 and 3 or 2)
509
510        # DCE-RPC/NDR string specifier
511        if format[:1] == 'w':
512            l = unpack('<L', data[:4])[0]
513            return 12+l*2
514
515        # literal specifier
516        if format[:1] == ':':
517            return len(data)
518
519        # struct like specifier
520        return calcsize(format)
521
522    def calcPackFieldSize(self, fieldName, format = None):
523        if format is None:
524            format = self.formatForField(fieldName)
525
526        return self.calcPackSize(format, self[fieldName])
527
528    def formatForField(self, fieldName):
529        for field in self.commonHdr+self.structure:
530            if field[0] == fieldName:
531                return field[1]
532        raise Exception("Field %s not found" % fieldName)
533
534    def findAddressFieldFor(self, fieldName):
535        descriptor = '&%s' % fieldName
536        l = len(descriptor)
537        for field in self.commonHdr+self.structure:
538            if field[1][-l:] == descriptor:
539                return field[0]
540        return None
541
542    def findLengthFieldFor(self, fieldName):
543        descriptor = '-%s' % fieldName
544        l = len(descriptor)
545        for field in self.commonHdr+self.structure:
546            if field[1][-l:] == descriptor:
547                return field[0]
548        return None
549
550    def zeroValue(self, format):
551        two = format.split('*')
552        if len(two) == 2:
553            if two[0].isdigit():
554                return (self.zeroValue(two[1]),)*int(two[0])
555
556        if not format.find('*') == -1: return ()
557        if 's' in format: return b''
558        if format in ['z',':','u']: return b''
559        if format == 'w': return b('\x00\x00')
560
561        return 0
562
563    def clear(self):
564        for field in self.commonHdr + self.structure:
565            self[field[0]] = self.zeroValue(field[1])
566
567    def dump(self, msg = None, indent = 0):
568        if msg is None: msg = self.__class__.__name__
569        ind = ' '*indent
570        print("\n%s" % msg)
571        fixedFields = []
572        for field in self.commonHdr+self.structure:
573            i = field[0]
574            if i in self.fields:
575                fixedFields.append(i)
576                if isinstance(self[i], Structure):
577                    self[i].dump('%s%s:{' % (ind,i), indent = indent + 4)
578                    print("%s}" % ind)
579                else:
580                    print("%s%s: {%r}" % (ind,i,self[i]))
581        # Do we have remaining fields not defined in the structures? let's
582        # print them
583        remainingFields = list(set(self.fields) - set(fixedFields))
584        for i in remainingFields:
585            if isinstance(self[i], Structure):
586                self[i].dump('%s%s:{' % (ind,i), indent = indent + 4)
587                print("%s}" % ind)
588            else:
589                print("%s%s: {%r}" % (ind,i,self[i]))
590
591def pretty_print(x):
592    if chr(x) in '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ':
593       return chr(x)
594    else:
595       return u'.'
596
597def hexdump(data, indent = ''):
598    x=bytearray(b(data))
599    strLen = len(x)
600    i = 0
601    while i < strLen:
602        line = " %s%04x   " % (indent, i)
603        for j in range(16):
604            if i+j < strLen:
605                line += "%02X " % x[i+j]
606            else:
607                line += u"   "
608            if j%16 == 7:
609                line += " "
610        line += "  "
611        line += ''.join(pretty_print(x) for x in x[i:i+16] )
612        print (line)
613        i += 16
614
615class _StructureTest:
616    alignment = 0
617    def create(self,data = None):
618        if data is not None:
619            return self.theClass(data, alignment = self.alignment)
620        else:
621            return self.theClass(alignment = self.alignment)
622
623    def run(self):
624        print()
625        print("-"*70)
626        testName = self.__class__.__name__
627        print("starting test: %s....." % testName)
628        a = self.create()
629        self.populate(a)
630        a.dump("packing.....")
631        a_str = a.getData()
632        print("packed: %r" % a_str)
633        print("unpacking.....")
634        b = self.create(a_str)
635        b.dump("unpacked.....")
636        print("repacking.....")
637        b_str = b.getData()
638        hexdump(b_str)
639        if b_str != a_str:
640            print("ERROR: original packed and repacked don't match")
641            print("packed: %r" % b_str)
642
643class _Test_simple(_StructureTest):
644    class theClass(Structure):
645        commonHdr = ()
646        structure = (
647                ('int1', '!L'),
648                ('len1','!L-z1'),
649                ('arr1','B*<L'),
650                ('z1', 'z'),
651                ('u1','u'),
652                ('', '"COCA'),
653                ('len2','!H-:1'),
654                ('', '"COCA'),
655                (':1', ':'),
656                ('int3','>L'),
657                ('code1','>L=len(arr1)*2+0x1000'),
658                )
659
660    def populate(self, a):
661        a['default'] = 'hola'
662        a['int1'] = 0x3131
663        a['int3'] = 0x45444342
664        a['z1']   = 'hola'
665        a['u1']   = 'hola'.encode('utf_16_le')
666        a[':1']   = ':1234:'
667        a['arr1'] = (0x12341234,0x88990077,0x41414141)
668        # a['len1'] = 0x42424242
669
670class _Test_fixedLength(_Test_simple):
671    def populate(self, a):
672        _Test_simple.populate(self, a)
673        a['len1'] = 0x42424242
674
675class _Test_simple_aligned4(_Test_simple):
676    alignment = 4
677
678class _Test_nested(_StructureTest):
679    class theClass(Structure):
680        class _Inner(Structure):
681            structure = (('data', 'z'),)
682
683        structure = (
684            ('nest1', ':', _Inner),
685            ('nest2', ':', _Inner),
686            ('int', '<L'),
687        )
688
689    def populate(self, a):
690        a['nest1'] = _Test_nested.theClass._Inner()
691        a['nest2'] = _Test_nested.theClass._Inner()
692        a['nest1']['data'] = 'hola manola'
693        a['nest2']['data'] = 'chau loco'
694        a['int'] = 0x12345678
695
696class _Test_Optional(_StructureTest):
697    class theClass(Structure):
698        structure = (
699                ('pName','<L&Name'),
700                ('pList','<L&List'),
701                ('Name','w'),
702                ('List','<H*<L'),
703            )
704
705    def populate(self, a):
706        a['Name'] = 'Optional test'
707        a['List'] = (1,2,3,4)
708
709class _Test_Optional_sparse(_Test_Optional):
710    def populate(self, a):
711        _Test_Optional.populate(self, a)
712        del a['Name']
713
714class _Test_AsciiZArray(_StructureTest):
715    class theClass(Structure):
716        structure = (
717            ('head','<L'),
718            ('array','B*z'),
719            ('tail','<L'),
720        )
721
722    def populate(self, a):
723        a['head'] = 0x1234
724        a['tail'] = 0xabcd
725        a['array'] = ('hola','manola','te traje')
726
727class _Test_UnpackCode(_StructureTest):
728    class theClass(Structure):
729        structure = (
730            ('leni','<L=len(uno)*2'),
731            ('cuchi','_-uno','leni//2'),
732            ('uno',':'),
733            ('dos',':'),
734        )
735
736    def populate(self, a):
737        a['uno'] = 'soy un loco!'
738        a['dos'] = 'que haces fiera'
739
740class _Test_AAA(_StructureTest):
741    class theClass(Structure):
742        commonHdr = ()
743        structure = (
744          ('iv', '!L=((init_vector & 0xFFFFFF) << 8) | ((pad & 0x3f) << 2) | (keyid & 3)'),
745          ('init_vector',   '_','(iv >> 8)'),
746          ('pad',           '_','((iv >>2) & 0x3F)'),
747          ('keyid',         '_','( iv & 0x03 )'),
748          ('dataLen',       '_-data', 'len(inputDataLeft)-4'),
749          ('data',':'),
750          ('icv','>L'),
751        )
752
753    def populate(self, a):
754        a['init_vector']=0x01020304
755        #a['pad']=int('01010101',2)
756        a['pad']=int('010101',2)
757        a['keyid']=0x07
758        a['data']="\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9"
759        a['icv'] = 0x05060708
760        #a['iv'] = 0x01020304
761
762if __name__ == '__main__':
763    _Test_simple().run()
764
765    try:
766        _Test_fixedLength().run()
767    except:
768        print("cannot repack because length is bogus")
769
770    _Test_simple_aligned4().run()
771    _Test_nested().run()
772    _Test_Optional().run()
773    _Test_Optional_sparse().run()
774    _Test_AsciiZArray().run()
775    _Test_UnpackCode().run()
776    _Test_AAA().run()
777