1"""Implements (a subset of) Sun XDR -- eXternal Data Representation.
2
3See: RFC 1014
4
5"""
6
7import struct
8from io import BytesIO
9from functools import wraps
10
11__all__ = ["Error", "Packer", "Unpacker", "ConversionError"]
12
13# exceptions
14class Error(Exception):
15    """Exception class for this module. Use:
16
17    except xdrlib.Error as var:
18        # var has the Error instance for the exception
19
20    Public ivars:
21        msg -- contains the message
22
23    """
24    def __init__(self, msg):
25        self.msg = msg
26    def __repr__(self):
27        return repr(self.msg)
28    def __str__(self):
29        return str(self.msg)
30
31
32class ConversionError(Error):
33    pass
34
35def raise_conversion_error(function):
36    """ Wrap any raised struct.errors in a ConversionError. """
37
38    @wraps(function)
39    def result(self, value):
40        try:
41            return function(self, value)
42        except struct.error as e:
43            raise ConversionError(e.args[0]) from None
44    return result
45
46
47class Packer:
48    """Pack various data representations into a buffer."""
49
50    def __init__(self):
51        self.reset()
52
53    def reset(self):
54        self.__buf = BytesIO()
55
56    def get_buffer(self):
57        return self.__buf.getvalue()
58    # backwards compatibility
59    get_buf = get_buffer
60
61    @raise_conversion_error
62    def pack_uint(self, x):
63        self.__buf.write(struct.pack('>L', x))
64
65    @raise_conversion_error
66    def pack_int(self, x):
67        self.__buf.write(struct.pack('>l', x))
68
69    pack_enum = pack_int
70
71    def pack_bool(self, x):
72        if x: self.__buf.write(b'\0\0\0\1')
73        else: self.__buf.write(b'\0\0\0\0')
74
75    def pack_uhyper(self, x):
76        try:
77            self.pack_uint(x>>32 & 0xffffffff)
78        except (TypeError, struct.error) as e:
79            raise ConversionError(e.args[0]) from None
80        try:
81            self.pack_uint(x & 0xffffffff)
82        except (TypeError, struct.error) as e:
83            raise ConversionError(e.args[0]) from None
84
85    pack_hyper = pack_uhyper
86
87    @raise_conversion_error
88    def pack_float(self, x):
89        self.__buf.write(struct.pack('>f', x))
90
91    @raise_conversion_error
92    def pack_double(self, x):
93        self.__buf.write(struct.pack('>d', x))
94
95    def pack_fstring(self, n, s):
96        if n < 0:
97            raise ValueError('fstring size must be nonnegative')
98        data = s[:n]
99        n = ((n+3)//4)*4
100        data = data + (n - len(data)) * b'\0'
101        self.__buf.write(data)
102
103    pack_fopaque = pack_fstring
104
105    def pack_string(self, s):
106        n = len(s)
107        self.pack_uint(n)
108        self.pack_fstring(n, s)
109
110    pack_opaque = pack_string
111    pack_bytes = pack_string
112
113    def pack_list(self, list, pack_item):
114        for item in list:
115            self.pack_uint(1)
116            pack_item(item)
117        self.pack_uint(0)
118
119    def pack_farray(self, n, list, pack_item):
120        if len(list) != n:
121            raise ValueError('wrong array size')
122        for item in list:
123            pack_item(item)
124
125    def pack_array(self, list, pack_item):
126        n = len(list)
127        self.pack_uint(n)
128        self.pack_farray(n, list, pack_item)
129
130
131
132class Unpacker:
133    """Unpacks various data representations from the given buffer."""
134
135    def __init__(self, data):
136        self.reset(data)
137
138    def reset(self, data):
139        self.__buf = data
140        self.__pos = 0
141
142    def get_position(self):
143        return self.__pos
144
145    def set_position(self, position):
146        self.__pos = position
147
148    def get_buffer(self):
149        return self.__buf
150
151    def done(self):
152        if self.__pos < len(self.__buf):
153            raise Error('unextracted data remains')
154
155    def unpack_uint(self):
156        i = self.__pos
157        self.__pos = j = i+4
158        data = self.__buf[i:j]
159        if len(data) < 4:
160            raise EOFError
161        return struct.unpack('>L', data)[0]
162
163    def unpack_int(self):
164        i = self.__pos
165        self.__pos = j = i+4
166        data = self.__buf[i:j]
167        if len(data) < 4:
168            raise EOFError
169        return struct.unpack('>l', data)[0]
170
171    unpack_enum = unpack_int
172
173    def unpack_bool(self):
174        return bool(self.unpack_int())
175
176    def unpack_uhyper(self):
177        hi = self.unpack_uint()
178        lo = self.unpack_uint()
179        return int(hi)<<32 | lo
180
181    def unpack_hyper(self):
182        x = self.unpack_uhyper()
183        if x >= 0x8000000000000000:
184            x = x - 0x10000000000000000
185        return x
186
187    def unpack_float(self):
188        i = self.__pos
189        self.__pos = j = i+4
190        data = self.__buf[i:j]
191        if len(data) < 4:
192            raise EOFError
193        return struct.unpack('>f', data)[0]
194
195    def unpack_double(self):
196        i = self.__pos
197        self.__pos = j = i+8
198        data = self.__buf[i:j]
199        if len(data) < 8:
200            raise EOFError
201        return struct.unpack('>d', data)[0]
202
203    def unpack_fstring(self, n):
204        if n < 0:
205            raise ValueError('fstring size must be nonnegative')
206        i = self.__pos
207        j = i + (n+3)//4*4
208        if j > len(self.__buf):
209            raise EOFError
210        self.__pos = j
211        return self.__buf[i:i+n]
212
213    unpack_fopaque = unpack_fstring
214
215    def unpack_string(self):
216        n = self.unpack_uint()
217        return self.unpack_fstring(n)
218
219    unpack_opaque = unpack_string
220    unpack_bytes = unpack_string
221
222    def unpack_list(self, unpack_item):
223        list = []
224        while 1:
225            x = self.unpack_uint()
226            if x == 0: break
227            if x != 1:
228                raise ConversionError('0 or 1 expected, got %r' % (x,))
229            item = unpack_item()
230            list.append(item)
231        return list
232
233    def unpack_farray(self, n, unpack_item):
234        list = []
235        for i in range(n):
236            list.append(unpack_item())
237        return list
238
239    def unpack_array(self, unpack_item):
240        n = self.unpack_uint()
241        return self.unpack_farray(n, unpack_item)
242