1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# pyre-unsafe
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from __future__ import unicode_literals
21
22import sys
23if sys.version_info[0] >= 3:
24    from http import server
25    # pyre-fixme[11]: Annotation `server` is not defined as a type.
26    BaseHTTPServer = server
27    xrange = range
28    from io import BytesIO as StringIO
29    PY3 = True
30else:
31    import BaseHTTPServer  # @manual
32    from cStringIO import StringIO
33    PY3 = False
34
35from struct import pack, unpack
36import zlib
37
38from thrift.Thrift import TApplicationException
39from thrift.protocol.TBinaryProtocol import TBinaryProtocol
40from thrift.transport.TTransport import (
41    TTransportException, TTransportBase, CReadableTransport
42)
43from thrift.protocol.TCompactProtocol import (
44    getVarint, readVarint, TCompactProtocol
45)
46
47# Import the snappy module if it is available
48try:
49    import snappy
50except ImportError:
51    # If snappy is not available, don't fail immediately.
52    # Only raise an error if we actually ever need to perform snappy
53    # compression.
54    class DummySnappy(object):
55        def compress(self, buf):
56            raise TTransportException(TTransportException.INVALID_TRANSFORM,
57                                      'snappy module not available')
58
59        def decompress(self, buf):
60            raise TTransportException(TTransportException.INVALID_TRANSFORM,
61                                      'snappy module not available')
62    snappy = DummySnappy()
63
64# Import the zstd module if it is available
65try:
66    import zstd  # @manual
67except ImportError:
68    # If zstd is not available, don't fail immediately.
69    # Only raise an error if we actually ever need to perform zstd
70    # compression.
71    class DummyZstd(object):
72        def ZstdCompressor(self, write_content_size):
73            raise TTransportException(TTransportException.INVALID_TRANSFORM,
74                                      'zstd module not available')
75
76        def ZstdDecompressor(self):
77            raise TTransportException(TTransportException.INVALID_TRANSFORM,
78                                      'zstd module not available')
79    zstd = DummyZstd()
80
81
82# Definitions from THeader.h
83
84
85class CLIENT_TYPE:
86    HEADER = 0
87    FRAMED_DEPRECATED = 1
88    UNFRAMED_DEPRECATED = 2
89    HTTP_SERVER = 3
90    HTTP_CLIENT = 4
91    FRAMED_COMPACT = 5
92    HTTP_GET = 7
93    UNKNOWN = 8
94    UNFRAMED_COMPACT_DEPRECATED = 9
95
96
97class HEADER_FLAG:
98    SUPPORT_OUT_OF_ORDER = 0x01
99    DUPLEX_REVERSE = 0x08
100
101
102class TRANSFORM:
103    NONE = 0x00
104    ZLIB = 0x01
105    HMAC = 0x02
106    SNAPPY = 0x03
107    QLZ = 0x04
108    ZSTD = 0x05
109
110
111class INFO:
112    NORMAL = 1
113    PERSISTENT = 2
114
115
116T_BINARY_PROTOCOL = 0
117T_COMPACT_PROTOCOL = 2
118HEADER_MAGIC = 0x0FFF0000
119PACKED_HEADER_MAGIC = pack(b'!H', HEADER_MAGIC >> 16)
120HEADER_MASK = 0xFFFF0000
121FLAGS_MASK = 0x0000FFFF
122HTTP_SERVER_MAGIC = 0x504F5354  # POST
123HTTP_CLIENT_MAGIC = 0x48545450  # HTTP
124HTTP_GET_CLIENT_MAGIC = 0x47455420  # GET
125HTTP_HEAD_CLIENT_MAGIC = 0x48454144  # HEAD
126BIG_FRAME_MAGIC = 0x42494746  # BIGF
127MAX_FRAME_SIZE = 0x3FFFFFFF
128MAX_BIG_FRAME_SIZE = 2 ** 61 - 1
129
130
131class THeaderTransport(TTransportBase, CReadableTransport):
132    """Transport that sends headers.  Also understands framed/unframed/HTTP
133    transports and will do the right thing"""
134
135    __max_frame_size = MAX_FRAME_SIZE
136
137    # Defaults to current user, but there is also a setter below.
138    __identity = None
139    __first_request = True
140    IDENTITY_HEADER = "identity"
141    ID_VERSION_HEADER = "id_version"
142    ID_VERSION = "1"
143    CLIENT_METADATA_HEADER = "client_metadata";
144
145    def __init__(self, trans, client_types=None, client_type=None):
146        self.__trans = trans
147        self.__rbuf = StringIO()
148        self.__rbuf_frame = False
149        self.__wbuf = StringIO()
150        self.seq_id = 0
151        self.__flags = 0
152        self.__read_transforms = []
153        self.__write_transforms = []
154        self.__supported_client_types = set(client_types or
155                                            (CLIENT_TYPE.HEADER,))
156        self.__proto_id = T_COMPACT_PROTOCOL  # default to compact like c++
157        self.__client_type = client_type or CLIENT_TYPE.HEADER
158        self.__read_headers = {}
159        self.__read_persistent_headers = {}
160        self.__write_headers = {}
161        self.__write_persistent_headers = {}
162
163        self.__supported_client_types.add(self.__client_type)
164
165        # If we support unframed binary / framed binary also support compact
166        if CLIENT_TYPE.UNFRAMED_DEPRECATED in self.__supported_client_types:
167            self.__supported_client_types.add(
168                CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED)
169        if CLIENT_TYPE.FRAMED_DEPRECATED in self.__supported_client_types:
170            self.__supported_client_types.add(
171                CLIENT_TYPE.FRAMED_COMPACT)
172
173    def set_header_flag(self, flag):
174        self.__flags |= flag
175
176    def clear_header_flag(self, flag):
177        self.__flags &= ~ flag
178
179    def header_flags(self):
180        return self.__flags
181
182    def set_max_frame_size(self, size):
183        if size > MAX_BIG_FRAME_SIZE:
184            raise TTransportException(TTransportException.INVALID_FRAME_SIZE,
185                                      "Cannot set max frame size > %s" %
186                                      MAX_BIG_FRAME_SIZE)
187        if size > MAX_FRAME_SIZE and self.__client_type != CLIENT_TYPE.HEADER:
188            raise TTransportException(
189                TTransportException.INVALID_FRAME_SIZE,
190                "Cannot set max frame size > %s for clients other than HEADER"
191                % MAX_FRAME_SIZE)
192        self.__max_frame_size = size
193
194    def get_peer_identity(self):
195        if self.IDENTITY_HEADER in self.__read_headers:
196            if self.__read_headers[self.ID_VERSION_HEADER] == self.ID_VERSION:
197                return self.__read_headers[self.IDENTITY_HEADER]
198        return None
199
200    def set_identity(self, identity):
201        self.__identity = identity
202
203    def get_protocol_id(self):
204        return self.__proto_id
205
206    def set_protocol_id(self, proto_id):
207        self.__proto_id = proto_id
208
209    def set_header(self, str_key, str_value):
210        self.__write_headers[str_key] = str_value
211
212    def get_write_headers(self):
213        return self.__write_headers
214
215    def get_headers(self):
216        return self.__read_headers
217
218    def clear_headers(self):
219        self.__write_headers.clear()
220
221    def set_persistent_header(self, str_key, str_value):
222        self.__write_persistent_headers[str_key] = str_value
223
224    def get_write_persistent_headers(self):
225        return self.__write_persistent_headers
226
227    def clear_persistent_headers(self):
228        self.__write_persistent_headers.clear()
229
230    def add_transform(self, trans_id):
231        self.__write_transforms.append(trans_id)
232
233    def _reset_protocol(self):
234        # HTTP calls that are one way need to flush here.
235        if self.__client_type == CLIENT_TYPE.HTTP_SERVER:
236            self.flush()
237        # set to anything except unframed
238        self.__client_type = CLIENT_TYPE.UNKNOWN
239        # Read header bytes to check which protocol to decode
240        self.readFrame(0)
241
242    def getTransport(self):
243        return self.__trans
244
245    def isOpen(self):
246        return self.getTransport().isOpen()
247
248    def open(self):
249        return self.getTransport().open()
250
251    def close(self):
252        return self.getTransport().close()
253
254    def read(self, sz):
255        ret = self.__rbuf.read(sz)
256        if len(ret) == sz:
257            return ret
258
259        if self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED,
260                                  CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED):
261            return ret + self.getTransport().readAll(sz - len(ret))
262
263        self.readFrame(sz - len(ret))
264        return ret + self.__rbuf.read(sz - len(ret))
265
266    readAll = read  # TTransportBase.readAll does a needless copy here.
267
268    def readFrame(self, req_sz):
269        self.__rbuf_frame = True
270        word1 = self.getTransport().readAll(4)
271        sz = unpack('!I', word1)[0]
272        proto_id = word1[0] if PY3 else ord(word1[0])
273        if proto_id == TBinaryProtocol.PROTOCOL_ID:
274            # unframed
275            self.__client_type = CLIENT_TYPE.UNFRAMED_DEPRECATED
276            self.__proto_id = T_BINARY_PROTOCOL
277            if req_sz <= 4:  # check for reads < 0.
278                self.__rbuf = StringIO(word1)
279            else:
280                self.__rbuf = StringIO(word1 + self.getTransport().read(
281                    req_sz - 4))
282        elif proto_id == TCompactProtocol.PROTOCOL_ID:
283            self.__client_type = CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED
284            self.__proto_id = T_COMPACT_PROTOCOL
285            if req_sz <= 4:  # check for reads < 0.
286                self.__rbuf = StringIO(word1)
287            else:
288                self.__rbuf = StringIO(word1 + self.getTransport().read(
289                    req_sz - 4))
290        elif sz == HTTP_SERVER_MAGIC:
291            self.__client_type = CLIENT_TYPE.HTTP_SERVER
292            mf = self.getTransport().handle.makefile('rb', -1)
293
294            self.handler = RequestHandler(mf,
295                                          'client_address:port', '')
296            self.header = self.handler.wfile
297            self.__rbuf = StringIO(self.handler.data)
298        else:
299            if sz == BIG_FRAME_MAGIC:
300                sz = unpack('!Q', self.getTransport().readAll(8))[0]
301            # could be header format or framed.  Check next two bytes.
302            magic = self.getTransport().readAll(2)
303            proto_id = magic[0] if PY3 else ord(magic[0])
304            if proto_id == TCompactProtocol.PROTOCOL_ID:
305                self.__client_type = CLIENT_TYPE.FRAMED_COMPACT
306                self.__proto_id = T_COMPACT_PROTOCOL
307                _frame_size_check(sz, self.__max_frame_size, header=False)
308                self.__rbuf = StringIO(magic + self.getTransport().readAll(
309                    sz - 2))
310            elif proto_id == TBinaryProtocol.PROTOCOL_ID:
311                self.__client_type = CLIENT_TYPE.FRAMED_DEPRECATED
312                self.__proto_id = T_BINARY_PROTOCOL
313                _frame_size_check(sz, self.__max_frame_size, header=False)
314                self.__rbuf = StringIO(magic + self.getTransport().readAll(
315                    sz - 2))
316            elif magic == PACKED_HEADER_MAGIC:
317                self.__client_type = CLIENT_TYPE.HEADER
318                _frame_size_check(sz, self.__max_frame_size)
319                # flags(2), seq_id(4), header_size(2)
320                n_header_meta = self.getTransport().readAll(8)
321                self.__flags, self.seq_id, header_size = unpack('!HIH',
322                                                                n_header_meta)
323                data = StringIO()
324                data.write(magic)
325                data.write(n_header_meta)
326                data.write(self.getTransport().readAll(sz - 10))
327                data.seek(10)
328                self.read_header_format(sz - 10, header_size, data)
329            else:
330                self.__client_type = CLIENT_TYPE.UNKNOWN
331                raise TTransportException(
332                    TTransportException.INVALID_CLIENT_TYPE,
333                    "Could not detect client transport type")
334
335        if self.__client_type not in self.__supported_client_types:
336            raise TTransportException(TTransportException.INVALID_CLIENT_TYPE,
337                                      "Client type {} not supported on server"
338                                      .format(self.__client_type))
339
340    def read_header_format(self, sz, header_size, data):
341        # clear out any previous transforms
342        self.__read_transforms = []
343
344        header_size = header_size * 4
345        if header_size > sz:
346            raise TTransportException(TTransportException.INVALID_FRAME_SIZE,
347                                      "Header size is larger than frame")
348        end_header = header_size + data.tell()
349
350        self.__proto_id = readVarint(data)
351        num_headers = readVarint(data)
352
353        if self.__proto_id == 1 and self.__client_type != \
354                CLIENT_TYPE.HTTP_SERVER:
355            raise TTransportException(TTransportException.INVALID_CLIENT_TYPE,
356                                      "Trying to recv JSON encoding over binary")
357
358        # Read the headers.  Data for each header varies.
359        for _ in range(0, num_headers):
360            trans_id = readVarint(data)
361            if trans_id in (TRANSFORM.ZLIB, TRANSFORM.SNAPPY, TRANSFORM.ZSTD):
362                self.__read_transforms.insert(0, trans_id)
363            elif trans_id == TRANSFORM.HMAC:
364                raise TApplicationException(
365                    TApplicationException.INVALID_TRANSFORM,
366                    "Hmac transform is no longer supported: %i" % trans_id)
367            else:
368                # TApplicationException will be sent back to client
369                raise TApplicationException(
370                    TApplicationException.INVALID_TRANSFORM,
371                    "Unknown transform in client request: %i" % trans_id)
372
373        # Clear out previous info headers.
374        self.__read_headers.clear()
375
376        # Read the info headers.
377        while data.tell() < end_header:
378            info_id = readVarint(data)
379            if info_id == INFO.NORMAL:
380                _read_info_headers(
381                    data, end_header, self.__read_headers)
382            elif info_id == INFO.PERSISTENT:
383                _read_info_headers(
384                    data, end_header, self.__read_persistent_headers)
385            else:
386                break  # Unknown header.  Stop info processing.
387
388        if self.__read_persistent_headers:
389            self.__read_headers.update(self.__read_persistent_headers)
390
391        # Skip the rest of the header
392        data.seek(end_header)
393
394        payload = data.read(sz - header_size)
395
396        # Read the data section.
397        self.__rbuf = StringIO(self.untransform(payload))
398
399    def write(self, buf):
400        self.__wbuf.write(buf)
401
402    def transform(self, buf):
403        for trans_id in self.__write_transforms:
404            if trans_id == TRANSFORM.ZLIB:
405                buf = zlib.compress(buf)
406            elif trans_id == TRANSFORM.SNAPPY:
407                buf = snappy.compress(buf)
408            elif trans_id == TRANSFORM.ZSTD:
409                buf = zstd.ZstdCompressor(write_content_size=True).compress(buf)
410            else:
411                raise TTransportException(TTransportException.INVALID_TRANSFORM,
412                                          "Unknown transform during send")
413        return buf
414
415    def untransform(self, buf):
416        for trans_id in self.__read_transforms:
417            if trans_id == TRANSFORM.ZLIB:
418                buf = zlib.decompress(buf)
419            elif trans_id == TRANSFORM.SNAPPY:
420                buf = snappy.decompress(buf)
421            elif trans_id == TRANSFORM.ZSTD:
422                buf = zstd.ZstdDecompressor().decompress(buf)
423            if trans_id not in self.__write_transforms:
424                self.__write_transforms.append(trans_id)
425        return buf
426
427    def disable_client_metadata(self):
428        self.__first_request = False
429
430    def flush(self):
431        self.flushImpl(False)
432
433    def onewayFlush(self):
434        self.flushImpl(True)
435
436    def _flushHeaderMessage(self, buf, wout, wsz):
437        """Write a message for CLIENT_TYPE.HEADER
438
439        @param buf(StringIO): Buffer to write message to
440        @param wout(str): Payload
441        @param wsz(int): Payload length
442        """
443        transform_data = StringIO()
444        # For now, all transforms don't require data.
445        num_transforms = len(self.__write_transforms)
446        for trans_id in self.__write_transforms:
447            transform_data.write(getVarint(trans_id))
448
449        # Add in special flags.
450        if self.__identity:
451            self.__write_headers[self.ID_VERSION_HEADER] = self.ID_VERSION
452            self.__write_headers[self.IDENTITY_HEADER] = self.__identity
453
454        if self.__first_request:
455            self.__first_request = False
456            self.__write_headers[self.CLIENT_METADATA_HEADER] = \
457                "{\"agent\":\"THeaderTransport.py\"}"
458
459
460        info_data = StringIO()
461
462        # Write persistent kv-headers
463        _flush_info_headers(info_data,
464                            self.get_write_persistent_headers(),
465                            INFO.PERSISTENT)
466
467        # Write non-persistent kv-headers
468        _flush_info_headers(info_data,
469                            self.__write_headers,
470                            INFO.NORMAL)
471
472        header_data = StringIO()
473        header_data.write(getVarint(self.__proto_id))
474        header_data.write(getVarint(num_transforms))
475
476        header_size = transform_data.tell() + header_data.tell() + \
477            info_data.tell()
478
479        padding_size = 4 - (header_size % 4)
480        header_size = header_size + padding_size
481
482        # MAGIC(2) | FLAGS(2) + SEQ_ID(4) + HEADER_SIZE(2)
483        wsz += header_size + 10
484        if wsz > MAX_FRAME_SIZE:
485            buf.write(pack("!I", BIG_FRAME_MAGIC))
486            buf.write(pack("!Q", wsz))
487        else:
488            buf.write(pack("!I", wsz))
489        buf.write(pack("!HH", HEADER_MAGIC >> 16, self.__flags))
490        buf.write(pack("!I", self.seq_id))
491        buf.write(pack("!H", header_size // 4))
492
493        buf.write(header_data.getvalue())
494        buf.write(transform_data.getvalue())
495        buf.write(info_data.getvalue())
496
497        # Pad out the header with 0x00
498        for _ in range(0, padding_size, 1):
499            buf.write(pack("!c", b'\0'))
500
501        # Send data section
502        buf.write(wout)
503
504    def flushImpl(self, oneway):
505        wout = self.__wbuf.getvalue()
506        wout = self.transform(wout)
507        wsz = len(wout)
508
509        # reset wbuf before write/flush to preserve state on underlying failure
510        self.__wbuf.seek(0)
511        self.__wbuf.truncate()
512
513        if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER:
514            raise TTransportException(TTransportException.INVALID_CLIENT_TYPE,
515                                      "Trying to send JSON encoding over binary")
516
517        buf = StringIO()
518        if self.__client_type == CLIENT_TYPE.HEADER:
519            self._flushHeaderMessage(buf, wout, wsz)
520        elif self.__client_type in (CLIENT_TYPE.FRAMED_DEPRECATED,
521                                    CLIENT_TYPE.FRAMED_COMPACT):
522            buf.write(pack("!i", wsz))
523            buf.write(wout)
524        elif self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED,
525                                    CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED):
526            buf.write(wout)
527        elif self.__client_type == CLIENT_TYPE.HTTP_SERVER:
528            # Reset the client type if we sent something -
529            # oneway calls via HTTP expect a status response otherwise
530            buf.write(self.header.getvalue())
531            buf.write(wout)
532            self.__client_type == CLIENT_TYPE.HEADER
533        elif self.__client_type == CLIENT_TYPE.UNKNOWN:
534            raise TTransportException(TTransportException.INVALID_CLIENT_TYPE,
535                                      "Unknown client type")
536
537        # We don't include the framing bytes as part of the frame size check
538        frame_size = buf.tell() - (4 if wsz < MAX_FRAME_SIZE else 12)
539        _frame_size_check(frame_size,
540                          self.__max_frame_size,
541                          header=self.__client_type == CLIENT_TYPE.HEADER)
542        self.getTransport().write(buf.getvalue())
543        if oneway:
544            self.getTransport().onewayFlush()
545        else:
546            self.getTransport().flush()
547
548    # Implement the CReadableTransport interface.
549    @property
550    def cstringio_buf(self):
551        if not self.__rbuf_frame:
552            self.readFrame(0)
553        return self.__rbuf
554
555    def cstringio_refill(self, prefix, reqlen):
556        # self.__rbuf will already be empty here because fastproto doesn't
557        # ask for a refill until the previous buffer is empty.  Therefore,
558        # we can start reading new frames immediately.
559
560        # On unframed clients, there is a chance there is something left
561        # in rbuf, and the read pointer is not advanced by fastproto
562        # so seek to the end to be safe
563        self.__rbuf.seek(0, 2)
564        while len(prefix) < reqlen:
565            prefix += self.read(reqlen)
566        self.__rbuf = StringIO(prefix)
567        return self.__rbuf
568
569
570def _serialize_string(str_):
571    if PY3 and not isinstance(str_, bytes):
572        str_ = str_.encode()
573    return getVarint(len(str_)) + str_
574
575
576def _flush_info_headers(info_data, write_headers, type):
577    if (len(write_headers) > 0):
578        info_data.write(getVarint(type))
579        info_data.write(getVarint(len(write_headers)))
580        write_headers_iter = write_headers.items()
581        for str_key, str_value in write_headers_iter:
582            info_data.write(_serialize_string(str_key))
583            info_data.write(_serialize_string(str_value))
584        write_headers.clear()
585
586
587def _read_string(bufio, buflimit):
588    str_sz = readVarint(bufio)
589    if str_sz + bufio.tell() > buflimit:
590        raise TTransportException(TTransportException.INVALID_FRAME_SIZE,
591                                  "String read too big")
592    return bufio.read(str_sz)
593
594
595def _read_info_headers(data, end_header, read_headers):
596    num_keys = readVarint(data)
597    for _ in xrange(num_keys):
598        str_key = _read_string(data, end_header)
599        str_value = _read_string(data, end_header)
600        read_headers[str_key] = str_value
601
602
603def _frame_size_check(sz, set_max_size, header=True):
604    if sz > set_max_size or (not header and sz > MAX_FRAME_SIZE):
605        raise TTransportException(
606            TTransportException.INVALID_FRAME_SIZE,
607            "%s transport frame was too large" % 'Header' if header else 'Framed'
608        )
609
610
611class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
612
613    # Same as superclass function, but append 'POST' because we
614    # stripped it in the calling function.  Would be nice if
615    # we had an ungetch instead
616    def handle_one_request(self):
617        self.raw_requestline = self.rfile.readline()
618        if not self.raw_requestline:
619            self.close_connection = 1
620            return
621        self.raw_requestline = "POST" + self.raw_requestline
622        if not self.parse_request():
623            # An error code has been sent, just exit
624            return
625        mname = 'do_' + self.command
626        if not hasattr(self, mname):
627            self.send_error(501, "Unsupported method (%r)" % self.command)
628            return
629        method = getattr(self, mname)
630        method()
631
632    def setup(self):
633        self.rfile = self.request
634        self.wfile = StringIO()  # New output buffer
635
636    def finish(self):
637        if not self.rfile.closed:
638            self.rfile.close()
639        # leave wfile open for reading.
640
641    def do_POST(self):
642        if int(self.headers['Content-Length']) > 0:
643            self.data = self.rfile.read(int(self.headers['Content-Length']))
644        else:
645            self.data = ""
646
647        # Prepare a response header, to be sent later.
648        self.send_response(200)
649        self.send_header("content-type", "application/x-thrift")
650        self.end_headers()
651