1#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21from __future__ import division
22import logging
23import os
24import signal
25import sys
26import time
27from optparse import OptionParser
28
29from util import local_libpath
30sys.path.insert(0, local_libpath())
31from thrift.protocol import TProtocol, TProtocolDecorator
32
33SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
34
35
36class TestHandler(object):
37    def testVoid(self):
38        if options.verbose > 1:
39            logging.info('testVoid()')
40
41    def testString(self, str):
42        if options.verbose > 1:
43            logging.info('testString(%s)' % str)
44        return str
45
46    def testBool(self, boolean):
47        if options.verbose > 1:
48            logging.info('testBool(%s)' % str(boolean).lower())
49        return boolean
50
51    def testByte(self, byte):
52        if options.verbose > 1:
53            logging.info('testByte(%d)' % byte)
54        return byte
55
56    def testI16(self, i16):
57        if options.verbose > 1:
58            logging.info('testI16(%d)' % i16)
59        return i16
60
61    def testI32(self, i32):
62        if options.verbose > 1:
63            logging.info('testI32(%d)' % i32)
64        return i32
65
66    def testI64(self, i64):
67        if options.verbose > 1:
68            logging.info('testI64(%d)' % i64)
69        return i64
70
71    def testDouble(self, dub):
72        if options.verbose > 1:
73            logging.info('testDouble(%f)' % dub)
74        return dub
75
76    def testBinary(self, thing):
77        if options.verbose > 1:
78            logging.info('testBinary()')  # TODO: hex output
79        return thing
80
81    def testStruct(self, thing):
82        if options.verbose > 1:
83            logging.info('testStruct({%s, %s, %s, %s})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing))
84        return thing
85
86    def testException(self, arg):
87        # if options.verbose > 1:
88        logging.info('testException(%s)' % arg)
89        if arg == 'Xception':
90            raise Xception(errorCode=1001, message=arg)
91        elif arg == 'TException':
92            raise TException(message='This is a TException')
93
94    def testMultiException(self, arg0, arg1):
95        if options.verbose > 1:
96            logging.info('testMultiException(%s, %s)' % (arg0, arg1))
97        if arg0 == 'Xception':
98            raise Xception(errorCode=1001, message='This is an Xception')
99        elif arg0 == 'Xception2':
100            raise Xception2(
101                errorCode=2002,
102                struct_thing=Xtruct(string_thing='This is an Xception2'))
103        return Xtruct(string_thing=arg1)
104
105    def testOneway(self, seconds):
106        if options.verbose > 1:
107            logging.info('testOneway(%d) => sleeping...' % seconds)
108        time.sleep(seconds / 3)  # be quick
109        if options.verbose > 1:
110            logging.info('done sleeping')
111
112    def testNest(self, thing):
113        if options.verbose > 1:
114            logging.info('testNest(%s)' % thing)
115        return thing
116
117    def testMap(self, thing):
118        if options.verbose > 1:
119            logging.info('testMap(%s)' % thing)
120        return thing
121
122    def testStringMap(self, thing):
123        if options.verbose > 1:
124            logging.info('testStringMap(%s)' % thing)
125        return thing
126
127    def testSet(self, thing):
128        if options.verbose > 1:
129            logging.info('testSet(%s)' % thing)
130        return thing
131
132    def testList(self, thing):
133        if options.verbose > 1:
134            logging.info('testList(%s)' % thing)
135        return thing
136
137    def testEnum(self, thing):
138        if options.verbose > 1:
139            logging.info('testEnum(%s)' % thing)
140        return thing
141
142    def testTypedef(self, thing):
143        if options.verbose > 1:
144            logging.info('testTypedef(%s)' % thing)
145        return thing
146
147    def testMapMap(self, thing):
148        if options.verbose > 1:
149            logging.info('testMapMap(%s)' % thing)
150        return {
151            -4: {
152                -4: -4,
153                -3: -3,
154                -2: -2,
155                -1: -1,
156            },
157            4: {
158                4: 4,
159                3: 3,
160                2: 2,
161                1: 1,
162            },
163        }
164
165    def testInsanity(self, argument):
166        if options.verbose > 1:
167            logging.info('testInsanity(%s)' % argument)
168        return {
169            1: {
170                2: argument,
171                3: argument,
172            },
173            2: {6: Insanity()},
174        }
175
176    def testMulti(self, arg0, arg1, arg2, arg3, arg4, arg5):
177        if options.verbose > 1:
178            logging.info('testMulti(%s, %s, %s, %s, %s, %s)' % (arg0, arg1, arg2, arg3, arg4, arg5))
179        return Xtruct(string_thing='Hello2',
180                      byte_thing=arg0, i32_thing=arg1, i64_thing=arg2)
181
182
183class SecondHandler(object):
184    def secondtestString(self, argument):
185        return "testString(\"" + argument + "\")"
186
187
188# LAST_SEQID is a global because we have one transport and multiple protocols
189# running on it (when multiplexed)
190LAST_SEQID = None
191
192
193class TPedanticSequenceIdProtocolWrapper(TProtocolDecorator.TProtocolDecorator):
194    """
195    Wraps any protocol with sequence ID checking: looks for outbound
196    uniqueness as well as request/response alignment.
197    """
198    def __init__(self, protocol):
199        # TProtocolDecorator.__new__ does all the heavy lifting
200        pass
201
202    def readMessageBegin(self):
203        global LAST_SEQID
204        (name, type, seqid) =\
205            super(TPedanticSequenceIdProtocolWrapper, self).readMessageBegin()
206        if LAST_SEQID is not None and LAST_SEQID == seqid:
207            raise TProtocol.TProtocolException(
208                TProtocol.TProtocolException.INVALID_DATA,
209                "We received the same seqid {0} twice in a row".format(seqid))
210        LAST_SEQID = seqid
211        return (name, type, seqid)
212
213
214def make_pedantic(proto):
215    """ Wrap a protocol in the pedantic sequence ID wrapper. """
216    # NOTE: this is disabled for now as many clients send seqid
217    #       of zero and that is okay, need a way to identify
218    #       clients that MUST send seqid unique to function right
219    #       or just force all implementations to send unique seqids (preferred)
220    return proto  # TPedanticSequenceIdProtocolWrapper(proto)
221
222
223class TPedanticSequenceIdProtocolFactory(TProtocol.TProtocolFactory):
224    def __init__(self, encapsulated):
225        super(TPedanticSequenceIdProtocolFactory, self).__init__()
226        self.encapsulated = encapsulated
227
228    def getProtocol(self, trans):
229        return make_pedantic(self.encapsulated.getProtocol(trans))
230
231
232def main(options):
233    # common header allowed client types
234    allowed_client_types = [
235        THeaderTransport.THeaderClientType.HEADERS,
236        THeaderTransport.THeaderClientType.FRAMED_BINARY,
237        THeaderTransport.THeaderClientType.UNFRAMED_BINARY,
238        THeaderTransport.THeaderClientType.FRAMED_COMPACT,
239        THeaderTransport.THeaderClientType.UNFRAMED_COMPACT,
240    ]
241
242    # set up the protocol factory form the --protocol option
243    prot_factories = {
244        'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory(),
245        'multia': TBinaryProtocol.TBinaryProtocolAcceleratedFactory(),
246        'accelc': TCompactProtocol.TCompactProtocolAcceleratedFactory(),
247        'multiac': TCompactProtocol.TCompactProtocolAcceleratedFactory(),
248        'binary': TPedanticSequenceIdProtocolFactory(TBinaryProtocol.TBinaryProtocolFactory()),
249        'multi': TPedanticSequenceIdProtocolFactory(TBinaryProtocol.TBinaryProtocolFactory()),
250        'compact': TCompactProtocol.TCompactProtocolFactory(),
251        'multic': TCompactProtocol.TCompactProtocolFactory(),
252        'header': THeaderProtocol.THeaderProtocolFactory(allowed_client_types),
253        'multih': THeaderProtocol.THeaderProtocolFactory(allowed_client_types),
254        'json': TJSONProtocol.TJSONProtocolFactory(),
255        'multij': TJSONProtocol.TJSONProtocolFactory(),
256    }
257    pfactory = prot_factories.get(options.proto, None)
258    if pfactory is None:
259        raise AssertionError('Unknown --protocol option: %s' % options.proto)
260    try:
261        pfactory.string_length_limit = options.string_limit
262        pfactory.container_length_limit = options.container_limit
263    except Exception:
264        # Ignore errors for those protocols that does not support length limit
265        pass
266
267    # get the server type (TSimpleServer, TNonblockingServer, etc...)
268    if len(args) > 1:
269        raise AssertionError('Only one server type may be specified, not multiple types.')
270    server_type = args[0]
271    if options.trans == 'http':
272        server_type = 'THttpServer'
273
274    # Set up the handler and processor objects
275    handler = TestHandler()
276    processor = ThriftTest.Processor(handler)
277
278    if options.proto.startswith('multi'):
279        secondHandler = SecondHandler()
280        secondProcessor = SecondService.Processor(secondHandler)
281
282        multiplexedProcessor = TMultiplexedProcessor()
283        multiplexedProcessor.registerDefault(processor)
284        multiplexedProcessor.registerProcessor('ThriftTest', processor)
285        multiplexedProcessor.registerProcessor('SecondService', secondProcessor)
286        processor = multiplexedProcessor
287
288    global server
289
290    # Handle THttpServer as a special case
291    if server_type == 'THttpServer':
292        if options.ssl:
293            __certfile = os.path.join(os.path.dirname(SCRIPT_DIR), "keys", "server.crt")
294            __keyfile = os.path.join(os.path.dirname(SCRIPT_DIR), "keys", "server.key")
295            server = THttpServer.THttpServer(processor, ('', options.port), pfactory, cert_file=__certfile, key_file=__keyfile)
296        else:
297            server = THttpServer.THttpServer(processor, ('', options.port), pfactory)
298        server.serve()
299        sys.exit(0)
300
301    # set up server transport and transport factory
302
303    abs_key_path = os.path.join(os.path.dirname(SCRIPT_DIR), 'keys', 'server.pem')
304
305    host = None
306    if options.ssl:
307        from thrift.transport import TSSLSocket
308        transport = TSSLSocket.TSSLServerSocket(host, options.port, certfile=abs_key_path)
309    else:
310        transport = TSocket.TServerSocket(host, options.port, options.domain_socket)
311    tfactory = TTransport.TBufferedTransportFactory()
312    if options.trans == 'buffered':
313        tfactory = TTransport.TBufferedTransportFactory()
314    elif options.trans == 'framed':
315        tfactory = TTransport.TFramedTransportFactory()
316    elif options.trans == '':
317        raise AssertionError('Unknown --transport option: %s' % options.trans)
318    else:
319        tfactory = TTransport.TBufferedTransportFactory()
320    # if --zlib, then wrap server transport, and use a different transport factory
321    if options.zlib:
322        transport = TZlibTransport.TZlibTransport(transport)  # wrap  with zlib
323        tfactory = TZlibTransport.TZlibTransportFactory()
324
325    # do server-specific setup here:
326    if server_type == "TNonblockingServer":
327        server = TNonblockingServer.TNonblockingServer(processor, transport, inputProtocolFactory=pfactory)
328    elif server_type == "TProcessPoolServer":
329        import signal
330        from thrift.server import TProcessPoolServer
331        server = TProcessPoolServer.TProcessPoolServer(processor, transport, tfactory, pfactory)
332        server.setNumWorkers(5)
333
334        def set_alarm():
335            def clean_shutdown(signum, frame):
336                for worker in server.workers:
337                    if options.verbose > 0:
338                        logging.info('Terminating worker: %s' % worker)
339                    worker.terminate()
340                if options.verbose > 0:
341                    logging.info('Requesting server to stop()')
342                try:
343                    server.stop()
344                except Exception:
345                    pass
346            signal.signal(signal.SIGALRM, clean_shutdown)
347            signal.alarm(4)
348        set_alarm()
349    else:
350        # look up server class dynamically to instantiate server
351        ServerClass = getattr(TServer, server_type)
352        server = ServerClass(processor, transport, tfactory, pfactory)
353    # enter server main loop
354    server.serve()
355
356
357def exit_gracefully(signum, frame):
358    print("SIGINT received\n")
359    server.shutdown()   # doesn't work properly, yet
360    sys.exit(0)
361
362
363if __name__ == '__main__':
364    signal.signal(signal.SIGINT, exit_gracefully)
365
366    parser = OptionParser()
367    parser.add_option('--libpydir', type='string', dest='libpydir',
368                      help='include this directory to sys.path for locating library code')
369    parser.add_option('--genpydir', type='string', dest='genpydir',
370                      default='gen-py',
371                      help='include this directory to sys.path for locating generated code')
372    parser.add_option("--port", type="int", dest="port",
373                      help="port number for server to listen on")
374    parser.add_option("--zlib", action="store_true", dest="zlib",
375                      help="use zlib wrapper for compressed transport")
376    parser.add_option("--ssl", action="store_true", dest="ssl",
377                      help="use SSL for encrypted transport")
378    parser.add_option('-v', '--verbose', action="store_const",
379                      dest="verbose", const=2,
380                      help="verbose output")
381    parser.add_option('-q', '--quiet', action="store_const",
382                      dest="verbose", const=0,
383                      help="minimal output")
384    parser.add_option('--protocol', dest="proto", type="string",
385                      help="protocol to use, one of: accel, accelc, binary, compact, json, multi, multia, multiac, multic, multih, multij")
386    parser.add_option('--transport', dest="trans", type="string",
387                      help="transport to use, one of: buffered, framed, http")
388    parser.add_option('--domain-socket', dest="domain_socket", type="string",
389                      help="Unix domain socket path")
390    parser.add_option('--container-limit', dest='container_limit', type='int', default=None)
391    parser.add_option('--string-limit', dest='string_limit', type='int', default=None)
392    parser.set_defaults(port=9090, verbose=1, proto='binary', transport='buffered')
393    options, args = parser.parse_args()
394
395    # Print TServer log to stdout so that the test-runner can redirect it to log files
396    logging.basicConfig(level=options.verbose)
397
398    sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir))
399
400    from ThriftTest import ThriftTest, SecondService
401    from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity
402    from thrift.Thrift import TException
403    from thrift.TMultiplexedProcessor import TMultiplexedProcessor
404    from thrift.transport import THeaderTransport
405    from thrift.transport import TTransport
406    from thrift.transport import TSocket
407    from thrift.transport import TZlibTransport
408    from thrift.protocol import TBinaryProtocol
409    from thrift.protocol import TCompactProtocol
410    from thrift.protocol import THeaderProtocol
411    from thrift.protocol import TJSONProtocol
412    from thrift.server import TServer, TNonblockingServer, THttpServer
413
414    sys.exit(main(options))
415