1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server responding with RESOURCE_EXHAUSTED.""" 15 16import threading 17import unittest 18import logging 19 20import grpc 21from grpc import _channel 22from grpc.framework.foundation import logging_pool 23 24from tests.unit import test_common 25from tests.unit.framework.common import test_constants 26 27_REQUEST = b'\x00\x00\x00' 28_RESPONSE = b'\x00\x00\x00' 29 30_UNARY_UNARY = '/test/UnaryUnary' 31_UNARY_STREAM = '/test/UnaryStream' 32_STREAM_UNARY = '/test/StreamUnary' 33_STREAM_STREAM = '/test/StreamStream' 34 35 36class _TestTrigger(object): 37 38 def __init__(self, total_call_count): 39 self._total_call_count = total_call_count 40 self._pending_calls = 0 41 self._triggered = False 42 self._finish_condition = threading.Condition() 43 self._start_condition = threading.Condition() 44 45 # Wait for all calls be blocked in their handler 46 def await_calls(self): 47 with self._start_condition: 48 while self._pending_calls < self._total_call_count: 49 self._start_condition.wait() 50 51 # Block in a response handler and wait for a trigger 52 def await_trigger(self): 53 with self._start_condition: 54 self._pending_calls += 1 55 self._start_condition.notify() 56 57 with self._finish_condition: 58 if not self._triggered: 59 self._finish_condition.wait() 60 61 # Finish all response handlers 62 def trigger(self): 63 with self._finish_condition: 64 self._triggered = True 65 self._finish_condition.notify_all() 66 67 68def handle_unary_unary(trigger, request, servicer_context): 69 trigger.await_trigger() 70 return _RESPONSE 71 72 73def handle_unary_stream(trigger, request, servicer_context): 74 trigger.await_trigger() 75 for _ in range(test_constants.STREAM_LENGTH): 76 yield _RESPONSE 77 78 79def handle_stream_unary(trigger, request_iterator, servicer_context): 80 trigger.await_trigger() 81 # TODO(issue:#6891) We should be able to remove this loop 82 for request in request_iterator: 83 pass 84 return _RESPONSE 85 86 87def handle_stream_stream(trigger, request_iterator, servicer_context): 88 trigger.await_trigger() 89 # TODO(issue:#6891) We should be able to remove this loop, 90 # and replace with return; yield 91 for request in request_iterator: 92 yield _RESPONSE 93 94 95class _MethodHandler(grpc.RpcMethodHandler): 96 97 def __init__(self, trigger, request_streaming, response_streaming): 98 self.request_streaming = request_streaming 99 self.response_streaming = response_streaming 100 self.request_deserializer = None 101 self.response_serializer = None 102 self.unary_unary = None 103 self.unary_stream = None 104 self.stream_unary = None 105 self.stream_stream = None 106 if self.request_streaming and self.response_streaming: 107 self.stream_stream = ( 108 lambda x, y: handle_stream_stream(trigger, x, y)) 109 elif self.request_streaming: 110 self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y) 111 elif self.response_streaming: 112 self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y) 113 else: 114 self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y) 115 116 117class _GenericHandler(grpc.GenericRpcHandler): 118 119 def __init__(self, trigger): 120 self._trigger = trigger 121 122 def service(self, handler_call_details): 123 if handler_call_details.method == _UNARY_UNARY: 124 return _MethodHandler(self._trigger, False, False) 125 elif handler_call_details.method == _UNARY_STREAM: 126 return _MethodHandler(self._trigger, False, True) 127 elif handler_call_details.method == _STREAM_UNARY: 128 return _MethodHandler(self._trigger, True, False) 129 elif handler_call_details.method == _STREAM_STREAM: 130 return _MethodHandler(self._trigger, True, True) 131 else: 132 return None 133 134 135class ResourceExhaustedTest(unittest.TestCase): 136 137 def setUp(self): 138 self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) 139 self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY) 140 self._server = grpc.server( 141 self._server_pool, 142 handlers=(_GenericHandler(self._trigger),), 143 options=(('grpc.so_reuseport', 0),), 144 maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY) 145 port = self._server.add_insecure_port('[::]:0') 146 self._server.start() 147 self._channel = grpc.insecure_channel('localhost:%d' % port) 148 149 def tearDown(self): 150 self._server.stop(0) 151 self._channel.close() 152 153 def testUnaryUnary(self): 154 multi_callable = self._channel.unary_unary(_UNARY_UNARY) 155 futures = [] 156 for _ in range(test_constants.THREAD_CONCURRENCY): 157 futures.append(multi_callable.future(_REQUEST)) 158 159 self._trigger.await_calls() 160 161 with self.assertRaises(grpc.RpcError) as exception_context: 162 multi_callable(_REQUEST) 163 164 self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, 165 exception_context.exception.code()) 166 167 future_exception = multi_callable.future(_REQUEST) 168 self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, 169 future_exception.exception().code()) 170 171 self._trigger.trigger() 172 for future in futures: 173 self.assertEqual(_RESPONSE, future.result()) 174 175 # Ensure a new request can be handled 176 self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) 177 178 def testUnaryStream(self): 179 multi_callable = self._channel.unary_stream(_UNARY_STREAM) 180 calls = [] 181 for _ in range(test_constants.THREAD_CONCURRENCY): 182 calls.append(multi_callable(_REQUEST)) 183 184 self._trigger.await_calls() 185 186 with self.assertRaises(grpc.RpcError) as exception_context: 187 next(multi_callable(_REQUEST)) 188 189 self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, 190 exception_context.exception.code()) 191 192 self._trigger.trigger() 193 194 for call in calls: 195 for response in call: 196 self.assertEqual(_RESPONSE, response) 197 198 # Ensure a new request can be handled 199 new_call = multi_callable(_REQUEST) 200 for response in new_call: 201 self.assertEqual(_RESPONSE, response) 202 203 def testStreamUnary(self): 204 multi_callable = self._channel.stream_unary(_STREAM_UNARY) 205 futures = [] 206 request = iter([_REQUEST] * test_constants.STREAM_LENGTH) 207 for _ in range(test_constants.THREAD_CONCURRENCY): 208 futures.append(multi_callable.future(request)) 209 210 self._trigger.await_calls() 211 212 with self.assertRaises(grpc.RpcError) as exception_context: 213 multi_callable(request) 214 215 self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, 216 exception_context.exception.code()) 217 218 future_exception = multi_callable.future(request) 219 self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, 220 future_exception.exception().code()) 221 222 self._trigger.trigger() 223 224 for future in futures: 225 self.assertEqual(_RESPONSE, future.result()) 226 227 # Ensure a new request can be handled 228 self.assertEqual(_RESPONSE, multi_callable(request)) 229 230 def testStreamStream(self): 231 multi_callable = self._channel.stream_stream(_STREAM_STREAM) 232 calls = [] 233 request = iter([_REQUEST] * test_constants.STREAM_LENGTH) 234 for _ in range(test_constants.THREAD_CONCURRENCY): 235 calls.append(multi_callable(request)) 236 237 self._trigger.await_calls() 238 239 with self.assertRaises(grpc.RpcError) as exception_context: 240 next(multi_callable(request)) 241 242 self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, 243 exception_context.exception.code()) 244 245 self._trigger.trigger() 246 247 for call in calls: 248 for response in call: 249 self.assertEqual(_RESPONSE, response) 250 251 # Ensure a new request can be handled 252 new_call = multi_callable(request) 253 for response in new_call: 254 self.assertEqual(_RESPONSE, response) 255 256 257if __name__ == '__main__': 258 logging.basicConfig() 259 unittest.main(verbosity=2) 260