1# -*- coding: utf-8 -*-
2# Copyright (C) 2005-2006  Joe Wreschnig
3# Copyright (C) 2006-2007  Lukas Lalinsky
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9
10import sys
11import struct
12
13from mutagen._compat import swap_to_string, text_type, PY2, reraise
14from mutagen._util import total_ordering
15
16from ._util import ASFError
17
18
19class ASFBaseAttribute(object):
20    """Generic attribute."""
21
22    TYPE = None
23
24    _TYPES = {}
25
26    value = None
27    """The Python value of this attribute (type depends on the class)"""
28
29    language = None
30    """Language"""
31
32    stream = None
33    """Stream"""
34
35    def __init__(self, value=None, data=None, language=None,
36                 stream=None, **kwargs):
37        self.language = language
38        self.stream = stream
39        if data is not None:
40            self.value = self.parse(data, **kwargs)
41        else:
42            if value is None:
43                # we used to support not passing any args and instead assign
44                # them later, keep that working..
45                self.value = None
46            else:
47                self.value = self._validate(value)
48
49    @classmethod
50    def _register(cls, other):
51        cls._TYPES[other.TYPE] = other
52        return other
53
54    @classmethod
55    def _get_type(cls, type_):
56        """Raises KeyError"""
57
58        return cls._TYPES[type_]
59
60    def _validate(self, value):
61        """Raises TypeError or ValueError in case the user supplied value
62        isn't valid.
63        """
64
65        return value
66
67    def data_size(self):
68        raise NotImplementedError
69
70    def __repr__(self):
71        name = "%s(%r" % (type(self).__name__, self.value)
72        if self.language:
73            name += ", language=%d" % self.language
74        if self.stream:
75            name += ", stream=%d" % self.stream
76        name += ")"
77        return name
78
79    def render(self, name):
80        name = name.encode("utf-16-le") + b"\x00\x00"
81        data = self._render()
82        return (struct.pack("<H", len(name)) + name +
83                struct.pack("<HH", self.TYPE, len(data)) + data)
84
85    def render_m(self, name):
86        name = name.encode("utf-16-le") + b"\x00\x00"
87        if self.TYPE == 2:
88            data = self._render(dword=False)
89        else:
90            data = self._render()
91        return (struct.pack("<HHHHI", 0, self.stream or 0, len(name),
92                            self.TYPE, len(data)) + name + data)
93
94    def render_ml(self, name):
95        name = name.encode("utf-16-le") + b"\x00\x00"
96        if self.TYPE == 2:
97            data = self._render(dword=False)
98        else:
99            data = self._render()
100
101        return (struct.pack("<HHHHI", self.language or 0, self.stream or 0,
102                            len(name), self.TYPE, len(data)) + name + data)
103
104
105@ASFBaseAttribute._register
106@swap_to_string
107@total_ordering
108class ASFUnicodeAttribute(ASFBaseAttribute):
109    """Unicode string attribute.
110
111    ::
112
113        ASFUnicodeAttribute(u'some text')
114    """
115
116    TYPE = 0x0000
117
118    def parse(self, data):
119        try:
120            return data.decode("utf-16-le").strip("\x00")
121        except UnicodeDecodeError as e:
122            reraise(ASFError, e, sys.exc_info()[2])
123
124    def _validate(self, value):
125        if not isinstance(value, text_type):
126            if PY2:
127                return value.decode("utf-8")
128            else:
129                raise TypeError("%r not str" % value)
130        return value
131
132    def _render(self):
133        return self.value.encode("utf-16-le") + b"\x00\x00"
134
135    def data_size(self):
136        return len(self._render())
137
138    def __bytes__(self):
139        return self.value.encode("utf-16-le")
140
141    def __str__(self):
142        return self.value
143
144    def __eq__(self, other):
145        return text_type(self) == other
146
147    def __lt__(self, other):
148        return text_type(self) < other
149
150    __hash__ = ASFBaseAttribute.__hash__
151
152
153@ASFBaseAttribute._register
154@swap_to_string
155@total_ordering
156class ASFByteArrayAttribute(ASFBaseAttribute):
157    """Byte array attribute.
158
159    ::
160
161        ASFByteArrayAttribute(b'1234')
162    """
163    TYPE = 0x0001
164
165    def parse(self, data):
166        assert isinstance(data, bytes)
167        return data
168
169    def _render(self):
170        assert isinstance(self.value, bytes)
171        return self.value
172
173    def _validate(self, value):
174        if not isinstance(value, bytes):
175            raise TypeError("must be bytes/str: %r" % value)
176        return value
177
178    def data_size(self):
179        return len(self.value)
180
181    def __bytes__(self):
182        return self.value
183
184    def __str__(self):
185        return "[binary data (%d bytes)]" % len(self.value)
186
187    def __eq__(self, other):
188        return self.value == other
189
190    def __lt__(self, other):
191        return self.value < other
192
193    __hash__ = ASFBaseAttribute.__hash__
194
195
196@ASFBaseAttribute._register
197@swap_to_string
198@total_ordering
199class ASFBoolAttribute(ASFBaseAttribute):
200    """Bool attribute.
201
202    ::
203
204        ASFBoolAttribute(True)
205    """
206
207    TYPE = 0x0002
208
209    def parse(self, data, dword=True):
210        if dword:
211            return struct.unpack("<I", data)[0] == 1
212        else:
213            return struct.unpack("<H", data)[0] == 1
214
215    def _render(self, dword=True):
216        if dword:
217            return struct.pack("<I", bool(self.value))
218        else:
219            return struct.pack("<H", bool(self.value))
220
221    def _validate(self, value):
222        return bool(value)
223
224    def data_size(self):
225        return 4
226
227    def __bool__(self):
228        return bool(self.value)
229
230    def __bytes__(self):
231        return text_type(self.value).encode('utf-8')
232
233    def __str__(self):
234        return text_type(self.value)
235
236    def __eq__(self, other):
237        return bool(self.value) == other
238
239    def __lt__(self, other):
240        return bool(self.value) < other
241
242    __hash__ = ASFBaseAttribute.__hash__
243
244
245@ASFBaseAttribute._register
246@swap_to_string
247@total_ordering
248class ASFDWordAttribute(ASFBaseAttribute):
249    """DWORD attribute.
250
251    ::
252
253        ASFDWordAttribute(42)
254    """
255
256    TYPE = 0x0003
257
258    def parse(self, data):
259        return struct.unpack("<L", data)[0]
260
261    def _render(self):
262        return struct.pack("<L", self.value)
263
264    def _validate(self, value):
265        value = int(value)
266        if not 0 <= value <= 2 ** 32 - 1:
267            raise ValueError("Out of range")
268        return value
269
270    def data_size(self):
271        return 4
272
273    def __int__(self):
274        return self.value
275
276    def __bytes__(self):
277        return text_type(self.value).encode('utf-8')
278
279    def __str__(self):
280        return text_type(self.value)
281
282    def __eq__(self, other):
283        return int(self.value) == other
284
285    def __lt__(self, other):
286        return int(self.value) < other
287
288    __hash__ = ASFBaseAttribute.__hash__
289
290
291@ASFBaseAttribute._register
292@swap_to_string
293@total_ordering
294class ASFQWordAttribute(ASFBaseAttribute):
295    """QWORD attribute.
296
297    ::
298
299        ASFQWordAttribute(42)
300    """
301
302    TYPE = 0x0004
303
304    def parse(self, data):
305        return struct.unpack("<Q", data)[0]
306
307    def _render(self):
308        return struct.pack("<Q", self.value)
309
310    def _validate(self, value):
311        value = int(value)
312        if not 0 <= value <= 2 ** 64 - 1:
313            raise ValueError("Out of range")
314        return value
315
316    def data_size(self):
317        return 8
318
319    def __int__(self):
320        return self.value
321
322    def __bytes__(self):
323        return text_type(self.value).encode('utf-8')
324
325    def __str__(self):
326        return text_type(self.value)
327
328    def __eq__(self, other):
329        return int(self.value) == other
330
331    def __lt__(self, other):
332        return int(self.value) < other
333
334    __hash__ = ASFBaseAttribute.__hash__
335
336
337@ASFBaseAttribute._register
338@swap_to_string
339@total_ordering
340class ASFWordAttribute(ASFBaseAttribute):
341    """WORD attribute.
342
343    ::
344
345        ASFWordAttribute(42)
346    """
347
348    TYPE = 0x0005
349
350    def parse(self, data):
351        return struct.unpack("<H", data)[0]
352
353    def _render(self):
354        return struct.pack("<H", self.value)
355
356    def _validate(self, value):
357        value = int(value)
358        if not 0 <= value <= 2 ** 16 - 1:
359            raise ValueError("Out of range")
360        return value
361
362    def data_size(self):
363        return 2
364
365    def __int__(self):
366        return self.value
367
368    def __bytes__(self):
369        return text_type(self.value).encode('utf-8')
370
371    def __str__(self):
372        return text_type(self.value)
373
374    def __eq__(self, other):
375        return int(self.value) == other
376
377    def __lt__(self, other):
378        return int(self.value) < other
379
380    __hash__ = ASFBaseAttribute.__hash__
381
382
383@ASFBaseAttribute._register
384@swap_to_string
385@total_ordering
386class ASFGUIDAttribute(ASFBaseAttribute):
387    """GUID attribute."""
388
389    TYPE = 0x0006
390
391    def parse(self, data):
392        assert isinstance(data, bytes)
393        return data
394
395    def _render(self):
396        assert isinstance(self.value, bytes)
397        return self.value
398
399    def _validate(self, value):
400        if not isinstance(value, bytes):
401            raise TypeError("must be bytes/str: %r" % value)
402        return value
403
404    def data_size(self):
405        return len(self.value)
406
407    def __bytes__(self):
408        return self.value
409
410    def __str__(self):
411        return repr(self.value)
412
413    def __eq__(self, other):
414        return self.value == other
415
416    def __lt__(self, other):
417        return self.value < other
418
419    __hash__ = ASFBaseAttribute.__hash__
420
421
422def ASFValue(value, kind, **kwargs):
423    """Create a tag value of a specific kind.
424
425    ::
426
427        ASFValue(u"My Value", UNICODE)
428
429    :rtype: ASFBaseAttribute
430    :raises TypeError: in case a wrong type was passed
431    :raises ValueError: in case the value can't be be represented as ASFValue.
432    """
433
434    try:
435        attr_type = ASFBaseAttribute._get_type(kind)
436    except KeyError:
437        raise ValueError("Unknown value type %r" % kind)
438    else:
439        return attr_type(value=value, **kwargs)
440