1from __future__ import absolute_import
2
3import struct
4from struct import error
5
6from kafka.protocol.abstract import AbstractType
7
8
9def _pack(f, value):
10    try:
11        return f(value)
12    except error as e:
13        raise ValueError("Error encountered when attempting to convert value: "
14                        "{!r} to struct format: '{}', hit error: {}"
15                        .format(value, f, e))
16
17
18def _unpack(f, data):
19    try:
20        (value,) = f(data)
21        return value
22    except error as e:
23        raise ValueError("Error encountered when attempting to convert value: "
24                        "{!r} to struct format: '{}', hit error: {}"
25                        .format(data, f, e))
26
27
28class Int8(AbstractType):
29    _pack = struct.Struct('>b').pack
30    _unpack = struct.Struct('>b').unpack
31
32    @classmethod
33    def encode(cls, value):
34        return _pack(cls._pack, value)
35
36    @classmethod
37    def decode(cls, data):
38        return _unpack(cls._unpack, data.read(1))
39
40
41class Int16(AbstractType):
42    _pack = struct.Struct('>h').pack
43    _unpack = struct.Struct('>h').unpack
44
45    @classmethod
46    def encode(cls, value):
47        return _pack(cls._pack, value)
48
49    @classmethod
50    def decode(cls, data):
51        return _unpack(cls._unpack, data.read(2))
52
53
54class Int32(AbstractType):
55    _pack = struct.Struct('>i').pack
56    _unpack = struct.Struct('>i').unpack
57
58    @classmethod
59    def encode(cls, value):
60        return _pack(cls._pack, value)
61
62    @classmethod
63    def decode(cls, data):
64        return _unpack(cls._unpack, data.read(4))
65
66
67class Int64(AbstractType):
68    _pack = struct.Struct('>q').pack
69    _unpack = struct.Struct('>q').unpack
70
71    @classmethod
72    def encode(cls, value):
73        return _pack(cls._pack, value)
74
75    @classmethod
76    def decode(cls, data):
77        return _unpack(cls._unpack, data.read(8))
78
79
80class String(AbstractType):
81    def __init__(self, encoding='utf-8'):
82        self.encoding = encoding
83
84    def encode(self, value):
85        if value is None:
86            return Int16.encode(-1)
87        value = str(value).encode(self.encoding)
88        return Int16.encode(len(value)) + value
89
90    def decode(self, data):
91        length = Int16.decode(data)
92        if length < 0:
93            return None
94        value = data.read(length)
95        if len(value) != length:
96            raise ValueError('Buffer underrun decoding string')
97        return value.decode(self.encoding)
98
99
100class Bytes(AbstractType):
101    @classmethod
102    def encode(cls, value):
103        if value is None:
104            return Int32.encode(-1)
105        else:
106            return Int32.encode(len(value)) + value
107
108    @classmethod
109    def decode(cls, data):
110        length = Int32.decode(data)
111        if length < 0:
112            return None
113        value = data.read(length)
114        if len(value) != length:
115            raise ValueError('Buffer underrun decoding Bytes')
116        return value
117
118    @classmethod
119    def repr(cls, value):
120        return repr(value[:100] + b'...' if value is not None and len(value) > 100 else value)
121
122
123class Boolean(AbstractType):
124    _pack = struct.Struct('>?').pack
125    _unpack = struct.Struct('>?').unpack
126
127    @classmethod
128    def encode(cls, value):
129        return _pack(cls._pack, value)
130
131    @classmethod
132    def decode(cls, data):
133        return _unpack(cls._unpack, data.read(1))
134
135
136class Schema(AbstractType):
137    def __init__(self, *fields):
138        if fields:
139            self.names, self.fields = zip(*fields)
140        else:
141            self.names, self.fields = (), ()
142
143    def encode(self, item):
144        if len(item) != len(self.fields):
145            raise ValueError('Item field count does not match Schema')
146        return b''.join([
147            field.encode(item[i])
148            for i, field in enumerate(self.fields)
149        ])
150
151    def decode(self, data):
152        return tuple([field.decode(data) for field in self.fields])
153
154    def __len__(self):
155        return len(self.fields)
156
157    def repr(self, value):
158        key_vals = []
159        try:
160            for i in range(len(self)):
161                try:
162                    field_val = getattr(value, self.names[i])
163                except AttributeError:
164                    field_val = value[i]
165                key_vals.append('%s=%s' % (self.names[i], self.fields[i].repr(field_val)))
166            return '(' + ', '.join(key_vals) + ')'
167        except Exception:
168            return repr(value)
169
170
171class Array(AbstractType):
172    def __init__(self, *array_of):
173        if len(array_of) > 1:
174            self.array_of = Schema(*array_of)
175        elif len(array_of) == 1 and (isinstance(array_of[0], AbstractType) or
176                                     issubclass(array_of[0], AbstractType)):
177            self.array_of = array_of[0]
178        else:
179            raise ValueError('Array instantiated with no array_of type')
180
181    def encode(self, items):
182        if items is None:
183            return Int32.encode(-1)
184        return b''.join(
185            [Int32.encode(len(items))] +
186            [self.array_of.encode(item) for item in items]
187        )
188
189    def decode(self, data):
190        length = Int32.decode(data)
191        if length == -1:
192            return None
193        return [self.array_of.decode(data) for _ in range(length)]
194
195    def repr(self, list_of_items):
196        if list_of_items is None:
197            return 'NULL'
198        return '[' + ', '.join([self.array_of.repr(item) for item in list_of_items]) + ']'
199