1# Copyright (c) 2015-2021 by Rocky Bernstein
2# Copyright (c) 2000-2002 by hartmut Goebel <h.goebel@crazy-compilers.com>
3#
4#  This program is free software; you can redistribute it and/or
5#  modify it under the terms of the GNU General Public License
6#  as published by the Free Software Foundation; either version 2
7#  of the License, or (at your option) any later version.
8#
9#  This program is distributed in the hope that it will be useful,
10#  but WITHOUT ANY WARRANTY; without even the implied warranty of
11#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12#  GNU General Public License for more details.
13#
14#  You should have received a copy of the GNU General Public License
15#  along with this program; if not, write to the Free Software
16#  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
17
18"""CPython magic- and version-independent Python object
19deserialization (unmarshal).
20
21This is needed when the bytecode extracted is from
22a different version than the currently-running Python.
23
24When the running interpreter and the read-in bytecode are the same,
25you can simply use Python's built-in marshal.loads() to produce a code
26object
27"""
28
29import io
30import sys
31from struct import unpack
32
33from xdis.magics import magic_int2float
34from xdis.codetype import to_portable
35from xdis.version_info import PYTHON3, PYTHON_VERSION, IS_PYPY
36
37if PYTHON3:
38
39    def long(n):
40        return n
41
42
43else:
44    import unicodedata
45    # FIXME: we should write a bytes() class with a repr
46    # that prints the b'' prefix so that Python2 can
47    # print out Python3 code correctly
48
49# Bit set on marshalType if we should
50# add obj to internObjects.
51# FLAG_REF is the marshal.c name
52FLAG_REF = 0x80
53
54
55# The keys in following dictionary are an unmashal codes, like "s", "c", "<", etc.
56# the values of the dictionary are names of routines to call that do the data unmarshaling.
57#
58# Note: we could eliminate the parameters, if this were all inside a
59# class.  This might be good from an efficiency standpoint, and bad
60# from a functional-programming standpoint. Pick your poison.
61# EDIT: I'm choosing efficiency over functional-programming.
62UNMARSHAL_DISPATCH_TABLE = {
63    "0": "C_NULL",
64    "N": "None",
65    "S": "stopIteration",
66    ".": "Elipsis",
67    "F": "False",
68    "T": "True",
69    "i": "int32",
70    "l": "long",
71    "I": "int64",
72    "f": "float",
73    "g": "binary_float",
74    "x": "complex",
75    "y": "binary_complex",
76    "s": "string",
77    "A": "ASCII_interned",
78    "a": "ASCII",
79    "z": "short_ASCII",
80    "Z": "short_ASCII_interned",
81    "t": "interned",
82    "u": "unicode",
83    ")": "small_tuple",
84    "(": "tuple",
85    "[": "list",
86    "<": "frozenset",
87    ">": "set",
88    "i": "int32",
89    "{": "dict",
90    "R": "python2_string_reference",
91    "c": "code",
92    "C": "code",  # Older Python code
93    "r": "object_reference",
94    "?": "unknown",
95}
96
97
98def compat_str(s: str) -> str:
99    """
100    This handles working with strings between Python2 and Python3.
101    """
102    if PYTHON3:
103        try:
104            return s.decode("utf-8")
105        except UnicodeDecodeError:
106            # If not Unicode, return bytes
107            # and it will get converted to str when needed
108            return s
109
110        return s.decode()
111    else:
112        return str(s)
113
114
115def compat_u2s(u):
116    if PYTHON_VERSION < 3.0:
117        # See also unaccent.py which can be found using google. I
118        # found it and this code via
119        # https://www.peterbe.com/plog/unicode-to-ascii where it is a
120        # dead link. That can potentially do better job in converting accents.
121        s = unicodedata.normalize("NFKD", u)
122        try:
123            return s.encode("ascii")
124        except UnicodeEncodeError:
125            return s
126    else:
127        return str(u)
128
129
130class _VersionIndependentUnmarshaller:
131    def __init__(self, fp, magic_int, bytes_for_s, code_objects={}):
132        """
133        Marshal versions:
134            0/Historical: Until 2.4/magic int 62041
135            1: [2.4, 2.5) (self.magic_int: 62041 until 62071)
136            2: [2.5, 3.4a0) (self.magic_int: 62071 until 3250)
137            3: [3.4a0, 3.4a3) (self.magic_int: 3250 until 3280)
138            4: [3.4a3, current) (self.magic_int: 3280 onwards)
139
140        In Python 3 a bytes type is used for strings.
141        """
142        self.fp = fp
143        self.magic_int = magic_int
144        self.code_objects = code_objects
145
146        self.bytes_for_s = bytes_for_s
147        version_float = magic_int2float(self.magic_int)
148        if version_float >= 3.4:
149            if self.magic_int in (3250, 3260, 3270):
150                self.marshal_version = 3
151            else:
152                self.marshal_version = 4
153        elif version_float < 3.4 and version_float >= 2.5:
154            self.marshal_version = 2
155        elif version_float < 2.5 and version_float >= 2.4:
156            self.marshal_version = 1
157        else:
158            self.marshal_version = 0
159
160        self.internStrings = []
161        self.internObjects = []
162
163    def load(self):
164        """
165        marshal.load() written in Python. When the Python bytecode magic loaded is the
166        same magic for the running Python interpreter, we can simply use the
167        Python-supplied marshal.load().
168
169        However we need to use this when versions are different since the internal
170        code structures are different. Sigh.
171        """
172
173        if self.marshal_version == 0:
174            self.internStrings = []
175        if self.marshal_version < 3:
176            assert self.internObjects == []
177
178        return self.r_object()
179
180    # Python 3.4+ support for reference objects.
181    # The names follow marshal.c
182    def r_ref_reserve(self, obj, save_ref):
183        i = None
184        if save_ref:
185            i = len(self.internObjects)
186            self.internObjects.append(obj)
187        return obj, i
188
189    def r_ref_insert(self, obj, i):
190        if i is not None:
191            self.internObjects[i] = obj
192        return obj
193
194    def r_ref(self, obj, save_ref):
195        if save_ref:
196            self.internObjects.append(obj)
197        return obj
198
199    # In marshal.c this is one big case statement
200    def r_object(self, bytes_for_s=False):
201        """
202        In Python3 strings are bytes type
203        """
204        byte1 = ord(self.fp.read(1))
205
206        # FLAG_REF indiates whether we "intern" or
207        # save a reference to the object.
208        # byte1 without that reference is the
209        # marshal type code, an ASCII character.
210        save_ref = False
211        if byte1 & FLAG_REF:
212            # Since 3.4, "flag" is the marshal.c name
213            save_ref = True
214            byte1 = byte1 & (FLAG_REF - 1)
215        marshalType = chr(byte1)
216
217        # print(marshalType) # debug
218        if marshalType in UNMARSHAL_DISPATCH_TABLE:
219            func_suffix = UNMARSHAL_DISPATCH_TABLE[marshalType]
220            unmarshal_func = getattr(self, "t_" + func_suffix)
221            return unmarshal_func(save_ref, bytes_for_s)
222        else:
223            try:
224                sys.stderr.write(
225                    "Unknown type %i (hex %x) %c\n"
226                    % (ord(marshalType), hex(ord(marshalType)), marshalType)
227                )
228            except TypeError:
229                sys.stderr.write(
230                    "Unknown type %i %c\n" % (ord(marshalType), marshalType)
231                )
232
233        return
234
235    # In C this NULL. Not sure what it should
236    # translate here. Note NULL != None which is below
237    def t_C_NULL(self, save_ref, bytes_for_s=False):
238        return None
239
240    def t_None(self, save_ref, bytes_for_s=False):
241        return None
242
243    def t_stopIteration(self, save_ref, bytes_for_s=False):
244        return StopIteration
245
246    def t_Elipsis(self, save_ref, bytes_for_s=False):
247        return Ellipsis
248
249    def t_False(self, save_ref, bytes_for_s=False):
250        return False
251
252    def t_True(self, save_ref, bytes_for_s=False):
253        return True
254
255    def t_int32(self, save_ref, bytes_for_s=False):
256        return self.r_ref(int(unpack("<i", self.fp.read(4))[0]), save_ref)
257
258    def t_long(self, save_ref, bytes_for_s=False):
259        n = unpack("<i", self.fp.read(4))[0]
260        if n == 0:
261            return long(0)
262        size = abs(n)
263        d = long(0)
264        for j in range(0, size):
265            md = int(unpack("<h", self.fp.read(2))[0])
266            d += md << j * 15
267        if n < 0:
268            d = long(d * -1)
269        return self.r_ref(d, save_ref)
270
271    # Python 3.4 removed this.
272    def t_int64(self, save_ref, bytes_for_s=False):
273        obj = unpack("<q", self.fp.read(8))[0]
274        if save_ref:
275            self.internObjects.append(obj)
276        return obj
277
278    # float - Seems not in use after Python 2.4
279    def t_float(self, save_ref, bytes_for_s=False):
280        strsize = unpack("B", self.fp.read(1))[0]
281        s = self.fp.read(strsize)
282        return self.r_ref(float(s), save_ref)
283
284    def t_binary_float(self, save_ref, bytes_for_s=False):
285        return self.r_ref(float(unpack("<d", self.fp.read(8))[0]), save_ref)
286
287    def t_complex(self, save_ref, bytes_for_s=False):
288        if self.magic_int <= 62061:
289            get_float = lambda: float(self.fp.read(unpack("B", self.fp.read(1))[0]))
290        else:
291            get_float = lambda: float(self.fp.read(unpack("<i", self.fp.read(4))[0]))
292        real = get_float()
293        imag = get_float()
294        return self.r_ref(complex(real, imag), save_ref)
295
296    def t_binary_complex(self, save_ref, bytes_for_s=False):
297        # binary complex
298        real = unpack("<d", self.fp.read(8))[0]
299        imag = unpack("<d", self.fp.read(8))[0]
300        return self.r_ref(complex(real, imag), save_ref)
301
302    # Note: could mean bytes in Python3 processing Python2 bytecode
303    def t_string(self, save_ref, bytes_for_s: bool):
304        """
305        In Python3 this is a bytes types. In Python2 it is a string.
306        `bytes_for_s` distinguishes what we need.
307        """
308        strsize = unpack("<i", self.fp.read(4))[0]
309        s = self.fp.read(strsize)
310        if not bytes_for_s:
311            s = compat_str(s)
312        return self.r_ref(s, save_ref)
313
314    # Python 3.4
315    def t_ASCII_interned(self, save_ref, bytes_for_s=False):
316        """
317        There are true strings in Python3 as opposed to
318        bytes. "interned" just means we keep a reference to
319        the string.
320        """
321        # FIXME: check
322        strsize = unpack("<i", self.fp.read(4))[0]
323        interned = compat_str(self.fp.read(strsize))
324        self.internStrings.append(interned)
325        return self.r_ref(interned, save_ref)
326
327    # Since Python 3.4
328    def t_ASCII(self, save_ref, bytes_for_s=False):
329        """
330        There are true strings in Python3 as opposed to
331        bytes.
332        """
333        strsize = unpack("<i", self.fp.read(4))[0]
334        s = self.fp.read(strsize)
335        s = compat_str(s)
336        return self.r_ref(s, save_ref)
337
338    # Since Python 3.4
339    def t_short_ASCII(self, save_ref, bytes_for_s=False):
340        strsize = unpack("B", self.fp.read(1))[0]
341        return self.r_ref(compat_str(self.fp.read(strsize)), save_ref)
342
343    # Since Python 3.4
344    def t_short_ASCII_interned(self, save_ref, bytes_for_s=False):
345        # FIXME: check
346        strsize = unpack("B", self.fp.read(1))[0]
347        interned = compat_str(self.fp.read(strsize))
348        self.internStrings.append(interned)
349        return self.r_ref(interned, save_ref)
350
351    # Since Python 3.4
352    def t_interned(self, save_ref, bytes_for_s=False):
353        strsize = unpack("<i", self.fp.read(4))[0]
354        interned = compat_str(self.fp.read(strsize))
355        self.internStrings.append(interned)
356        return self.r_ref(interned, save_ref)
357
358    def t_unicode(self, save_ref, bytes_for_s=False):
359        strsize = unpack("<i", self.fp.read(4))[0]
360        unicodestring = self.fp.read(strsize)
361        if PYTHON_VERSION == 3.2 and IS_PYPY:
362            # FIXME: this isn't quite right. See
363            # pypy3-2.4.0/lib-python/3/email/message.py
364            # '([^\ud800-\udbff]|\A)[\udc00-\udfff]([^\udc00-\udfff]|\Z)')
365            return self.r_ref(unicodestring.decode("utf-8", errors="ignore"), save_ref)
366        else:
367            try:
368                return self.r_ref(unicodestring.decode("utf-8"), save_ref)
369            except UnicodeDecodeError as e:
370                return self.r_ref(
371                    unicodestring.decode("utf-8", errors="ignore"), save_ref
372                )
373
374    # Since Python 3.4
375    def t_small_tuple(self, save_ref, bytes_for_s=False):
376        # small tuple - since Python 3.4
377        tuplesize = unpack("B", self.fp.read(1))[0]
378        ret, i = self.r_ref_reserve(tuple(), save_ref)
379        while tuplesize > 0:
380            ret += (self.r_object(bytes_for_s=bytes_for_s),)
381            tuplesize -= 1
382            pass
383        return self.r_ref_insert(ret, i)
384
385    def t_tuple(self, save_ref, bytes_for_s=False):
386        tuplesize = unpack("<i", self.fp.read(4))[0]
387        ret = self.r_ref(tuple(), save_ref)
388        while tuplesize > 0:
389            ret += (self.r_object(bytes_for_s=bytes_for_s),)
390            tuplesize -= 1
391        return ret
392
393    def t_list(self, save_ref, bytes_for_s=False):
394        # FIXME: check me
395        n = unpack("<i", self.fp.read(4))[0]
396        ret = self.r_ref(list(), save_ref)
397        while n > 0:
398            ret += (self.r_object(bytes_for_s=bytes_for_s),)
399            n -= 1
400        return ret
401
402    def t_frozenset(self, save_ref, bytes_for_s=False):
403        setsize = unpack("<i", self.fp.read(4))[0]
404        ret, i = self.r_ref_reserve(tuple(), save_ref)
405        while setsize > 0:
406            ret += (self.r_object(bytes_for_s=bytes_for_s),)
407            setsize -= 1
408        return self.r_ref_insert(frozenset(ret), i)
409
410    def t_set(self, save_ref, bytes_for_s=False):
411        setsize = unpack("<i", self.fp.read(4))[0]
412        ret, i = self.r_ref_reserve(tuple(), save_ref)
413        while setsize > 0:
414            ret += (self.r_object(bytes_for_s=bytes_for_s),)
415            setsize -= 1
416        return self.r_ref_insert(set(ret), i)
417
418    def t_dict(self, save_ref, bytes_for_s=False):
419        ret = self.r_ref(dict(), save_ref)
420        # dictionary
421        while True:
422            key = self.r_object(bytes_for_s=bytes_for_s)
423            if key is None:
424                break
425            val = self.r_object(bytes_for_s=bytes_for_s)
426            if val is None:
427                break
428            ret[key] = val
429            pass
430        return ret
431
432    def t_python2_string_reference(self, save_ref, bytes_for_s=False):
433        refnum = unpack("<i", self.fp.read(4))[0]
434        return self.internStrings[refnum]
435
436    def t_code(self, save_ref, bytes_for_s=False):
437        # FIXME: use tables to simplify this?
438        # FIXME: Python 1.0 .. 1.3 isn't well known
439
440        ret, i = self.r_ref_reserve(None, save_ref)
441        version = magic_int2float(self.magic_int)
442
443        if version >= 2.3:
444            co_argcount = unpack("<i", self.fp.read(4))[0]
445        elif version >= 1.3:
446            co_argcount = unpack("<h", self.fp.read(2))[0]
447        else:
448            co_argcount = 0
449
450        # FIXME:
451        # Note we do this by magic_int, not version which is *not*
452        # 3.8
453        if self.magic_int in (3412, 3413, 3422, 3425):
454            co_posonlyargcount = unpack("<i", self.fp.read(4))[0]
455        if version >= 3.8:
456            co_posonlyargcount = 0
457        else:
458            co_posonlyargcount = None
459
460        if version >= 3.0:
461            kwonlyargcount = unpack("<i", self.fp.read(4))[0]
462        else:
463            kwonlyargcount = 0
464
465        if version >= 2.3:
466            co_nlocals = unpack("<i", self.fp.read(4))[0]
467        elif version >= 1.3:
468            co_nlocals = unpack("<h", self.fp.read(2))[0]
469        else:
470            co_nlocals = 0
471
472        if version >= 2.3:
473            co_stacksize = unpack("<i", self.fp.read(4))[0]
474        elif version >= 1.5:
475            co_stacksize = unpack("<h", self.fp.read(2))[0]
476        else:
477            co_stacksize = 0
478
479        if version >= 2.3:
480            co_flags = unpack("<i", self.fp.read(4))[0]
481        elif version >= 1.3:
482            co_flags = unpack("<h", self.fp.read(2))[0]
483        else:
484            co_flags = 0
485
486        co_code = self.r_object(bytes_for_s=True)
487
488        # FIXME: Check/verify that is true:
489        bytes_for_s = PYTHON_VERSION >= 3.0 and version > 3.0
490        co_consts = self.r_object(bytes_for_s=bytes_for_s)
491        co_names = self.r_object(bytes_for_s=bytes_for_s)
492
493        if version >= 1.3:
494            co_varnames = self.r_object(bytes_for_s=False)
495        else:
496            co_varnames = []
497
498        if version >= 2.0:
499            co_freevars = self.r_object(bytes_for_s=bytes_for_s)
500            co_cellvars = self.r_object(bytes_for_s=bytes_for_s)
501        else:
502            co_freevars = tuple()
503            co_cellvars = tuple()
504
505        co_filename = self.r_object(bytes_for_s=bytes_for_s)
506        co_name = self.r_object(bytes_for_s=bytes_for_s)
507
508        if version >= 1.5:
509            if version >= 2.3:
510                co_firstlineno = unpack("<i", self.fp.read(4))[0]
511            else:
512                co_firstlineno = unpack("<h", self.fp.read(2))[0]
513            co_lnotab = self.r_object(bytes_for_s=bytes_for_s)
514        else:
515            # < 1.5 there is no lnotab, so no firstlineno.
516            # SET_LINENO is used instead.
517            co_firstlineno = -1  # Bogus sentinal value
518            co_lnotab = ""
519
520        code = to_portable(
521            co_argcount,
522            co_posonlyargcount,
523            kwonlyargcount,
524            co_nlocals,
525            co_stacksize,
526            co_flags,
527            co_code,
528            co_consts,
529            co_names,
530            co_varnames,
531            co_filename,
532            co_name,
533            co_firstlineno,
534            co_lnotab,
535            co_freevars,
536            co_cellvars,
537            version,
538        )
539
540        self.code_objects[str(code)] = code
541        ret = code
542        return self.r_ref_insert(ret, i)
543
544    # Since Python 3.4
545    def t_object_reference(self, save_ref=None, bytes_for_s=False):
546        refnum = unpack("<i", self.fp.read(4))[0]
547        o = self.internObjects[refnum]
548        return o
549
550    def t_unknown(self, save_ref=None, bytes_for_s=False):
551        raise KeyError("?")
552
553
554# _________________________________________________________________
555#
556# user interface
557
558
559def load_code(fp, magic_int, bytes_for_s=False, code_objects={}):
560    if isinstance(fp, bytes):
561        fp = io.BytesIO(fp)
562    um_gen = _VersionIndependentUnmarshaller(
563        fp, magic_int, bytes_for_s, code_objects=code_objects
564    )
565    return um_gen.load()
566