1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# pyre-unsafe 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from __future__ import unicode_literals 21 22import sys 23if sys.version_info[0] >= 3: 24 from http import server 25 # pyre-fixme[11]: Annotation `server` is not defined as a type. 26 BaseHTTPServer = server 27 xrange = range 28 from io import BytesIO as StringIO 29 PY3 = True 30else: 31 import BaseHTTPServer # @manual 32 from cStringIO import StringIO 33 PY3 = False 34 35from struct import pack, unpack 36import zlib 37 38from thrift.Thrift import TApplicationException 39from thrift.protocol.TBinaryProtocol import TBinaryProtocol 40from thrift.transport.TTransport import ( 41 TTransportException, TTransportBase, CReadableTransport 42) 43from thrift.protocol.TCompactProtocol import ( 44 getVarint, readVarint, TCompactProtocol 45) 46 47# Import the snappy module if it is available 48try: 49 import snappy 50except ImportError: 51 # If snappy is not available, don't fail immediately. 52 # Only raise an error if we actually ever need to perform snappy 53 # compression. 54 class DummySnappy(object): 55 def compress(self, buf): 56 raise TTransportException(TTransportException.INVALID_TRANSFORM, 57 'snappy module not available') 58 59 def decompress(self, buf): 60 raise TTransportException(TTransportException.INVALID_TRANSFORM, 61 'snappy module not available') 62 snappy = DummySnappy() 63 64# Import the zstd module if it is available 65try: 66 import zstd # @manual 67except ImportError: 68 # If zstd is not available, don't fail immediately. 69 # Only raise an error if we actually ever need to perform zstd 70 # compression. 71 class DummyZstd(object): 72 def ZstdCompressor(self, write_content_size): 73 raise TTransportException(TTransportException.INVALID_TRANSFORM, 74 'zstd module not available') 75 76 def ZstdDecompressor(self): 77 raise TTransportException(TTransportException.INVALID_TRANSFORM, 78 'zstd module not available') 79 zstd = DummyZstd() 80 81 82# Definitions from THeader.h 83 84 85class CLIENT_TYPE: 86 HEADER = 0 87 FRAMED_DEPRECATED = 1 88 UNFRAMED_DEPRECATED = 2 89 HTTP_SERVER = 3 90 HTTP_CLIENT = 4 91 FRAMED_COMPACT = 5 92 HTTP_GET = 7 93 UNKNOWN = 8 94 UNFRAMED_COMPACT_DEPRECATED = 9 95 96 97class HEADER_FLAG: 98 SUPPORT_OUT_OF_ORDER = 0x01 99 DUPLEX_REVERSE = 0x08 100 101 102class TRANSFORM: 103 NONE = 0x00 104 ZLIB = 0x01 105 HMAC = 0x02 106 SNAPPY = 0x03 107 QLZ = 0x04 108 ZSTD = 0x05 109 110 111class INFO: 112 NORMAL = 1 113 PERSISTENT = 2 114 115 116T_BINARY_PROTOCOL = 0 117T_COMPACT_PROTOCOL = 2 118HEADER_MAGIC = 0x0FFF0000 119PACKED_HEADER_MAGIC = pack(b'!H', HEADER_MAGIC >> 16) 120HEADER_MASK = 0xFFFF0000 121FLAGS_MASK = 0x0000FFFF 122HTTP_SERVER_MAGIC = 0x504F5354 # POST 123HTTP_CLIENT_MAGIC = 0x48545450 # HTTP 124HTTP_GET_CLIENT_MAGIC = 0x47455420 # GET 125HTTP_HEAD_CLIENT_MAGIC = 0x48454144 # HEAD 126BIG_FRAME_MAGIC = 0x42494746 # BIGF 127MAX_FRAME_SIZE = 0x3FFFFFFF 128MAX_BIG_FRAME_SIZE = 2 ** 61 - 1 129 130 131class THeaderTransport(TTransportBase, CReadableTransport): 132 """Transport that sends headers. Also understands framed/unframed/HTTP 133 transports and will do the right thing""" 134 135 __max_frame_size = MAX_FRAME_SIZE 136 137 # Defaults to current user, but there is also a setter below. 138 __identity = None 139 __first_request = True 140 IDENTITY_HEADER = "identity" 141 ID_VERSION_HEADER = "id_version" 142 ID_VERSION = "1" 143 CLIENT_METADATA_HEADER = "client_metadata"; 144 145 def __init__(self, trans, client_types=None, client_type=None): 146 self.__trans = trans 147 self.__rbuf = StringIO() 148 self.__rbuf_frame = False 149 self.__wbuf = StringIO() 150 self.seq_id = 0 151 self.__flags = 0 152 self.__read_transforms = [] 153 self.__write_transforms = [] 154 self.__supported_client_types = set(client_types or 155 (CLIENT_TYPE.HEADER,)) 156 self.__proto_id = T_COMPACT_PROTOCOL # default to compact like c++ 157 self.__client_type = client_type or CLIENT_TYPE.HEADER 158 self.__read_headers = {} 159 self.__read_persistent_headers = {} 160 self.__write_headers = {} 161 self.__write_persistent_headers = {} 162 163 self.__supported_client_types.add(self.__client_type) 164 165 # If we support unframed binary / framed binary also support compact 166 if CLIENT_TYPE.UNFRAMED_DEPRECATED in self.__supported_client_types: 167 self.__supported_client_types.add( 168 CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED) 169 if CLIENT_TYPE.FRAMED_DEPRECATED in self.__supported_client_types: 170 self.__supported_client_types.add( 171 CLIENT_TYPE.FRAMED_COMPACT) 172 173 def set_header_flag(self, flag): 174 self.__flags |= flag 175 176 def clear_header_flag(self, flag): 177 self.__flags &= ~ flag 178 179 def header_flags(self): 180 return self.__flags 181 182 def set_max_frame_size(self, size): 183 if size > MAX_BIG_FRAME_SIZE: 184 raise TTransportException(TTransportException.INVALID_FRAME_SIZE, 185 "Cannot set max frame size > %s" % 186 MAX_BIG_FRAME_SIZE) 187 if size > MAX_FRAME_SIZE and self.__client_type != CLIENT_TYPE.HEADER: 188 raise TTransportException( 189 TTransportException.INVALID_FRAME_SIZE, 190 "Cannot set max frame size > %s for clients other than HEADER" 191 % MAX_FRAME_SIZE) 192 self.__max_frame_size = size 193 194 def get_peer_identity(self): 195 if self.IDENTITY_HEADER in self.__read_headers: 196 if self.__read_headers[self.ID_VERSION_HEADER] == self.ID_VERSION: 197 return self.__read_headers[self.IDENTITY_HEADER] 198 return None 199 200 def set_identity(self, identity): 201 self.__identity = identity 202 203 def get_protocol_id(self): 204 return self.__proto_id 205 206 def set_protocol_id(self, proto_id): 207 self.__proto_id = proto_id 208 209 def set_header(self, str_key, str_value): 210 self.__write_headers[str_key] = str_value 211 212 def get_write_headers(self): 213 return self.__write_headers 214 215 def get_headers(self): 216 return self.__read_headers 217 218 def clear_headers(self): 219 self.__write_headers.clear() 220 221 def set_persistent_header(self, str_key, str_value): 222 self.__write_persistent_headers[str_key] = str_value 223 224 def get_write_persistent_headers(self): 225 return self.__write_persistent_headers 226 227 def clear_persistent_headers(self): 228 self.__write_persistent_headers.clear() 229 230 def add_transform(self, trans_id): 231 self.__write_transforms.append(trans_id) 232 233 def _reset_protocol(self): 234 # HTTP calls that are one way need to flush here. 235 if self.__client_type == CLIENT_TYPE.HTTP_SERVER: 236 self.flush() 237 # set to anything except unframed 238 self.__client_type = CLIENT_TYPE.UNKNOWN 239 # Read header bytes to check which protocol to decode 240 self.readFrame(0) 241 242 def getTransport(self): 243 return self.__trans 244 245 def isOpen(self): 246 return self.getTransport().isOpen() 247 248 def open(self): 249 return self.getTransport().open() 250 251 def close(self): 252 return self.getTransport().close() 253 254 def read(self, sz): 255 ret = self.__rbuf.read(sz) 256 if len(ret) == sz: 257 return ret 258 259 if self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, 260 CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): 261 return ret + self.getTransport().readAll(sz - len(ret)) 262 263 self.readFrame(sz - len(ret)) 264 return ret + self.__rbuf.read(sz - len(ret)) 265 266 readAll = read # TTransportBase.readAll does a needless copy here. 267 268 def readFrame(self, req_sz): 269 self.__rbuf_frame = True 270 word1 = self.getTransport().readAll(4) 271 sz = unpack('!I', word1)[0] 272 proto_id = word1[0] if PY3 else ord(word1[0]) 273 if proto_id == TBinaryProtocol.PROTOCOL_ID: 274 # unframed 275 self.__client_type = CLIENT_TYPE.UNFRAMED_DEPRECATED 276 self.__proto_id = T_BINARY_PROTOCOL 277 if req_sz <= 4: # check for reads < 0. 278 self.__rbuf = StringIO(word1) 279 else: 280 self.__rbuf = StringIO(word1 + self.getTransport().read( 281 req_sz - 4)) 282 elif proto_id == TCompactProtocol.PROTOCOL_ID: 283 self.__client_type = CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED 284 self.__proto_id = T_COMPACT_PROTOCOL 285 if req_sz <= 4: # check for reads < 0. 286 self.__rbuf = StringIO(word1) 287 else: 288 self.__rbuf = StringIO(word1 + self.getTransport().read( 289 req_sz - 4)) 290 elif sz == HTTP_SERVER_MAGIC: 291 self.__client_type = CLIENT_TYPE.HTTP_SERVER 292 mf = self.getTransport().handle.makefile('rb', -1) 293 294 self.handler = RequestHandler(mf, 295 'client_address:port', '') 296 self.header = self.handler.wfile 297 self.__rbuf = StringIO(self.handler.data) 298 else: 299 if sz == BIG_FRAME_MAGIC: 300 sz = unpack('!Q', self.getTransport().readAll(8))[0] 301 # could be header format or framed. Check next two bytes. 302 magic = self.getTransport().readAll(2) 303 proto_id = magic[0] if PY3 else ord(magic[0]) 304 if proto_id == TCompactProtocol.PROTOCOL_ID: 305 self.__client_type = CLIENT_TYPE.FRAMED_COMPACT 306 self.__proto_id = T_COMPACT_PROTOCOL 307 _frame_size_check(sz, self.__max_frame_size, header=False) 308 self.__rbuf = StringIO(magic + self.getTransport().readAll( 309 sz - 2)) 310 elif proto_id == TBinaryProtocol.PROTOCOL_ID: 311 self.__client_type = CLIENT_TYPE.FRAMED_DEPRECATED 312 self.__proto_id = T_BINARY_PROTOCOL 313 _frame_size_check(sz, self.__max_frame_size, header=False) 314 self.__rbuf = StringIO(magic + self.getTransport().readAll( 315 sz - 2)) 316 elif magic == PACKED_HEADER_MAGIC: 317 self.__client_type = CLIENT_TYPE.HEADER 318 _frame_size_check(sz, self.__max_frame_size) 319 # flags(2), seq_id(4), header_size(2) 320 n_header_meta = self.getTransport().readAll(8) 321 self.__flags, self.seq_id, header_size = unpack('!HIH', 322 n_header_meta) 323 data = StringIO() 324 data.write(magic) 325 data.write(n_header_meta) 326 data.write(self.getTransport().readAll(sz - 10)) 327 data.seek(10) 328 self.read_header_format(sz - 10, header_size, data) 329 else: 330 self.__client_type = CLIENT_TYPE.UNKNOWN 331 raise TTransportException( 332 TTransportException.INVALID_CLIENT_TYPE, 333 "Could not detect client transport type") 334 335 if self.__client_type not in self.__supported_client_types: 336 raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, 337 "Client type {} not supported on server" 338 .format(self.__client_type)) 339 340 def read_header_format(self, sz, header_size, data): 341 # clear out any previous transforms 342 self.__read_transforms = [] 343 344 header_size = header_size * 4 345 if header_size > sz: 346 raise TTransportException(TTransportException.INVALID_FRAME_SIZE, 347 "Header size is larger than frame") 348 end_header = header_size + data.tell() 349 350 self.__proto_id = readVarint(data) 351 num_headers = readVarint(data) 352 353 if self.__proto_id == 1 and self.__client_type != \ 354 CLIENT_TYPE.HTTP_SERVER: 355 raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, 356 "Trying to recv JSON encoding over binary") 357 358 # Read the headers. Data for each header varies. 359 for _ in range(0, num_headers): 360 trans_id = readVarint(data) 361 if trans_id in (TRANSFORM.ZLIB, TRANSFORM.SNAPPY, TRANSFORM.ZSTD): 362 self.__read_transforms.insert(0, trans_id) 363 elif trans_id == TRANSFORM.HMAC: 364 raise TApplicationException( 365 TApplicationException.INVALID_TRANSFORM, 366 "Hmac transform is no longer supported: %i" % trans_id) 367 else: 368 # TApplicationException will be sent back to client 369 raise TApplicationException( 370 TApplicationException.INVALID_TRANSFORM, 371 "Unknown transform in client request: %i" % trans_id) 372 373 # Clear out previous info headers. 374 self.__read_headers.clear() 375 376 # Read the info headers. 377 while data.tell() < end_header: 378 info_id = readVarint(data) 379 if info_id == INFO.NORMAL: 380 _read_info_headers( 381 data, end_header, self.__read_headers) 382 elif info_id == INFO.PERSISTENT: 383 _read_info_headers( 384 data, end_header, self.__read_persistent_headers) 385 else: 386 break # Unknown header. Stop info processing. 387 388 if self.__read_persistent_headers: 389 self.__read_headers.update(self.__read_persistent_headers) 390 391 # Skip the rest of the header 392 data.seek(end_header) 393 394 payload = data.read(sz - header_size) 395 396 # Read the data section. 397 self.__rbuf = StringIO(self.untransform(payload)) 398 399 def write(self, buf): 400 self.__wbuf.write(buf) 401 402 def transform(self, buf): 403 for trans_id in self.__write_transforms: 404 if trans_id == TRANSFORM.ZLIB: 405 buf = zlib.compress(buf) 406 elif trans_id == TRANSFORM.SNAPPY: 407 buf = snappy.compress(buf) 408 elif trans_id == TRANSFORM.ZSTD: 409 buf = zstd.ZstdCompressor(write_content_size=True).compress(buf) 410 else: 411 raise TTransportException(TTransportException.INVALID_TRANSFORM, 412 "Unknown transform during send") 413 return buf 414 415 def untransform(self, buf): 416 for trans_id in self.__read_transforms: 417 if trans_id == TRANSFORM.ZLIB: 418 buf = zlib.decompress(buf) 419 elif trans_id == TRANSFORM.SNAPPY: 420 buf = snappy.decompress(buf) 421 elif trans_id == TRANSFORM.ZSTD: 422 buf = zstd.ZstdDecompressor().decompress(buf) 423 if trans_id not in self.__write_transforms: 424 self.__write_transforms.append(trans_id) 425 return buf 426 427 def disable_client_metadata(self): 428 self.__first_request = False 429 430 def flush(self): 431 self.flushImpl(False) 432 433 def onewayFlush(self): 434 self.flushImpl(True) 435 436 def _flushHeaderMessage(self, buf, wout, wsz): 437 """Write a message for CLIENT_TYPE.HEADER 438 439 @param buf(StringIO): Buffer to write message to 440 @param wout(str): Payload 441 @param wsz(int): Payload length 442 """ 443 transform_data = StringIO() 444 # For now, all transforms don't require data. 445 num_transforms = len(self.__write_transforms) 446 for trans_id in self.__write_transforms: 447 transform_data.write(getVarint(trans_id)) 448 449 # Add in special flags. 450 if self.__identity: 451 self.__write_headers[self.ID_VERSION_HEADER] = self.ID_VERSION 452 self.__write_headers[self.IDENTITY_HEADER] = self.__identity 453 454 if self.__first_request: 455 self.__first_request = False 456 self.__write_headers[self.CLIENT_METADATA_HEADER] = \ 457 "{\"agent\":\"THeaderTransport.py\"}" 458 459 460 info_data = StringIO() 461 462 # Write persistent kv-headers 463 _flush_info_headers(info_data, 464 self.get_write_persistent_headers(), 465 INFO.PERSISTENT) 466 467 # Write non-persistent kv-headers 468 _flush_info_headers(info_data, 469 self.__write_headers, 470 INFO.NORMAL) 471 472 header_data = StringIO() 473 header_data.write(getVarint(self.__proto_id)) 474 header_data.write(getVarint(num_transforms)) 475 476 header_size = transform_data.tell() + header_data.tell() + \ 477 info_data.tell() 478 479 padding_size = 4 - (header_size % 4) 480 header_size = header_size + padding_size 481 482 # MAGIC(2) | FLAGS(2) + SEQ_ID(4) + HEADER_SIZE(2) 483 wsz += header_size + 10 484 if wsz > MAX_FRAME_SIZE: 485 buf.write(pack("!I", BIG_FRAME_MAGIC)) 486 buf.write(pack("!Q", wsz)) 487 else: 488 buf.write(pack("!I", wsz)) 489 buf.write(pack("!HH", HEADER_MAGIC >> 16, self.__flags)) 490 buf.write(pack("!I", self.seq_id)) 491 buf.write(pack("!H", header_size // 4)) 492 493 buf.write(header_data.getvalue()) 494 buf.write(transform_data.getvalue()) 495 buf.write(info_data.getvalue()) 496 497 # Pad out the header with 0x00 498 for _ in range(0, padding_size, 1): 499 buf.write(pack("!c", b'\0')) 500 501 # Send data section 502 buf.write(wout) 503 504 def flushImpl(self, oneway): 505 wout = self.__wbuf.getvalue() 506 wout = self.transform(wout) 507 wsz = len(wout) 508 509 # reset wbuf before write/flush to preserve state on underlying failure 510 self.__wbuf.seek(0) 511 self.__wbuf.truncate() 512 513 if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER: 514 raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, 515 "Trying to send JSON encoding over binary") 516 517 buf = StringIO() 518 if self.__client_type == CLIENT_TYPE.HEADER: 519 self._flushHeaderMessage(buf, wout, wsz) 520 elif self.__client_type in (CLIENT_TYPE.FRAMED_DEPRECATED, 521 CLIENT_TYPE.FRAMED_COMPACT): 522 buf.write(pack("!i", wsz)) 523 buf.write(wout) 524 elif self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, 525 CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): 526 buf.write(wout) 527 elif self.__client_type == CLIENT_TYPE.HTTP_SERVER: 528 # Reset the client type if we sent something - 529 # oneway calls via HTTP expect a status response otherwise 530 buf.write(self.header.getvalue()) 531 buf.write(wout) 532 self.__client_type == CLIENT_TYPE.HEADER 533 elif self.__client_type == CLIENT_TYPE.UNKNOWN: 534 raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, 535 "Unknown client type") 536 537 # We don't include the framing bytes as part of the frame size check 538 frame_size = buf.tell() - (4 if wsz < MAX_FRAME_SIZE else 12) 539 _frame_size_check(frame_size, 540 self.__max_frame_size, 541 header=self.__client_type == CLIENT_TYPE.HEADER) 542 self.getTransport().write(buf.getvalue()) 543 if oneway: 544 self.getTransport().onewayFlush() 545 else: 546 self.getTransport().flush() 547 548 # Implement the CReadableTransport interface. 549 @property 550 def cstringio_buf(self): 551 if not self.__rbuf_frame: 552 self.readFrame(0) 553 return self.__rbuf 554 555 def cstringio_refill(self, prefix, reqlen): 556 # self.__rbuf will already be empty here because fastproto doesn't 557 # ask for a refill until the previous buffer is empty. Therefore, 558 # we can start reading new frames immediately. 559 560 # On unframed clients, there is a chance there is something left 561 # in rbuf, and the read pointer is not advanced by fastproto 562 # so seek to the end to be safe 563 self.__rbuf.seek(0, 2) 564 while len(prefix) < reqlen: 565 prefix += self.read(reqlen) 566 self.__rbuf = StringIO(prefix) 567 return self.__rbuf 568 569 570def _serialize_string(str_): 571 if PY3 and not isinstance(str_, bytes): 572 str_ = str_.encode() 573 return getVarint(len(str_)) + str_ 574 575 576def _flush_info_headers(info_data, write_headers, type): 577 if (len(write_headers) > 0): 578 info_data.write(getVarint(type)) 579 info_data.write(getVarint(len(write_headers))) 580 write_headers_iter = write_headers.items() 581 for str_key, str_value in write_headers_iter: 582 info_data.write(_serialize_string(str_key)) 583 info_data.write(_serialize_string(str_value)) 584 write_headers.clear() 585 586 587def _read_string(bufio, buflimit): 588 str_sz = readVarint(bufio) 589 if str_sz + bufio.tell() > buflimit: 590 raise TTransportException(TTransportException.INVALID_FRAME_SIZE, 591 "String read too big") 592 return bufio.read(str_sz) 593 594 595def _read_info_headers(data, end_header, read_headers): 596 num_keys = readVarint(data) 597 for _ in xrange(num_keys): 598 str_key = _read_string(data, end_header) 599 str_value = _read_string(data, end_header) 600 read_headers[str_key] = str_value 601 602 603def _frame_size_check(sz, set_max_size, header=True): 604 if sz > set_max_size or (not header and sz > MAX_FRAME_SIZE): 605 raise TTransportException( 606 TTransportException.INVALID_FRAME_SIZE, 607 "%s transport frame was too large" % 'Header' if header else 'Framed' 608 ) 609 610 611class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): 612 613 # Same as superclass function, but append 'POST' because we 614 # stripped it in the calling function. Would be nice if 615 # we had an ungetch instead 616 def handle_one_request(self): 617 self.raw_requestline = self.rfile.readline() 618 if not self.raw_requestline: 619 self.close_connection = 1 620 return 621 self.raw_requestline = "POST" + self.raw_requestline 622 if not self.parse_request(): 623 # An error code has been sent, just exit 624 return 625 mname = 'do_' + self.command 626 if not hasattr(self, mname): 627 self.send_error(501, "Unsupported method (%r)" % self.command) 628 return 629 method = getattr(self, mname) 630 method() 631 632 def setup(self): 633 self.rfile = self.request 634 self.wfile = StringIO() # New output buffer 635 636 def finish(self): 637 if not self.rfile.closed: 638 self.rfile.close() 639 # leave wfile open for reading. 640 641 def do_POST(self): 642 if int(self.headers['Content-Length']) > 0: 643 self.data = self.rfile.read(int(self.headers['Content-Length'])) 644 else: 645 self.data = "" 646 647 # Prepare a response header, to be sent later. 648 self.send_response(200) 649 self.send_header("content-type", "application/x-thrift") 650 self.end_headers() 651