1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import struct
21import zlib
22
23from thrift.compat import BufferIO, byte_index
24from thrift.protocol.TBinaryProtocol import TBinaryProtocol
25from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint
26from thrift.Thrift import TApplicationException
27from thrift.transport.TTransport import (
28    CReadableTransport,
29    TMemoryBuffer,
30    TTransportBase,
31    TTransportException,
32)
33
34
35U16 = struct.Struct("!H")
36I32 = struct.Struct("!i")
37HEADER_MAGIC = 0x0FFF
38HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
39
40
41class THeaderClientType(object):
42    HEADERS = 0x00
43
44    FRAMED_BINARY = 0x01
45    UNFRAMED_BINARY = 0x02
46
47    FRAMED_COMPACT = 0x03
48    UNFRAMED_COMPACT = 0x04
49
50
51class THeaderSubprotocolID(object):
52    BINARY = 0x00
53    COMPACT = 0x02
54
55
56class TInfoHeaderType(object):
57    KEY_VALUE = 0x01
58
59
60class THeaderTransformID(object):
61    ZLIB = 0x01
62
63
64READ_TRANSFORMS_BY_ID = {
65    THeaderTransformID.ZLIB: zlib.decompress,
66}
67
68
69WRITE_TRANSFORMS_BY_ID = {
70    THeaderTransformID.ZLIB: zlib.compress,
71}
72
73
74def _readString(trans):
75    size = readVarint(trans)
76    if size < 0:
77        raise TTransportException(
78            TTransportException.NEGATIVE_SIZE,
79            "Negative length"
80        )
81    return trans.read(size)
82
83
84def _writeString(trans, value):
85    writeVarint(trans, len(value))
86    trans.write(value)
87
88
89class THeaderTransport(TTransportBase, CReadableTransport):
90    def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY):
91        self._transport = transport
92        self._client_type = THeaderClientType.HEADERS
93        self._allowed_client_types = allowed_client_types
94
95        self._read_buffer = BufferIO(b"")
96        self._read_headers = {}
97
98        self._write_buffer = BufferIO()
99        self._write_headers = {}
100        self._write_transforms = []
101
102        self.flags = 0
103        self.sequence_id = 0
104        self._protocol_id = default_protocol
105        self._max_frame_size = HARD_MAX_FRAME_SIZE
106
107    def isOpen(self):
108        return self._transport.isOpen()
109
110    def open(self):
111        return self._transport.open()
112
113    def close(self):
114        return self._transport.close()
115
116    def get_headers(self):
117        return self._read_headers
118
119    def set_header(self, key, value):
120        if not isinstance(key, bytes):
121            raise ValueError("header names must be bytes")
122        if not isinstance(value, bytes):
123            raise ValueError("header values must be bytes")
124        self._write_headers[key] = value
125
126    def clear_headers(self):
127        self._write_headers.clear()
128
129    def add_transform(self, transform_id):
130        if transform_id not in WRITE_TRANSFORMS_BY_ID:
131            raise ValueError("unknown transform")
132        self._write_transforms.append(transform_id)
133
134    def set_max_frame_size(self, size):
135        if not 0 < size < HARD_MAX_FRAME_SIZE:
136            raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE)
137        self._max_frame_size = size
138
139    @property
140    def protocol_id(self):
141        if self._client_type == THeaderClientType.HEADERS:
142            return self._protocol_id
143        elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY):
144            return THeaderSubprotocolID.BINARY
145        elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT):
146            return THeaderSubprotocolID.COMPACT
147        else:
148            raise TTransportException(
149                TTransportException.INVALID_CLIENT_TYPE,
150                "Protocol ID not know for client type %d" % self._client_type,
151            )
152
153    def read(self, sz):
154        # if there are bytes left in the buffer, produce those first.
155        bytes_read = self._read_buffer.read(sz)
156        bytes_left_to_read = sz - len(bytes_read)
157        if bytes_left_to_read == 0:
158            return bytes_read
159
160        # if we've determined this is an unframed client, just pass the read
161        # through to the underlying transport until we're reset again at the
162        # beginning of the next message.
163        if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
164            return bytes_read + self._transport.read(bytes_left_to_read)
165
166        # we're empty and (maybe) framed. fill the buffers with the next frame.
167        self.readFrame(bytes_left_to_read)
168        return bytes_read + self._read_buffer.read(bytes_left_to_read)
169
170    def _set_client_type(self, client_type):
171        if client_type not in self._allowed_client_types:
172            raise TTransportException(
173                TTransportException.INVALID_CLIENT_TYPE,
174                "Client type %d not allowed by server." % client_type,
175            )
176        self._client_type = client_type
177
178    def readFrame(self, req_sz):
179        # the first word could either be the length field of a framed message
180        # or the first bytes of an unframed message.
181        first_word = self._transport.readAll(I32.size)
182        frame_size, = I32.unpack(first_word)
183        is_unframed = False
184        if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
185            self._set_client_type(THeaderClientType.UNFRAMED_BINARY)
186            is_unframed = True
187        elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and
188              byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
189            self._set_client_type(THeaderClientType.UNFRAMED_COMPACT)
190            is_unframed = True
191
192        if is_unframed:
193            bytes_left_to_read = req_sz - I32.size
194            if bytes_left_to_read > 0:
195                rest = self._transport.read(bytes_left_to_read)
196            else:
197                rest = b""
198            self._read_buffer = BufferIO(first_word + rest)
199            return
200
201        # ok, we're still here so we're framed.
202        if frame_size > self._max_frame_size:
203            raise TTransportException(
204                TTransportException.SIZE_LIMIT,
205                "Frame was too large.",
206            )
207        read_buffer = BufferIO(self._transport.readAll(frame_size))
208
209        # the next word is either going to be the version field of a
210        # binary/compact protocol message or the magic value + flags of a
211        # header protocol message.
212        second_word = read_buffer.read(I32.size)
213        version, = I32.unpack(second_word)
214        read_buffer.seek(0)
215        if version >> 16 == HEADER_MAGIC:
216            self._set_client_type(THeaderClientType.HEADERS)
217            self._read_buffer = self._parse_header_format(read_buffer)
218        elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
219            self._set_client_type(THeaderClientType.FRAMED_BINARY)
220            self._read_buffer = read_buffer
221        elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and
222              byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
223            self._set_client_type(THeaderClientType.FRAMED_COMPACT)
224            self._read_buffer = read_buffer
225        else:
226            raise TTransportException(
227                TTransportException.INVALID_CLIENT_TYPE,
228                "Could not detect client transport type.",
229            )
230
231    def _parse_header_format(self, buffer):
232        # make BufferIO look like TTransport for varint helpers
233        buffer_transport = TMemoryBuffer()
234        buffer_transport._buffer = buffer
235
236        buffer.read(2)  # discard the magic bytes
237        self.flags, = U16.unpack(buffer.read(U16.size))
238        self.sequence_id, = I32.unpack(buffer.read(I32.size))
239
240        header_length = U16.unpack(buffer.read(U16.size))[0] * 4
241        end_of_headers = buffer.tell() + header_length
242        if end_of_headers > len(buffer.getvalue()):
243            raise TTransportException(
244                TTransportException.SIZE_LIMIT,
245                "Header size is larger than whole frame.",
246            )
247
248        self._protocol_id = readVarint(buffer_transport)
249
250        transforms = []
251        transform_count = readVarint(buffer_transport)
252        for _ in range(transform_count):
253            transform_id = readVarint(buffer_transport)
254            if transform_id not in READ_TRANSFORMS_BY_ID:
255                raise TApplicationException(
256                    TApplicationException.INVALID_TRANSFORM,
257                    "Unknown transform: %d" % transform_id,
258                )
259            transforms.append(transform_id)
260        transforms.reverse()
261
262        headers = {}
263        while buffer.tell() < end_of_headers:
264            header_type = readVarint(buffer_transport)
265            if header_type == TInfoHeaderType.KEY_VALUE:
266                count = readVarint(buffer_transport)
267                for _ in range(count):
268                    key = _readString(buffer_transport)
269                    value = _readString(buffer_transport)
270                    headers[key] = value
271            else:
272                break  # ignore unknown headers
273        self._read_headers = headers
274
275        # skip padding / anything we didn't understand
276        buffer.seek(end_of_headers)
277
278        payload = buffer.read()
279        for transform_id in transforms:
280            transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
281            payload = transform_fn(payload)
282        return BufferIO(payload)
283
284    def write(self, buf):
285        self._write_buffer.write(buf)
286
287    def flush(self):
288        payload = self._write_buffer.getvalue()
289        self._write_buffer = BufferIO()
290
291        buffer = BufferIO()
292        if self._client_type == THeaderClientType.HEADERS:
293            for transform_id in self._write_transforms:
294                transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id]
295                payload = transform_fn(payload)
296
297            headers = BufferIO()
298            writeVarint(headers, self._protocol_id)
299            writeVarint(headers, len(self._write_transforms))
300            for transform_id in self._write_transforms:
301                writeVarint(headers, transform_id)
302            if self._write_headers:
303                writeVarint(headers, TInfoHeaderType.KEY_VALUE)
304                writeVarint(headers, len(self._write_headers))
305                for key, value in self._write_headers.items():
306                    _writeString(headers, key)
307                    _writeString(headers, value)
308                self._write_headers = {}
309            padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4
310            headers.write(b"\x00" * padding_needed)
311            header_bytes = headers.getvalue()
312
313            buffer.write(I32.pack(10 + len(header_bytes) + len(payload)))
314            buffer.write(U16.pack(HEADER_MAGIC))
315            buffer.write(U16.pack(self.flags))
316            buffer.write(I32.pack(self.sequence_id))
317            buffer.write(U16.pack(len(header_bytes) // 4))
318            buffer.write(header_bytes)
319            buffer.write(payload)
320        elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT):
321            buffer.write(I32.pack(len(payload)))
322            buffer.write(payload)
323        elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
324            buffer.write(payload)
325        else:
326            raise TTransportException(
327                TTransportException.INVALID_CLIENT_TYPE,
328                "Unknown client type.",
329            )
330
331        # the frame length field doesn't count towards the frame payload size
332        frame_bytes = buffer.getvalue()
333        frame_payload_size = len(frame_bytes) - 4
334        if frame_payload_size > self._max_frame_size:
335            raise TTransportException(
336                TTransportException.SIZE_LIMIT,
337                "Attempting to send frame that is too large.",
338            )
339
340        self._transport.write(frame_bytes)
341        self._transport.flush()
342
343    @property
344    def cstringio_buf(self):
345        return self._read_buffer
346
347    def cstringio_refill(self, partialread, reqlen):
348        result = bytearray(partialread)
349        while len(result) < reqlen:
350            result += self.read(reqlen - len(result))
351        self._read_buffer = BufferIO(result)
352        return self._read_buffer
353