1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20import struct 21import zlib 22 23from thrift.compat import BufferIO, byte_index 24from thrift.protocol.TBinaryProtocol import TBinaryProtocol 25from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint 26from thrift.Thrift import TApplicationException 27from thrift.transport.TTransport import ( 28 CReadableTransport, 29 TMemoryBuffer, 30 TTransportBase, 31 TTransportException, 32) 33 34 35U16 = struct.Struct("!H") 36I32 = struct.Struct("!i") 37HEADER_MAGIC = 0x0FFF 38HARD_MAX_FRAME_SIZE = 0x3FFFFFFF 39 40 41class THeaderClientType(object): 42 HEADERS = 0x00 43 44 FRAMED_BINARY = 0x01 45 UNFRAMED_BINARY = 0x02 46 47 FRAMED_COMPACT = 0x03 48 UNFRAMED_COMPACT = 0x04 49 50 51class THeaderSubprotocolID(object): 52 BINARY = 0x00 53 COMPACT = 0x02 54 55 56class TInfoHeaderType(object): 57 KEY_VALUE = 0x01 58 59 60class THeaderTransformID(object): 61 ZLIB = 0x01 62 63 64READ_TRANSFORMS_BY_ID = { 65 THeaderTransformID.ZLIB: zlib.decompress, 66} 67 68 69WRITE_TRANSFORMS_BY_ID = { 70 THeaderTransformID.ZLIB: zlib.compress, 71} 72 73 74def _readString(trans): 75 size = readVarint(trans) 76 if size < 0: 77 raise TTransportException( 78 TTransportException.NEGATIVE_SIZE, 79 "Negative length" 80 ) 81 return trans.read(size) 82 83 84def _writeString(trans, value): 85 writeVarint(trans, len(value)) 86 trans.write(value) 87 88 89class THeaderTransport(TTransportBase, CReadableTransport): 90 def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY): 91 self._transport = transport 92 self._client_type = THeaderClientType.HEADERS 93 self._allowed_client_types = allowed_client_types 94 95 self._read_buffer = BufferIO(b"") 96 self._read_headers = {} 97 98 self._write_buffer = BufferIO() 99 self._write_headers = {} 100 self._write_transforms = [] 101 102 self.flags = 0 103 self.sequence_id = 0 104 self._protocol_id = default_protocol 105 self._max_frame_size = HARD_MAX_FRAME_SIZE 106 107 def isOpen(self): 108 return self._transport.isOpen() 109 110 def open(self): 111 return self._transport.open() 112 113 def close(self): 114 return self._transport.close() 115 116 def get_headers(self): 117 return self._read_headers 118 119 def set_header(self, key, value): 120 if not isinstance(key, bytes): 121 raise ValueError("header names must be bytes") 122 if not isinstance(value, bytes): 123 raise ValueError("header values must be bytes") 124 self._write_headers[key] = value 125 126 def clear_headers(self): 127 self._write_headers.clear() 128 129 def add_transform(self, transform_id): 130 if transform_id not in WRITE_TRANSFORMS_BY_ID: 131 raise ValueError("unknown transform") 132 self._write_transforms.append(transform_id) 133 134 def set_max_frame_size(self, size): 135 if not 0 < size < HARD_MAX_FRAME_SIZE: 136 raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE) 137 self._max_frame_size = size 138 139 @property 140 def protocol_id(self): 141 if self._client_type == THeaderClientType.HEADERS: 142 return self._protocol_id 143 elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY): 144 return THeaderSubprotocolID.BINARY 145 elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT): 146 return THeaderSubprotocolID.COMPACT 147 else: 148 raise TTransportException( 149 TTransportException.INVALID_CLIENT_TYPE, 150 "Protocol ID not know for client type %d" % self._client_type, 151 ) 152 153 def read(self, sz): 154 # if there are bytes left in the buffer, produce those first. 155 bytes_read = self._read_buffer.read(sz) 156 bytes_left_to_read = sz - len(bytes_read) 157 if bytes_left_to_read == 0: 158 return bytes_read 159 160 # if we've determined this is an unframed client, just pass the read 161 # through to the underlying transport until we're reset again at the 162 # beginning of the next message. 163 if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT): 164 return bytes_read + self._transport.read(bytes_left_to_read) 165 166 # we're empty and (maybe) framed. fill the buffers with the next frame. 167 self.readFrame(bytes_left_to_read) 168 return bytes_read + self._read_buffer.read(bytes_left_to_read) 169 170 def _set_client_type(self, client_type): 171 if client_type not in self._allowed_client_types: 172 raise TTransportException( 173 TTransportException.INVALID_CLIENT_TYPE, 174 "Client type %d not allowed by server." % client_type, 175 ) 176 self._client_type = client_type 177 178 def readFrame(self, req_sz): 179 # the first word could either be the length field of a framed message 180 # or the first bytes of an unframed message. 181 first_word = self._transport.readAll(I32.size) 182 frame_size, = I32.unpack(first_word) 183 is_unframed = False 184 if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1: 185 self._set_client_type(THeaderClientType.UNFRAMED_BINARY) 186 is_unframed = True 187 elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and 188 byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION): 189 self._set_client_type(THeaderClientType.UNFRAMED_COMPACT) 190 is_unframed = True 191 192 if is_unframed: 193 bytes_left_to_read = req_sz - I32.size 194 if bytes_left_to_read > 0: 195 rest = self._transport.read(bytes_left_to_read) 196 else: 197 rest = b"" 198 self._read_buffer = BufferIO(first_word + rest) 199 return 200 201 # ok, we're still here so we're framed. 202 if frame_size > self._max_frame_size: 203 raise TTransportException( 204 TTransportException.SIZE_LIMIT, 205 "Frame was too large.", 206 ) 207 read_buffer = BufferIO(self._transport.readAll(frame_size)) 208 209 # the next word is either going to be the version field of a 210 # binary/compact protocol message or the magic value + flags of a 211 # header protocol message. 212 second_word = read_buffer.read(I32.size) 213 version, = I32.unpack(second_word) 214 read_buffer.seek(0) 215 if version >> 16 == HEADER_MAGIC: 216 self._set_client_type(THeaderClientType.HEADERS) 217 self._read_buffer = self._parse_header_format(read_buffer) 218 elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1: 219 self._set_client_type(THeaderClientType.FRAMED_BINARY) 220 self._read_buffer = read_buffer 221 elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and 222 byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION): 223 self._set_client_type(THeaderClientType.FRAMED_COMPACT) 224 self._read_buffer = read_buffer 225 else: 226 raise TTransportException( 227 TTransportException.INVALID_CLIENT_TYPE, 228 "Could not detect client transport type.", 229 ) 230 231 def _parse_header_format(self, buffer): 232 # make BufferIO look like TTransport for varint helpers 233 buffer_transport = TMemoryBuffer() 234 buffer_transport._buffer = buffer 235 236 buffer.read(2) # discard the magic bytes 237 self.flags, = U16.unpack(buffer.read(U16.size)) 238 self.sequence_id, = I32.unpack(buffer.read(I32.size)) 239 240 header_length = U16.unpack(buffer.read(U16.size))[0] * 4 241 end_of_headers = buffer.tell() + header_length 242 if end_of_headers > len(buffer.getvalue()): 243 raise TTransportException( 244 TTransportException.SIZE_LIMIT, 245 "Header size is larger than whole frame.", 246 ) 247 248 self._protocol_id = readVarint(buffer_transport) 249 250 transforms = [] 251 transform_count = readVarint(buffer_transport) 252 for _ in range(transform_count): 253 transform_id = readVarint(buffer_transport) 254 if transform_id not in READ_TRANSFORMS_BY_ID: 255 raise TApplicationException( 256 TApplicationException.INVALID_TRANSFORM, 257 "Unknown transform: %d" % transform_id, 258 ) 259 transforms.append(transform_id) 260 transforms.reverse() 261 262 headers = {} 263 while buffer.tell() < end_of_headers: 264 header_type = readVarint(buffer_transport) 265 if header_type == TInfoHeaderType.KEY_VALUE: 266 count = readVarint(buffer_transport) 267 for _ in range(count): 268 key = _readString(buffer_transport) 269 value = _readString(buffer_transport) 270 headers[key] = value 271 else: 272 break # ignore unknown headers 273 self._read_headers = headers 274 275 # skip padding / anything we didn't understand 276 buffer.seek(end_of_headers) 277 278 payload = buffer.read() 279 for transform_id in transforms: 280 transform_fn = READ_TRANSFORMS_BY_ID[transform_id] 281 payload = transform_fn(payload) 282 return BufferIO(payload) 283 284 def write(self, buf): 285 self._write_buffer.write(buf) 286 287 def flush(self): 288 payload = self._write_buffer.getvalue() 289 self._write_buffer = BufferIO() 290 291 buffer = BufferIO() 292 if self._client_type == THeaderClientType.HEADERS: 293 for transform_id in self._write_transforms: 294 transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id] 295 payload = transform_fn(payload) 296 297 headers = BufferIO() 298 writeVarint(headers, self._protocol_id) 299 writeVarint(headers, len(self._write_transforms)) 300 for transform_id in self._write_transforms: 301 writeVarint(headers, transform_id) 302 if self._write_headers: 303 writeVarint(headers, TInfoHeaderType.KEY_VALUE) 304 writeVarint(headers, len(self._write_headers)) 305 for key, value in self._write_headers.items(): 306 _writeString(headers, key) 307 _writeString(headers, value) 308 self._write_headers = {} 309 padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4 310 headers.write(b"\x00" * padding_needed) 311 header_bytes = headers.getvalue() 312 313 buffer.write(I32.pack(10 + len(header_bytes) + len(payload))) 314 buffer.write(U16.pack(HEADER_MAGIC)) 315 buffer.write(U16.pack(self.flags)) 316 buffer.write(I32.pack(self.sequence_id)) 317 buffer.write(U16.pack(len(header_bytes) // 4)) 318 buffer.write(header_bytes) 319 buffer.write(payload) 320 elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT): 321 buffer.write(I32.pack(len(payload))) 322 buffer.write(payload) 323 elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT): 324 buffer.write(payload) 325 else: 326 raise TTransportException( 327 TTransportException.INVALID_CLIENT_TYPE, 328 "Unknown client type.", 329 ) 330 331 # the frame length field doesn't count towards the frame payload size 332 frame_bytes = buffer.getvalue() 333 frame_payload_size = len(frame_bytes) - 4 334 if frame_payload_size > self._max_frame_size: 335 raise TTransportException( 336 TTransportException.SIZE_LIMIT, 337 "Attempting to send frame that is too large.", 338 ) 339 340 self._transport.write(frame_bytes) 341 self._transport.flush() 342 343 @property 344 def cstringio_buf(self): 345 return self._read_buffer 346 347 def cstringio_refill(self, partialread, reqlen): 348 result = bytearray(partialread) 349 while len(result) < reqlen: 350 result += self.read(reqlen - len(result)) 351 self._read_buffer = BufferIO(result) 352 return self._read_buffer 353