1# Copyright 2015-2017 the openage authors. See copying.md for legal info.
2
3"""
4Provides some classes designed to expand the functionality of struct.struct
5"""
6
7
8from collections import OrderedDict
9from struct import Struct
10
11from ..util.files import read_guaranteed
12
13
14class NamedStructMeta(type):
15    """
16    Metaclass for NamedStruct.
17
18    Not unlike the meta-class for Enum, processes all the member attributes
19    at class-creation time.
20    """
21    @classmethod
22    def __prepare__(mcs, name, bases, **kwds):
23        del mcs, name, bases, kwds  # unused variables
24
25        return OrderedDict()
26
27    def __new__(mcs, name, bases, classdict, **kwds):
28        del kwds  # unused variable
29
30        specstr = None
31        attributes = []
32        postprocessors = {}
33
34        for membername, value in classdict.items():
35            # ignore hidden and None members
36            if membername.startswith('_') or value is None:
37                continue
38
39            valuehasspecstr = hasattr(value, "specstr")
40
41            # ignore member methods
42            if not valuehasspecstr:
43                if callable(value) or isinstance(value, classmethod):
44                    continue
45
46            if membername == 'endianness':
47                if specstr is not None:
48                    raise Exception("endianness has been given multiple times")
49
50                if value not in "@=<>!":
51                    raise ValueError("endianess: expected one of @=<>!")
52
53                specstr = value
54                continue
55
56            if specstr is None:
57                raise Exception("NamedStruct: endianness expected before "
58                                "attribute " + membername)
59
60            if valuehasspecstr:
61                postprocessors[membername], value = value, value.specstr
62            elif isinstance(value, str):
63                pass
64            else:
65                raise TypeError(
66                    "NamedStruct member %s: expected str, but got %s"
67                    % (membername, repr(value)))
68
69            specstr += value
70
71            attributes.append(membername)
72
73        classdict["_attributes"] = attributes
74        classdict["_postprocessors"] = postprocessors
75        if specstr:
76            classdict["_struct"] = Struct(specstr)
77
78        return type.__new__(mcs, name, bases, dict(classdict))
79
80
81class NamedStruct(metaclass=NamedStructMeta):
82    """
83    Designed to be inherited from, similar to Enum.
84
85    Specify all fields of the struct, as 'membername = specstr',
86    where specstr is a string describing the field, as in struct.Struct.
87    NamedStructMeta translates those individual specstr fragments to
88    a complete specstr.
89
90    Alternatively to a specstr, a callable object with a specstr member
91    may be passed; the specstr is used as usual, but afterwards the
92    callable is invoked to post-process the extracted data.
93    One example for such a callable is the Flags class.
94
95    Alternatively, attributes may be set to None; those are ignored,
96    and may be set manually at some later point.
97
98    The first member must be 'endianess'.
99
100    Example:
101
102    class MyStruct(NamedStruct):
103        endianness = "<"
104
105        mgck = "4s"
106        test = "I"
107        rofl = "H"
108        flag = MyFlagType
109
110    The constructor takes a bytes object of the appropriate length, and fills
111    in all the members with the struct's actual values.
112    """
113
114    # those values are set by the metaclass.
115    _postprocessors = None
116    _struct = None
117    _attributes = None
118
119    def __init__(self, data):
120        if not self._struct:
121            raise NotImplementedError(
122                "Abstract NamedStruct can not be instantiated")
123
124        values = self._struct.unpack(data)
125
126        if len(self._attributes) != len(values):
127            raise Exception("internal error: "
128                            "number of attributes differs from number of "
129                            "struct fields")
130
131        for name, value in zip(self._attributes, values):
132            # pylint: disable=unsupported-membership-test
133            if name in self._postprocessors:
134                # pylint: disable=unsubscriptable-object
135                value = self._postprocessors[name](value)
136
137            setattr(self, name, value)
138
139    @classmethod
140    def unpack(cls, data):
141        """
142        Unpacks data and returns a NamedStruct object that holds the fields.
143        """
144        return cls(data)
145
146    @classmethod
147    def size(cls):
148        """
149        Returns the size of the struct, in bytes.
150        """
151        return cls._struct.size
152
153    @classmethod
154    def read(cls, fileobj):
155        """
156        Reads the appropriate amount of data from fileobj, and unpacks it.
157        """
158        data = read_guaranteed(fileobj, cls._struct.size)
159        return cls.unpack(data)
160
161    @classmethod
162    def from_nullbytes(cls):
163        """
164        Decodes nullbytes (sort of a 'default' value).
165        """
166        data = b"\x00" * cls._struct.size
167        return cls.unpack(data)
168
169    # nobody has needed .pack() and .write() functions this far; implement
170    # them if you need them.
171
172    def __len__(self):
173        """
174        Returns the number of fields.
175        """
176        return len(self._attributes)
177
178    def __getitem__(self, index):
179        """
180        Returns the n-th field, or raises IndexError.
181        """
182        # pylint: disable=unsubscriptable-object
183        return getattr(self, self._attributes[index])
184
185    def as_dict(self):
186        """
187        Returns a key-value dict for all attributes.
188        """
189        # pylint: disable=not-an-iterable
190        return {attr: getattr(self, attr) for attr in self._attributes}
191
192    def __iter__(self):
193        return iter(self)
194
195    def __repr__(self):
196        return str(type(self)) + ": " + repr(self.as_dict())
197
198    def __str__(self):
199        return type(self).__name__ + ":\n\t" + "\n\t".join(
200            str(key).ljust(20) + " = " + str(value)
201            for key, value in sorted(self.as_dict().items())
202        )
203
204
205class FlagsMeta(type):
206    """
207    Metaclass for Flags. Compare to NamedStructMeta.
208    """
209    def __new__(mcs, name, bases, classdict, **kwds):
210        del kwds  # unused variable
211
212        # we don't need to know the order of the flags, so we don't need
213        # to do the whole 'OrderedDict' dance.
214
215        # stores a mapping of flag value <-> flag name
216        flags = {}
217        specstr_found = False
218
219        for membername, value in classdict.items():
220            if membername.startswith('_'):
221                continue
222
223            if membername == "specstr":
224                specstr_found = True
225
226                if not isinstance(value, str):
227                    raise TypeError(
228                        "expected str as value for specstr, "
229                        "but got " + repr(value))
230
231                continue
232
233            if callable(value) or isinstance(value, classmethod):
234                continue
235
236            if not isinstance(value, int):
237                raise TypeError(
238                    "expected int as value for flag " + membername + ", "
239                    "but got " + repr(value))
240
241            flagvalue = 1 << value
242            flags[flagvalue] = membername
243
244        if flags and not specstr_found:
245            raise Exception("expected a 'specstr' attribute")
246
247        classdict["_flags"] = flags
248
249        return type.__new__(mcs, name, bases, classdict)
250
251
252class Flags(metaclass=FlagsMeta):
253    """
254    Designed to be inherited from, similar to Enum.
255
256    Used to generate flag parsers (for boolean flags that
257    are stored in an integer value).
258
259    Specify the bit numbers of all possible flags as attributes,
260    e.g.:
261
262    class MyFlags(Flags):
263        thisflag = 0
264        thatflag = 1
265
266    The constructor of the class takes an integer argument,
267    which is parsed; all the boolean values are stored in the
268    attributes.
269    If any unknown bits are set, self.unknown() is called.
270    """
271
272    # set by the metaclass
273    _flags = None
274
275    def __init__(self, val):
276        for flagvalue, flagname in self._flags.items():
277            if val & flagvalue:
278                setattr(self, flagname, True)
279                val &= ~flagvalue
280            else:
281                setattr(self, flagname, False)
282
283        if val:
284            self.unknown(val)
285
286    def unknown(self, unknownflags):
287        """
288        Default handler for any unknown bits. Overload if needed.
289        """
290        raise ValueError(
291            "unknown flag values: " + bin(unknownflags) + " "
292            "in addition to existing flags: " + str(self.as_dict()))
293
294    def as_dict(self):
295        """
296        Returns a key-value dict for all flags.
297        """
298        return {flagname: getattr(self, flagname)
299                for flagname in self._flags.values()}
300
301    def __repr__(self):
302        return repr(type(self)) + ": " + repr(self.as_dict())
303