1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from struct import pack, unpack
21from thrift.Thrift import TException
22from ..compat import BufferIO
23
24
25class TTransportException(TException):
26    """Custom Transport Exception class"""
27
28    UNKNOWN = 0
29    NOT_OPEN = 1
30    ALREADY_OPEN = 2
31    TIMED_OUT = 3
32    END_OF_FILE = 4
33    NEGATIVE_SIZE = 5
34    SIZE_LIMIT = 6
35    INVALID_CLIENT_TYPE = 7
36
37    def __init__(self, type=UNKNOWN, message=None):
38        TException.__init__(self, message)
39        self.type = type
40
41
42class TTransportBase(object):
43    """Base class for Thrift transport layer."""
44
45    def isOpen(self):
46        pass
47
48    def open(self):
49        pass
50
51    def close(self):
52        pass
53
54    def read(self, sz):
55        pass
56
57    def readAll(self, sz):
58        buff = b''
59        have = 0
60        while (have < sz):
61            chunk = self.read(sz - have)
62            chunkLen = len(chunk)
63            have += chunkLen
64            buff += chunk
65
66            if chunkLen == 0:
67                raise EOFError()
68
69        return buff
70
71    def write(self, buf):
72        pass
73
74    def flush(self):
75        pass
76
77
78# This class should be thought of as an interface.
79class CReadableTransport(object):
80    """base class for transports that are readable from C"""
81
82    # TODO(dreiss): Think about changing this interface to allow us to use
83    #               a (Python, not c) StringIO instead, because it allows
84    #               you to write after reading.
85
86    # NOTE: This is a classic class, so properties will NOT work
87    #       correctly for setting.
88    @property
89    def cstringio_buf(self):
90        """A cStringIO buffer that contains the current chunk we are reading."""
91        pass
92
93    def cstringio_refill(self, partialread, reqlen):
94        """Refills cstringio_buf.
95
96        Returns the currently used buffer (which can but need not be the same as
97        the old cstringio_buf). partialread is what the C code has read from the
98        buffer, and should be inserted into the buffer before any more reads.  The
99        return value must be a new, not borrowed reference.  Something along the
100        lines of self._buf should be fine.
101
102        If reqlen bytes can't be read, throw EOFError.
103        """
104        pass
105
106
107class TServerTransportBase(object):
108    """Base class for Thrift server transports."""
109
110    def listen(self):
111        pass
112
113    def accept(self):
114        pass
115
116    def close(self):
117        pass
118
119
120class TTransportFactoryBase(object):
121    """Base class for a Transport Factory"""
122
123    def getTransport(self, trans):
124        return trans
125
126
127class TBufferedTransportFactory(object):
128    """Factory transport that builds buffered transports"""
129
130    def getTransport(self, trans):
131        buffered = TBufferedTransport(trans)
132        return buffered
133
134
135class TBufferedTransport(TTransportBase, CReadableTransport):
136    """Class that wraps another transport and buffers its I/O.
137
138    The implementation uses a (configurable) fixed-size read buffer
139    but buffers all writes until a flush is performed.
140    """
141    DEFAULT_BUFFER = 4096
142
143    def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
144        self.__trans = trans
145        self.__wbuf = BufferIO()
146        # Pass string argument to initialize read buffer as cStringIO.InputType
147        self.__rbuf = BufferIO(b'')
148        self.__rbuf_size = rbuf_size
149
150    def isOpen(self):
151        return self.__trans.isOpen()
152
153    def open(self):
154        return self.__trans.open()
155
156    def close(self):
157        return self.__trans.close()
158
159    def read(self, sz):
160        ret = self.__rbuf.read(sz)
161        if len(ret) != 0:
162            return ret
163        self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
164        return self.__rbuf.read(sz)
165
166    def write(self, buf):
167        try:
168            self.__wbuf.write(buf)
169        except Exception as e:
170            # on exception reset wbuf so it doesn't contain a partial function call
171            self.__wbuf = BufferIO()
172            raise e
173
174    def flush(self):
175        out = self.__wbuf.getvalue()
176        # reset wbuf before write/flush to preserve state on underlying failure
177        self.__wbuf = BufferIO()
178        self.__trans.write(out)
179        self.__trans.flush()
180
181    # Implement the CReadableTransport interface.
182    @property
183    def cstringio_buf(self):
184        return self.__rbuf
185
186    def cstringio_refill(self, partialread, reqlen):
187        retstring = partialread
188        if reqlen < self.__rbuf_size:
189            # try to make a read of as much as we can.
190            retstring += self.__trans.read(self.__rbuf_size)
191
192        # but make sure we do read reqlen bytes.
193        if len(retstring) < reqlen:
194            retstring += self.__trans.readAll(reqlen - len(retstring))
195
196        self.__rbuf = BufferIO(retstring)
197        return self.__rbuf
198
199
200class TMemoryBuffer(TTransportBase, CReadableTransport):
201    """Wraps a cBytesIO object as a TTransport.
202
203    NOTE: Unlike the C++ version of this class, you cannot write to it
204          then immediately read from it.  If you want to read from a
205          TMemoryBuffer, you must either pass a string to the constructor.
206    TODO(dreiss): Make this work like the C++ version.
207    """
208
209    def __init__(self, value=None, offset=0):
210        """value -- a value to read from for stringio
211
212        If value is set, this will be a transport for reading,
213        otherwise, it is for writing"""
214        if value is not None:
215            self._buffer = BufferIO(value)
216        else:
217            self._buffer = BufferIO()
218        if offset:
219            self._buffer.seek(offset)
220
221    def isOpen(self):
222        return not self._buffer.closed
223
224    def open(self):
225        pass
226
227    def close(self):
228        self._buffer.close()
229
230    def read(self, sz):
231        return self._buffer.read(sz)
232
233    def write(self, buf):
234        self._buffer.write(buf)
235
236    def flush(self):
237        pass
238
239    def getvalue(self):
240        return self._buffer.getvalue()
241
242    # Implement the CReadableTransport interface.
243    @property
244    def cstringio_buf(self):
245        return self._buffer
246
247    def cstringio_refill(self, partialread, reqlen):
248        # only one shot at reading...
249        raise EOFError()
250
251
252class TFramedTransportFactory(object):
253    """Factory transport that builds framed transports"""
254
255    def getTransport(self, trans):
256        framed = TFramedTransport(trans)
257        return framed
258
259
260class TFramedTransport(TTransportBase, CReadableTransport):
261    """Class that wraps another transport and frames its I/O when writing."""
262
263    def __init__(self, trans,):
264        self.__trans = trans
265        self.__rbuf = BufferIO(b'')
266        self.__wbuf = BufferIO()
267
268    def isOpen(self):
269        return self.__trans.isOpen()
270
271    def open(self):
272        return self.__trans.open()
273
274    def close(self):
275        return self.__trans.close()
276
277    def read(self, sz):
278        ret = self.__rbuf.read(sz)
279        if len(ret) != 0:
280            return ret
281
282        self.readFrame()
283        return self.__rbuf.read(sz)
284
285    def readFrame(self):
286        buff = self.__trans.readAll(4)
287        sz, = unpack('!i', buff)
288        self.__rbuf = BufferIO(self.__trans.readAll(sz))
289
290    def write(self, buf):
291        self.__wbuf.write(buf)
292
293    def flush(self):
294        wout = self.__wbuf.getvalue()
295        wsz = len(wout)
296        # reset wbuf before write/flush to preserve state on underlying failure
297        self.__wbuf = BufferIO()
298        # N.B.: Doing this string concatenation is WAY cheaper than making
299        # two separate calls to the underlying socket object. Socket writes in
300        # Python turn out to be REALLY expensive, but it seems to do a pretty
301        # good job of managing string buffer operations without excessive copies
302        buf = pack("!i", wsz) + wout
303        self.__trans.write(buf)
304        self.__trans.flush()
305
306    # Implement the CReadableTransport interface.
307    @property
308    def cstringio_buf(self):
309        return self.__rbuf
310
311    def cstringio_refill(self, prefix, reqlen):
312        # self.__rbuf will already be empty here because fastbinary doesn't
313        # ask for a refill until the previous buffer is empty.  Therefore,
314        # we can start reading new frames immediately.
315        while len(prefix) < reqlen:
316            self.readFrame()
317            prefix += self.__rbuf.getvalue()
318        self.__rbuf = BufferIO(prefix)
319        return self.__rbuf
320
321
322class TFileObjectTransport(TTransportBase):
323    """Wraps a file-like object to make it work as a Thrift transport."""
324
325    def __init__(self, fileobj):
326        self.fileobj = fileobj
327
328    def isOpen(self):
329        return True
330
331    def close(self):
332        self.fileobj.close()
333
334    def read(self, sz):
335        return self.fileobj.read(sz)
336
337    def write(self, buf):
338        self.fileobj.write(buf)
339
340    def flush(self):
341        self.fileobj.flush()
342
343
344class TSaslClientTransport(TTransportBase, CReadableTransport):
345    """
346    SASL transport
347    """
348
349    START = 1
350    OK = 2
351    BAD = 3
352    ERROR = 4
353    COMPLETE = 5
354
355    def __init__(self, transport, host, service, mechanism='GSSAPI',
356                 **sasl_kwargs):
357        """
358        transport: an underlying transport to use, typically just a TSocket
359        host: the name of the server, from a SASL perspective
360        service: the name of the server's service, from a SASL perspective
361        mechanism: the name of the preferred mechanism to use
362
363        All other kwargs will be passed to the puresasl.client.SASLClient
364        constructor.
365        """
366
367        from puresasl.client import SASLClient
368
369        self.transport = transport
370        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
371
372        self.__wbuf = BufferIO()
373        self.__rbuf = BufferIO(b'')
374
375    def open(self):
376        if not self.transport.isOpen():
377            self.transport.open()
378
379        self.send_sasl_msg(self.START, self.sasl.mechanism)
380        self.send_sasl_msg(self.OK, self.sasl.process())
381
382        while True:
383            status, challenge = self.recv_sasl_msg()
384            if status == self.OK:
385                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
386            elif status == self.COMPLETE:
387                if not self.sasl.complete:
388                    raise TTransportException(
389                        TTransportException.NOT_OPEN,
390                        "The server erroneously indicated "
391                        "that SASL negotiation was complete")
392                else:
393                    break
394            else:
395                raise TTransportException(
396                    TTransportException.NOT_OPEN,
397                    "Bad SASL negotiation status: %d (%s)"
398                    % (status, challenge))
399
400    def send_sasl_msg(self, status, body):
401        header = pack(">BI", status, len(body))
402        self.transport.write(header + body)
403        self.transport.flush()
404
405    def recv_sasl_msg(self):
406        header = self.transport.readAll(5)
407        status, length = unpack(">BI", header)
408        if length > 0:
409            payload = self.transport.readAll(length)
410        else:
411            payload = ""
412        return status, payload
413
414    def write(self, data):
415        self.__wbuf.write(data)
416
417    def flush(self):
418        data = self.__wbuf.getvalue()
419        encoded = self.sasl.wrap(data)
420        self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
421        self.transport.flush()
422        self.__wbuf = BufferIO()
423
424    def read(self, sz):
425        ret = self.__rbuf.read(sz)
426        if len(ret) != 0:
427            return ret
428
429        self._read_frame()
430        return self.__rbuf.read(sz)
431
432    def _read_frame(self):
433        header = self.transport.readAll(4)
434        length, = unpack('!i', header)
435        encoded = self.transport.readAll(length)
436        self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
437
438    def close(self):
439        self.sasl.dispose()
440        self.transport.close()
441
442    # based on TFramedTransport
443    @property
444    def cstringio_buf(self):
445        return self.__rbuf
446
447    def cstringio_refill(self, prefix, reqlen):
448        # self.__rbuf will already be empty here because fastbinary doesn't
449        # ask for a refill until the previous buffer is empty.  Therefore,
450        # we can start reading new frames immediately.
451        while len(prefix) < reqlen:
452            self._read_frame()
453            prefix += self.__rbuf.getvalue()
454        self.__rbuf = BufferIO(prefix)
455        return self.__rbuf
456