1#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21
22import datetime
23import glob
24import os
25import sys
26import time
27import unittest
28
29basepath = os.path.abspath(os.path.dirname(__file__))
30sys.path.insert(0, basepath + '/gen-py.tornado')
31sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib*'))[0])
32
33try:
34    __import__('tornado')
35except ImportError:
36    print("module `tornado` not found, skipping test")
37    sys.exit(0)
38
39from tornado import gen
40from tornado.testing import AsyncTestCase, get_unused_port, gen_test
41
42from thrift import TTornado
43from thrift.Thrift import TApplicationException
44from thrift.protocol import TBinaryProtocol
45
46from ThriftTest import ThriftTest
47from ThriftTest.ttypes import Xception, Xtruct
48
49
50class TestHandler(object):
51    def __init__(self, test_instance):
52        self.test_instance = test_instance
53
54    def testVoid(self):
55        pass
56
57    def testString(self, s):
58        if s == 'unexpected_error':
59            raise Exception(s)
60        return s
61
62    def testByte(self, b):
63        return b
64
65    def testI16(self, i16):
66        return i16
67
68    def testI32(self, i32):
69        return i32
70
71    def testI64(self, i64):
72        return i64
73
74    def testDouble(self, dub):
75        return dub
76
77    def testBinary(self, thing):
78        return thing
79
80    def testStruct(self, thing):
81        return thing
82
83    def testException(self, s):
84        if s == 'Xception':
85            x = Xception()
86            x.errorCode = 1001
87            x.message = s
88            raise x
89        elif s == 'throw_undeclared':
90            raise ValueError('testing undeclared exception')
91
92    def testOneway(self, seconds):
93        start = time.time()
94
95        def fire_oneway():
96            end = time.time()
97            self.test_instance.stop((start, end, seconds))
98
99        self.test_instance.io_loop.add_timeout(
100            datetime.timedelta(seconds=seconds),
101            fire_oneway)
102        raise Exception('testing exception in oneway method')
103
104    def testNest(self, thing):
105        return thing
106
107    @gen.coroutine
108    def testMap(self, thing):
109        yield gen.moment
110        raise gen.Return(thing)
111
112    def testSet(self, thing):
113        return thing
114
115    def testList(self, thing):
116        return thing
117
118    def testEnum(self, thing):
119        return thing
120
121    def testTypedef(self, thing):
122        return thing
123
124
125class ThriftTestCase(AsyncTestCase):
126    def setUp(self):
127        super(ThriftTestCase, self).setUp()
128
129        self.port = get_unused_port()
130
131        # server
132        self.handler = TestHandler(self)
133        self.processor = ThriftTest.Processor(self.handler)
134        self.pfactory = TBinaryProtocol.TBinaryProtocolFactory()
135
136        self.server = TTornado.TTornadoServer(self.processor, self.pfactory, io_loop=self.io_loop)
137        self.server.bind(self.port)
138        self.server.start(1)
139
140        # client
141        transport = TTornado.TTornadoStreamTransport('localhost', self.port, io_loop=self.io_loop)
142        pfactory = TBinaryProtocol.TBinaryProtocolFactory()
143        self.io_loop.run_sync(transport.open)
144        self.client = ThriftTest.Client(transport, pfactory)
145
146    @gen_test
147    def test_void(self):
148        v = yield self.client.testVoid()
149        self.assertEqual(v, None)
150
151    @gen_test
152    def test_string(self):
153        v = yield self.client.testString('Python')
154        self.assertEqual(v, 'Python')
155
156    @gen_test
157    def test_byte(self):
158        v = yield self.client.testByte(63)
159        self.assertEqual(v, 63)
160
161    @gen_test
162    def test_i32(self):
163        v = yield self.client.testI32(-1)
164        self.assertEqual(v, -1)
165
166        v = yield self.client.testI32(0)
167        self.assertEqual(v, 0)
168
169    @gen_test
170    def test_i64(self):
171        v = yield self.client.testI64(-34359738368)
172        self.assertEqual(v, -34359738368)
173
174    @gen_test
175    def test_double(self):
176        v = yield self.client.testDouble(-5.235098235)
177        self.assertEqual(v, -5.235098235)
178
179    @gen_test
180    def test_struct(self):
181        x = Xtruct()
182        x.string_thing = "Zero"
183        x.byte_thing = 1
184        x.i32_thing = -3
185        x.i64_thing = -5
186        y = yield self.client.testStruct(x)
187
188        self.assertEqual(y.string_thing, "Zero")
189        self.assertEqual(y.byte_thing, 1)
190        self.assertEqual(y.i32_thing, -3)
191        self.assertEqual(y.i64_thing, -5)
192
193    @gen_test
194    def test_oneway(self):
195        self.client.testOneway(1)
196        v = yield self.client.testI32(-1)
197        self.assertEqual(v, -1)
198
199    @gen_test
200    def test_map(self):
201        """
202        TestHandler.testMap is a coroutine, this test checks if gen.Return() from a coroutine works.
203        """
204        expected = {1: 1}
205        res = yield self.client.testMap(expected)
206        self.assertEqual(res, expected)
207
208    @gen_test
209    def test_exception(self):
210        try:
211            yield self.client.testException('Xception')
212        except Xception as ex:
213            self.assertEqual(ex.errorCode, 1001)
214            self.assertEqual(ex.message, 'Xception')
215        else:
216            self.fail("should have gotten exception")
217        try:
218            yield self.client.testException('throw_undeclared')
219        except TApplicationException:
220            pass
221        else:
222            self.fail("should have gotten exception")
223
224        yield self.client.testException('Safe')
225
226
227def suite():
228    suite = unittest.TestSuite()
229    loader = unittest.TestLoader()
230    suite.addTest(loader.loadTestsFromTestCase(ThriftTestCase))
231    return suite
232
233
234if __name__ == '__main__':
235    unittest.TestProgram(defaultTest='suite',
236                         testRunner=unittest.TextTestRunner(verbosity=1))
237