1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from struct import pack, unpack
21from thrift.Thrift import TException
22from ..compat import BufferIO
23
24
25class TTransportException(TException):
26    """Custom Transport Exception class"""
27
28    UNKNOWN = 0
29    NOT_OPEN = 1
30    ALREADY_OPEN = 2
31    TIMED_OUT = 3
32    END_OF_FILE = 4
33    NEGATIVE_SIZE = 5
34    SIZE_LIMIT = 6
35    INVALID_CLIENT_TYPE = 7
36
37    def __init__(self, type=UNKNOWN, message=None, inner=None):
38        TException.__init__(self, message)
39        self.type = type
40        self.inner = inner
41
42
43class TTransportBase(object):
44    """Base class for Thrift transport layer."""
45
46    def isOpen(self):
47        pass
48
49    def open(self):
50        pass
51
52    def close(self):
53        pass
54
55    def read(self, sz):
56        pass
57
58    def readAll(self, sz):
59        buff = b''
60        have = 0
61        while (have < sz):
62            chunk = self.read(sz - have)
63            chunkLen = len(chunk)
64            have += chunkLen
65            buff += chunk
66
67            if chunkLen == 0:
68                raise EOFError()
69
70        return buff
71
72    def write(self, buf):
73        pass
74
75    def flush(self):
76        pass
77
78
79# This class should be thought of as an interface.
80class CReadableTransport(object):
81    """base class for transports that are readable from C"""
82
83    # TODO(dreiss): Think about changing this interface to allow us to use
84    #               a (Python, not c) StringIO instead, because it allows
85    #               you to write after reading.
86
87    # NOTE: This is a classic class, so properties will NOT work
88    #       correctly for setting.
89    @property
90    def cstringio_buf(self):
91        """A cStringIO buffer that contains the current chunk we are reading."""
92        pass
93
94    def cstringio_refill(self, partialread, reqlen):
95        """Refills cstringio_buf.
96
97        Returns the currently used buffer (which can but need not be the same as
98        the old cstringio_buf). partialread is what the C code has read from the
99        buffer, and should be inserted into the buffer before any more reads.  The
100        return value must be a new, not borrowed reference.  Something along the
101        lines of self._buf should be fine.
102
103        If reqlen bytes can't be read, throw EOFError.
104        """
105        pass
106
107
108class TServerTransportBase(object):
109    """Base class for Thrift server transports."""
110
111    def listen(self):
112        pass
113
114    def accept(self):
115        pass
116
117    def close(self):
118        pass
119
120
121class TTransportFactoryBase(object):
122    """Base class for a Transport Factory"""
123
124    def getTransport(self, trans):
125        return trans
126
127
128class TBufferedTransportFactory(object):
129    """Factory transport that builds buffered transports"""
130
131    def getTransport(self, trans):
132        buffered = TBufferedTransport(trans)
133        return buffered
134
135
136class TBufferedTransport(TTransportBase, CReadableTransport):
137    """Class that wraps another transport and buffers its I/O.
138
139    The implementation uses a (configurable) fixed-size read buffer
140    but buffers all writes until a flush is performed.
141    """
142    DEFAULT_BUFFER = 4096
143
144    def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
145        self.__trans = trans
146        self.__wbuf = BufferIO()
147        # Pass string argument to initialize read buffer as cStringIO.InputType
148        self.__rbuf = BufferIO(b'')
149        self.__rbuf_size = rbuf_size
150
151    def isOpen(self):
152        return self.__trans.isOpen()
153
154    def open(self):
155        return self.__trans.open()
156
157    def close(self):
158        return self.__trans.close()
159
160    def read(self, sz):
161        ret = self.__rbuf.read(sz)
162        if len(ret) != 0:
163            return ret
164        self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
165        return self.__rbuf.read(sz)
166
167    def write(self, buf):
168        try:
169            self.__wbuf.write(buf)
170        except Exception as e:
171            # on exception reset wbuf so it doesn't contain a partial function call
172            self.__wbuf = BufferIO()
173            raise e
174
175    def flush(self):
176        out = self.__wbuf.getvalue()
177        # reset wbuf before write/flush to preserve state on underlying failure
178        self.__wbuf = BufferIO()
179        self.__trans.write(out)
180        self.__trans.flush()
181
182    # Implement the CReadableTransport interface.
183    @property
184    def cstringio_buf(self):
185        return self.__rbuf
186
187    def cstringio_refill(self, partialread, reqlen):
188        retstring = partialread
189        if reqlen < self.__rbuf_size:
190            # try to make a read of as much as we can.
191            retstring += self.__trans.read(self.__rbuf_size)
192
193        # but make sure we do read reqlen bytes.
194        if len(retstring) < reqlen:
195            retstring += self.__trans.readAll(reqlen - len(retstring))
196
197        self.__rbuf = BufferIO(retstring)
198        return self.__rbuf
199
200
201class TMemoryBuffer(TTransportBase, CReadableTransport):
202    """Wraps a cBytesIO object as a TTransport.
203
204    NOTE: Unlike the C++ version of this class, you cannot write to it
205          then immediately read from it.  If you want to read from a
206          TMemoryBuffer, you must either pass a string to the constructor.
207    TODO(dreiss): Make this work like the C++ version.
208    """
209
210    def __init__(self, value=None, offset=0):
211        """value -- a value to read from for stringio
212
213        If value is set, this will be a transport for reading,
214        otherwise, it is for writing"""
215        if value is not None:
216            self._buffer = BufferIO(value)
217        else:
218            self._buffer = BufferIO()
219        if offset:
220            self._buffer.seek(offset)
221
222    def isOpen(self):
223        return not self._buffer.closed
224
225    def open(self):
226        pass
227
228    def close(self):
229        self._buffer.close()
230
231    def read(self, sz):
232        return self._buffer.read(sz)
233
234    def write(self, buf):
235        self._buffer.write(buf)
236
237    def flush(self):
238        pass
239
240    def getvalue(self):
241        return self._buffer.getvalue()
242
243    # Implement the CReadableTransport interface.
244    @property
245    def cstringio_buf(self):
246        return self._buffer
247
248    def cstringio_refill(self, partialread, reqlen):
249        # only one shot at reading...
250        raise EOFError()
251
252
253class TFramedTransportFactory(object):
254    """Factory transport that builds framed transports"""
255
256    def getTransport(self, trans):
257        framed = TFramedTransport(trans)
258        return framed
259
260
261class TFramedTransport(TTransportBase, CReadableTransport):
262    """Class that wraps another transport and frames its I/O when writing."""
263
264    def __init__(self, trans,):
265        self.__trans = trans
266        self.__rbuf = BufferIO(b'')
267        self.__wbuf = BufferIO()
268
269    def isOpen(self):
270        return self.__trans.isOpen()
271
272    def open(self):
273        return self.__trans.open()
274
275    def close(self):
276        return self.__trans.close()
277
278    def read(self, sz):
279        ret = self.__rbuf.read(sz)
280        if len(ret) != 0:
281            return ret
282
283        self.readFrame()
284        return self.__rbuf.read(sz)
285
286    def readFrame(self):
287        buff = self.__trans.readAll(4)
288        sz, = unpack('!i', buff)
289        self.__rbuf = BufferIO(self.__trans.readAll(sz))
290
291    def write(self, buf):
292        self.__wbuf.write(buf)
293
294    def flush(self):
295        wout = self.__wbuf.getvalue()
296        wsz = len(wout)
297        # reset wbuf before write/flush to preserve state on underlying failure
298        self.__wbuf = BufferIO()
299        # N.B.: Doing this string concatenation is WAY cheaper than making
300        # two separate calls to the underlying socket object. Socket writes in
301        # Python turn out to be REALLY expensive, but it seems to do a pretty
302        # good job of managing string buffer operations without excessive copies
303        buf = pack("!i", wsz) + wout
304        self.__trans.write(buf)
305        self.__trans.flush()
306
307    # Implement the CReadableTransport interface.
308    @property
309    def cstringio_buf(self):
310        return self.__rbuf
311
312    def cstringio_refill(self, prefix, reqlen):
313        # self.__rbuf will already be empty here because fastbinary doesn't
314        # ask for a refill until the previous buffer is empty.  Therefore,
315        # we can start reading new frames immediately.
316        while len(prefix) < reqlen:
317            self.readFrame()
318            prefix += self.__rbuf.getvalue()
319        self.__rbuf = BufferIO(prefix)
320        return self.__rbuf
321
322
323class TFileObjectTransport(TTransportBase):
324    """Wraps a file-like object to make it work as a Thrift transport."""
325
326    def __init__(self, fileobj):
327        self.fileobj = fileobj
328
329    def isOpen(self):
330        return True
331
332    def close(self):
333        self.fileobj.close()
334
335    def read(self, sz):
336        return self.fileobj.read(sz)
337
338    def write(self, buf):
339        self.fileobj.write(buf)
340
341    def flush(self):
342        self.fileobj.flush()
343
344
345class TSaslClientTransport(TTransportBase, CReadableTransport):
346    """
347    SASL transport
348    """
349
350    START = 1
351    OK = 2
352    BAD = 3
353    ERROR = 4
354    COMPLETE = 5
355
356    def __init__(self, transport, host, service, mechanism='GSSAPI',
357                 **sasl_kwargs):
358        """
359        transport: an underlying transport to use, typically just a TSocket
360        host: the name of the server, from a SASL perspective
361        service: the name of the server's service, from a SASL perspective
362        mechanism: the name of the preferred mechanism to use
363
364        All other kwargs will be passed to the puresasl.client.SASLClient
365        constructor.
366        """
367
368        from puresasl.client import SASLClient
369
370        self.transport = transport
371        self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
372
373        self.__wbuf = BufferIO()
374        self.__rbuf = BufferIO(b'')
375
376    def open(self):
377        if not self.transport.isOpen():
378            self.transport.open()
379
380        self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii'))
381        self.send_sasl_msg(self.OK, self.sasl.process())
382
383        while True:
384            status, challenge = self.recv_sasl_msg()
385            if status == self.OK:
386                self.send_sasl_msg(self.OK, self.sasl.process(challenge))
387            elif status == self.COMPLETE:
388                if not self.sasl.complete:
389                    raise TTransportException(
390                        TTransportException.NOT_OPEN,
391                        "The server erroneously indicated "
392                        "that SASL negotiation was complete")
393                else:
394                    break
395            else:
396                raise TTransportException(
397                    TTransportException.NOT_OPEN,
398                    "Bad SASL negotiation status: %d (%s)"
399                    % (status, challenge))
400
401    def send_sasl_msg(self, status, body):
402        header = pack(">BI", status, len(body))
403        self.transport.write(header + body)
404        self.transport.flush()
405
406    def recv_sasl_msg(self):
407        header = self.transport.readAll(5)
408        status, length = unpack(">BI", header)
409        if length > 0:
410            payload = self.transport.readAll(length)
411        else:
412            payload = ""
413        return status, payload
414
415    def write(self, data):
416        self.__wbuf.write(data)
417
418    def flush(self):
419        data = self.__wbuf.getvalue()
420        encoded = self.sasl.wrap(data)
421        self.transport.write(pack("!i", len(encoded)) + encoded)
422        self.transport.flush()
423        self.__wbuf = BufferIO()
424
425    def read(self, sz):
426        ret = self.__rbuf.read(sz)
427        if len(ret) != 0:
428            return ret
429
430        self._read_frame()
431        return self.__rbuf.read(sz)
432
433    def _read_frame(self):
434        header = self.transport.readAll(4)
435        length, = unpack('!i', header)
436        encoded = self.transport.readAll(length)
437        self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
438
439    def close(self):
440        self.sasl.dispose()
441        self.transport.close()
442
443    # based on TFramedTransport
444    @property
445    def cstringio_buf(self):
446        return self.__rbuf
447
448    def cstringio_refill(self, prefix, reqlen):
449        # self.__rbuf will already be empty here because fastbinary doesn't
450        # ask for a refill until the previous buffer is empty.  Therefore,
451        # we can start reading new frames immediately.
452        while len(prefix) < reqlen:
453            self._read_frame()
454            prefix += self.__rbuf.getvalue()
455        self.__rbuf = BufferIO(prefix)
456        return self.__rbuf
457