1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import unittest 16import logging 17 18import grpc 19 20from tests.unit import test_common 21from tests.unit.framework.common import test_constants 22from tests.unit.framework.common import test_control 23 24_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 25_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] 26_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 27_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] 28 29_UNARY_UNARY = '/test/UnaryUnary' 30_UNARY_STREAM = '/test/UnaryStream' 31_STREAM_UNARY = '/test/StreamUnary' 32_STREAM_STREAM = '/test/StreamStream' 33_DEFECTIVE_GENERIC_RPC_HANDLER = '/test/DefectiveGenericRpcHandler' 34 35 36class _Handler(object): 37 38 def __init__(self, control): 39 self._control = control 40 41 def handle_unary_unary(self, request, servicer_context): 42 self._control.control() 43 if servicer_context is not None: 44 servicer_context.set_trailing_metadata((( 45 'testkey', 46 'testvalue', 47 ),)) 48 return request 49 50 def handle_unary_stream(self, request, servicer_context): 51 for _ in range(test_constants.STREAM_LENGTH): 52 self._control.control() 53 yield request 54 self._control.control() 55 if servicer_context is not None: 56 servicer_context.set_trailing_metadata((( 57 'testkey', 58 'testvalue', 59 ),)) 60 61 def handle_stream_unary(self, request_iterator, servicer_context): 62 if servicer_context is not None: 63 servicer_context.invocation_metadata() 64 self._control.control() 65 response_elements = [] 66 for request in request_iterator: 67 self._control.control() 68 response_elements.append(request) 69 self._control.control() 70 if servicer_context is not None: 71 servicer_context.set_trailing_metadata((( 72 'testkey', 73 'testvalue', 74 ),)) 75 return b''.join(response_elements) 76 77 def handle_stream_stream(self, request_iterator, servicer_context): 78 self._control.control() 79 if servicer_context is not None: 80 servicer_context.set_trailing_metadata((( 81 'testkey', 82 'testvalue', 83 ),)) 84 for request in request_iterator: 85 self._control.control() 86 yield request 87 self._control.control() 88 89 def defective_generic_rpc_handler(self): 90 raise test_control.Defect() 91 92 93class _MethodHandler(grpc.RpcMethodHandler): 94 95 def __init__(self, request_streaming, response_streaming, 96 request_deserializer, response_serializer, unary_unary, 97 unary_stream, stream_unary, stream_stream): 98 self.request_streaming = request_streaming 99 self.response_streaming = response_streaming 100 self.request_deserializer = request_deserializer 101 self.response_serializer = response_serializer 102 self.unary_unary = unary_unary 103 self.unary_stream = unary_stream 104 self.stream_unary = stream_unary 105 self.stream_stream = stream_stream 106 107 108class _GenericHandler(grpc.GenericRpcHandler): 109 110 def __init__(self, handler): 111 self._handler = handler 112 113 def service(self, handler_call_details): 114 if handler_call_details.method == _UNARY_UNARY: 115 return _MethodHandler(False, False, None, None, 116 self._handler.handle_unary_unary, None, None, 117 None) 118 elif handler_call_details.method == _UNARY_STREAM: 119 return _MethodHandler(False, True, _DESERIALIZE_REQUEST, 120 _SERIALIZE_RESPONSE, None, 121 self._handler.handle_unary_stream, None, None) 122 elif handler_call_details.method == _STREAM_UNARY: 123 return _MethodHandler(True, False, _DESERIALIZE_REQUEST, 124 _SERIALIZE_RESPONSE, None, None, 125 self._handler.handle_stream_unary, None) 126 elif handler_call_details.method == _STREAM_STREAM: 127 return _MethodHandler(True, True, None, None, None, None, None, 128 self._handler.handle_stream_stream) 129 elif handler_call_details.method == _DEFECTIVE_GENERIC_RPC_HANDLER: 130 return self._handler.defective_generic_rpc_handler() 131 else: 132 return None 133 134 135class FailAfterFewIterationsCounter(object): 136 137 def __init__(self, high, bytestring): 138 self._current = 0 139 self._high = high 140 self._bytestring = bytestring 141 142 def __iter__(self): 143 return self 144 145 def __next__(self): 146 if self._current >= self._high: 147 raise test_control.Defect() 148 else: 149 self._current += 1 150 return self._bytestring 151 152 next = __next__ 153 154 155def _unary_unary_multi_callable(channel): 156 return channel.unary_unary(_UNARY_UNARY) 157 158 159def _unary_stream_multi_callable(channel): 160 return channel.unary_stream(_UNARY_STREAM, 161 request_serializer=_SERIALIZE_REQUEST, 162 response_deserializer=_DESERIALIZE_RESPONSE) 163 164 165def _stream_unary_multi_callable(channel): 166 return channel.stream_unary(_STREAM_UNARY, 167 request_serializer=_SERIALIZE_REQUEST, 168 response_deserializer=_DESERIALIZE_RESPONSE) 169 170 171def _stream_stream_multi_callable(channel): 172 return channel.stream_stream(_STREAM_STREAM) 173 174 175def _defective_handler_multi_callable(channel): 176 return channel.unary_unary(_DEFECTIVE_GENERIC_RPC_HANDLER) 177 178 179class InvocationDefectsTest(unittest.TestCase): 180 """Tests the handling of exception-raising user code on the client-side.""" 181 182 def setUp(self): 183 self._control = test_control.PauseFailControl() 184 self._handler = _Handler(self._control) 185 186 self._server = test_common.test_server() 187 port = self._server.add_insecure_port('[::]:0') 188 self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) 189 self._server.start() 190 191 self._channel = grpc.insecure_channel('localhost:%d' % port) 192 193 def tearDown(self): 194 self._server.stop(0) 195 self._channel.close() 196 197 def testIterableStreamRequestBlockingUnaryResponse(self): 198 requests = object() 199 multi_callable = _stream_unary_multi_callable(self._channel) 200 201 with self.assertRaises(grpc.RpcError) as exception_context: 202 multi_callable( 203 requests, 204 metadata=(('test', 205 'IterableStreamRequestBlockingUnaryResponse'),)) 206 207 self.assertIs(grpc.StatusCode.UNKNOWN, 208 exception_context.exception.code()) 209 210 def testIterableStreamRequestFutureUnaryResponse(self): 211 requests = object() 212 multi_callable = _stream_unary_multi_callable(self._channel) 213 response_future = multi_callable.future( 214 requests, 215 metadata=(('test', 'IterableStreamRequestFutureUnaryResponse'),)) 216 217 with self.assertRaises(grpc.RpcError) as exception_context: 218 response_future.result() 219 220 self.assertIs(grpc.StatusCode.UNKNOWN, 221 exception_context.exception.code()) 222 223 def testIterableStreamRequestStreamResponse(self): 224 requests = object() 225 multi_callable = _stream_stream_multi_callable(self._channel) 226 response_iterator = multi_callable( 227 requests, 228 metadata=(('test', 'IterableStreamRequestStreamResponse'),)) 229 230 with self.assertRaises(grpc.RpcError) as exception_context: 231 next(response_iterator) 232 233 self.assertIs(grpc.StatusCode.UNKNOWN, 234 exception_context.exception.code()) 235 236 def testIteratorStreamRequestStreamResponse(self): 237 requests_iterator = FailAfterFewIterationsCounter( 238 test_constants.STREAM_LENGTH // 2, b'\x07\x08') 239 multi_callable = _stream_stream_multi_callable(self._channel) 240 response_iterator = multi_callable( 241 requests_iterator, 242 metadata=(('test', 'IteratorStreamRequestStreamResponse'),)) 243 244 with self.assertRaises(grpc.RpcError) as exception_context: 245 for _ in range(test_constants.STREAM_LENGTH // 2 + 1): 246 next(response_iterator) 247 248 self.assertIs(grpc.StatusCode.UNKNOWN, 249 exception_context.exception.code()) 250 251 def testDefectiveGenericRpcHandlerUnaryResponse(self): 252 request = b'\x07\x08' 253 multi_callable = _defective_handler_multi_callable(self._channel) 254 255 with self.assertRaises(grpc.RpcError) as exception_context: 256 multi_callable(request, 257 metadata=(('test', 258 'DefectiveGenericRpcHandlerUnary'),)) 259 260 self.assertIs(grpc.StatusCode.UNKNOWN, 261 exception_context.exception.code()) 262 263 264if __name__ == '__main__': 265 logging.basicConfig() 266 unittest.main(verbosity=2) 267