1#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21
22import glob
23import os
24import sys
25import time
26
27basepath = os.path.abspath(os.path.dirname(__file__))
28sys.path.insert(0, os.path.join(basepath, 'gen-py.twisted'))
29sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib.*'))[0])
30
31from thrift.Thrift import TApplicationException
32
33from ThriftTest import ThriftTest
34from ThriftTest.ttypes import Xception, Xtruct
35from thrift.transport import TTwisted
36from thrift.protocol import TBinaryProtocol
37
38from twisted.trial import unittest
39from twisted.internet import defer, reactor
40from twisted.internet.protocol import ClientCreator
41
42from zope.interface import implementer
43
44
45@implementer(ThriftTest.Iface)
46class TestHandler:
47    def __init__(self):
48        self.onewaysQueue = defer.DeferredQueue()
49
50    def testVoid(self):
51        pass
52
53    def testString(self, s):
54        return s
55
56    def testByte(self, b):
57        return b
58
59    def testI16(self, i16):
60        return i16
61
62    def testI32(self, i32):
63        return i32
64
65    def testI64(self, i64):
66        return i64
67
68    def testDouble(self, dub):
69        return dub
70
71    def testBinary(self, thing):
72        return thing
73
74    def testStruct(self, thing):
75        return thing
76
77    def testException(self, s):
78        if s == 'Xception':
79            x = Xception()
80            x.errorCode = 1001
81            x.message = s
82            raise x
83        elif s == "throw_undeclared":
84            raise ValueError("foo")
85
86    def testOneway(self, seconds):
87        def fireOneway(t):
88            self.onewaysQueue.put((t, time.time(), seconds))
89        reactor.callLater(seconds, fireOneway, time.time())
90        raise Exception('')
91
92    def testNest(self, thing):
93        return thing
94
95    def testMap(self, thing):
96        return thing
97
98    def testSet(self, thing):
99        return thing
100
101    def testList(self, thing):
102        return thing
103
104    def testEnum(self, thing):
105        return thing
106
107    def testTypedef(self, thing):
108        return thing
109
110
111class ThriftTestCase(unittest.TestCase):
112
113    @defer.inlineCallbacks
114    def setUp(self):
115        self.handler = TestHandler()
116        self.processor = ThriftTest.Processor(self.handler)
117        self.pfactory = TBinaryProtocol.TBinaryProtocolFactory()
118
119        self.server = reactor.listenTCP(
120            0, TTwisted.ThriftServerFactory(self.processor, self.pfactory), interface="127.0.0.1")
121
122        self.portNo = self.server.getHost().port
123
124        self.txclient = yield ClientCreator(reactor,
125                                            TTwisted.ThriftClientProtocol,
126                                            ThriftTest.Client,
127                                            self.pfactory).connectTCP("127.0.0.1", self.portNo)
128        self.client = self.txclient.client
129
130    @defer.inlineCallbacks
131    def tearDown(self):
132        yield self.server.stopListening()
133        self.txclient.transport.loseConnection()
134
135    @defer.inlineCallbacks
136    def testVoid(self):
137        self.assertEquals((yield self.client.testVoid()), None)
138
139    @defer.inlineCallbacks
140    def testString(self):
141        self.assertEquals((yield self.client.testString('Python')), 'Python')
142
143    @defer.inlineCallbacks
144    def testByte(self):
145        self.assertEquals((yield self.client.testByte(63)), 63)
146
147    @defer.inlineCallbacks
148    def testI32(self):
149        self.assertEquals((yield self.client.testI32(-1)), -1)
150        self.assertEquals((yield self.client.testI32(0)), 0)
151
152    @defer.inlineCallbacks
153    def testI64(self):
154        self.assertEquals((yield self.client.testI64(-34359738368)), -34359738368)
155
156    @defer.inlineCallbacks
157    def testDouble(self):
158        self.assertEquals((yield self.client.testDouble(-5.235098235)), -5.235098235)
159
160    # TODO: def testBinary(self) ...
161
162    @defer.inlineCallbacks
163    def testStruct(self):
164        x = Xtruct()
165        x.string_thing = "Zero"
166        x.byte_thing = 1
167        x.i32_thing = -3
168        x.i64_thing = -5
169        y = yield self.client.testStruct(x)
170
171        self.assertEquals(y.string_thing, "Zero")
172        self.assertEquals(y.byte_thing, 1)
173        self.assertEquals(y.i32_thing, -3)
174        self.assertEquals(y.i64_thing, -5)
175
176    @defer.inlineCallbacks
177    def testException(self):
178        try:
179            yield self.client.testException('Xception')
180            self.fail("should have gotten exception")
181        except Xception as x:
182            self.assertEquals(x.errorCode, 1001)
183            self.assertEquals(x.message, 'Xception')
184
185        try:
186            yield self.client.testException("throw_undeclared")
187            self.fail("should have gotten exception")
188        except TApplicationException:
189            pass
190
191        yield self.client.testException('Safe')
192
193    @defer.inlineCallbacks
194    def testOneway(self):
195        yield self.client.testOneway(1)
196        start, end, seconds = yield self.handler.onewaysQueue.get()
197        self.assertAlmostEquals(seconds, (end - start), places=1)
198        self.assertEquals((yield self.client.testI32(-1)), -1)
199