1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import unittest
16import logging
17
18import grpc
19
20from tests.unit import test_common
21from tests.unit.framework.common import test_constants
22from tests.unit.framework.common import test_control
23
24_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
25_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]
26_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
27_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
28
29_UNARY_UNARY = '/test/UnaryUnary'
30_UNARY_STREAM = '/test/UnaryStream'
31_STREAM_UNARY = '/test/StreamUnary'
32_STREAM_STREAM = '/test/StreamStream'
33_DEFECTIVE_GENERIC_RPC_HANDLER = '/test/DefectiveGenericRpcHandler'
34
35
36class _Handler(object):
37
38    def __init__(self, control):
39        self._control = control
40
41    def handle_unary_unary(self, request, servicer_context):
42        self._control.control()
43        if servicer_context is not None:
44            servicer_context.set_trailing_metadata(((
45                'testkey',
46                'testvalue',
47            ),))
48        return request
49
50    def handle_unary_stream(self, request, servicer_context):
51        for _ in range(test_constants.STREAM_LENGTH):
52            self._control.control()
53            yield request
54        self._control.control()
55        if servicer_context is not None:
56            servicer_context.set_trailing_metadata(((
57                'testkey',
58                'testvalue',
59            ),))
60
61    def handle_stream_unary(self, request_iterator, servicer_context):
62        if servicer_context is not None:
63            servicer_context.invocation_metadata()
64        self._control.control()
65        response_elements = []
66        for request in request_iterator:
67            self._control.control()
68            response_elements.append(request)
69        self._control.control()
70        if servicer_context is not None:
71            servicer_context.set_trailing_metadata(((
72                'testkey',
73                'testvalue',
74            ),))
75        return b''.join(response_elements)
76
77    def handle_stream_stream(self, request_iterator, servicer_context):
78        self._control.control()
79        if servicer_context is not None:
80            servicer_context.set_trailing_metadata(((
81                'testkey',
82                'testvalue',
83            ),))
84        for request in request_iterator:
85            self._control.control()
86            yield request
87        self._control.control()
88
89    def defective_generic_rpc_handler(self):
90        raise test_control.Defect()
91
92
93class _MethodHandler(grpc.RpcMethodHandler):
94
95    def __init__(self, request_streaming, response_streaming,
96                 request_deserializer, response_serializer, unary_unary,
97                 unary_stream, stream_unary, stream_stream):
98        self.request_streaming = request_streaming
99        self.response_streaming = response_streaming
100        self.request_deserializer = request_deserializer
101        self.response_serializer = response_serializer
102        self.unary_unary = unary_unary
103        self.unary_stream = unary_stream
104        self.stream_unary = stream_unary
105        self.stream_stream = stream_stream
106
107
108class _GenericHandler(grpc.GenericRpcHandler):
109
110    def __init__(self, handler):
111        self._handler = handler
112
113    def service(self, handler_call_details):
114        if handler_call_details.method == _UNARY_UNARY:
115            return _MethodHandler(False, False, None, None,
116                                  self._handler.handle_unary_unary, None, None,
117                                  None)
118        elif handler_call_details.method == _UNARY_STREAM:
119            return _MethodHandler(False, True, _DESERIALIZE_REQUEST,
120                                  _SERIALIZE_RESPONSE, None,
121                                  self._handler.handle_unary_stream, None, None)
122        elif handler_call_details.method == _STREAM_UNARY:
123            return _MethodHandler(True, False, _DESERIALIZE_REQUEST,
124                                  _SERIALIZE_RESPONSE, None, None,
125                                  self._handler.handle_stream_unary, None)
126        elif handler_call_details.method == _STREAM_STREAM:
127            return _MethodHandler(True, True, None, None, None, None, None,
128                                  self._handler.handle_stream_stream)
129        elif handler_call_details.method == _DEFECTIVE_GENERIC_RPC_HANDLER:
130            return self._handler.defective_generic_rpc_handler()
131        else:
132            return None
133
134
135class FailAfterFewIterationsCounter(object):
136
137    def __init__(self, high, bytestring):
138        self._current = 0
139        self._high = high
140        self._bytestring = bytestring
141
142    def __iter__(self):
143        return self
144
145    def __next__(self):
146        if self._current >= self._high:
147            raise test_control.Defect()
148        else:
149            self._current += 1
150            return self._bytestring
151
152    next = __next__
153
154
155def _unary_unary_multi_callable(channel):
156    return channel.unary_unary(_UNARY_UNARY)
157
158
159def _unary_stream_multi_callable(channel):
160    return channel.unary_stream(_UNARY_STREAM,
161                                request_serializer=_SERIALIZE_REQUEST,
162                                response_deserializer=_DESERIALIZE_RESPONSE)
163
164
165def _stream_unary_multi_callable(channel):
166    return channel.stream_unary(_STREAM_UNARY,
167                                request_serializer=_SERIALIZE_REQUEST,
168                                response_deserializer=_DESERIALIZE_RESPONSE)
169
170
171def _stream_stream_multi_callable(channel):
172    return channel.stream_stream(_STREAM_STREAM)
173
174
175def _defective_handler_multi_callable(channel):
176    return channel.unary_unary(_DEFECTIVE_GENERIC_RPC_HANDLER)
177
178
179class InvocationDefectsTest(unittest.TestCase):
180    """Tests the handling of exception-raising user code on the client-side."""
181
182    def setUp(self):
183        self._control = test_control.PauseFailControl()
184        self._handler = _Handler(self._control)
185
186        self._server = test_common.test_server()
187        port = self._server.add_insecure_port('[::]:0')
188        self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
189        self._server.start()
190
191        self._channel = grpc.insecure_channel('localhost:%d' % port)
192
193    def tearDown(self):
194        self._server.stop(0)
195        self._channel.close()
196
197    def testIterableStreamRequestBlockingUnaryResponse(self):
198        requests = object()
199        multi_callable = _stream_unary_multi_callable(self._channel)
200
201        with self.assertRaises(grpc.RpcError) as exception_context:
202            multi_callable(
203                requests,
204                metadata=(('test',
205                           'IterableStreamRequestBlockingUnaryResponse'),))
206
207        self.assertIs(grpc.StatusCode.UNKNOWN,
208                      exception_context.exception.code())
209
210    def testIterableStreamRequestFutureUnaryResponse(self):
211        requests = object()
212        multi_callable = _stream_unary_multi_callable(self._channel)
213        response_future = multi_callable.future(
214            requests,
215            metadata=(('test', 'IterableStreamRequestFutureUnaryResponse'),))
216
217        with self.assertRaises(grpc.RpcError) as exception_context:
218            response_future.result()
219
220        self.assertIs(grpc.StatusCode.UNKNOWN,
221                      exception_context.exception.code())
222
223    def testIterableStreamRequestStreamResponse(self):
224        requests = object()
225        multi_callable = _stream_stream_multi_callable(self._channel)
226        response_iterator = multi_callable(
227            requests,
228            metadata=(('test', 'IterableStreamRequestStreamResponse'),))
229
230        with self.assertRaises(grpc.RpcError) as exception_context:
231            next(response_iterator)
232
233        self.assertIs(grpc.StatusCode.UNKNOWN,
234                      exception_context.exception.code())
235
236    def testIteratorStreamRequestStreamResponse(self):
237        requests_iterator = FailAfterFewIterationsCounter(
238            test_constants.STREAM_LENGTH // 2, b'\x07\x08')
239        multi_callable = _stream_stream_multi_callable(self._channel)
240        response_iterator = multi_callable(
241            requests_iterator,
242            metadata=(('test', 'IteratorStreamRequestStreamResponse'),))
243
244        with self.assertRaises(grpc.RpcError) as exception_context:
245            for _ in range(test_constants.STREAM_LENGTH // 2 + 1):
246                next(response_iterator)
247
248        self.assertIs(grpc.StatusCode.UNKNOWN,
249                      exception_context.exception.code())
250
251    def testDefectiveGenericRpcHandlerUnaryResponse(self):
252        request = b'\x07\x08'
253        multi_callable = _defective_handler_multi_callable(self._channel)
254
255        with self.assertRaises(grpc.RpcError) as exception_context:
256            multi_callable(request,
257                           metadata=(('test',
258                                      'DefectiveGenericRpcHandlerUnary'),))
259
260        self.assertIs(grpc.StatusCode.UNKNOWN,
261                      exception_context.exception.code())
262
263
264if __name__ == '__main__':
265    logging.basicConfig()
266    unittest.main(verbosity=2)
267