1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server responding with RESOURCE_EXHAUSTED."""
15
16import threading
17import unittest
18import logging
19
20import grpc
21from grpc import _channel
22from grpc.framework.foundation import logging_pool
23
24from tests.unit import test_common
25from tests.unit.framework.common import test_constants
26
27_REQUEST = b'\x00\x00\x00'
28_RESPONSE = b'\x00\x00\x00'
29
30_UNARY_UNARY = '/test/UnaryUnary'
31_UNARY_STREAM = '/test/UnaryStream'
32_STREAM_UNARY = '/test/StreamUnary'
33_STREAM_STREAM = '/test/StreamStream'
34
35
36class _TestTrigger(object):
37
38    def __init__(self, total_call_count):
39        self._total_call_count = total_call_count
40        self._pending_calls = 0
41        self._triggered = False
42        self._finish_condition = threading.Condition()
43        self._start_condition = threading.Condition()
44
45    # Wait for all calls be blocked in their handler
46    def await_calls(self):
47        with self._start_condition:
48            while self._pending_calls < self._total_call_count:
49                self._start_condition.wait()
50
51    # Block in a response handler and wait for a trigger
52    def await_trigger(self):
53        with self._start_condition:
54            self._pending_calls += 1
55            self._start_condition.notify()
56
57        with self._finish_condition:
58            if not self._triggered:
59                self._finish_condition.wait()
60
61    # Finish all response handlers
62    def trigger(self):
63        with self._finish_condition:
64            self._triggered = True
65            self._finish_condition.notify_all()
66
67
68def handle_unary_unary(trigger, request, servicer_context):
69    trigger.await_trigger()
70    return _RESPONSE
71
72
73def handle_unary_stream(trigger, request, servicer_context):
74    trigger.await_trigger()
75    for _ in range(test_constants.STREAM_LENGTH):
76        yield _RESPONSE
77
78
79def handle_stream_unary(trigger, request_iterator, servicer_context):
80    trigger.await_trigger()
81    # TODO(issue:#6891) We should be able to remove this loop
82    for request in request_iterator:
83        pass
84    return _RESPONSE
85
86
87def handle_stream_stream(trigger, request_iterator, servicer_context):
88    trigger.await_trigger()
89    # TODO(issue:#6891) We should be able to remove this loop,
90    # and replace with return; yield
91    for request in request_iterator:
92        yield _RESPONSE
93
94
95class _MethodHandler(grpc.RpcMethodHandler):
96
97    def __init__(self, trigger, request_streaming, response_streaming):
98        self.request_streaming = request_streaming
99        self.response_streaming = response_streaming
100        self.request_deserializer = None
101        self.response_serializer = None
102        self.unary_unary = None
103        self.unary_stream = None
104        self.stream_unary = None
105        self.stream_stream = None
106        if self.request_streaming and self.response_streaming:
107            self.stream_stream = (
108                lambda x, y: handle_stream_stream(trigger, x, y))
109        elif self.request_streaming:
110            self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y)
111        elif self.response_streaming:
112            self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y)
113        else:
114            self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y)
115
116
117class _GenericHandler(grpc.GenericRpcHandler):
118
119    def __init__(self, trigger):
120        self._trigger = trigger
121
122    def service(self, handler_call_details):
123        if handler_call_details.method == _UNARY_UNARY:
124            return _MethodHandler(self._trigger, False, False)
125        elif handler_call_details.method == _UNARY_STREAM:
126            return _MethodHandler(self._trigger, False, True)
127        elif handler_call_details.method == _STREAM_UNARY:
128            return _MethodHandler(self._trigger, True, False)
129        elif handler_call_details.method == _STREAM_STREAM:
130            return _MethodHandler(self._trigger, True, True)
131        else:
132            return None
133
134
135class ResourceExhaustedTest(unittest.TestCase):
136
137    def setUp(self):
138        self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
139        self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY)
140        self._server = grpc.server(
141            self._server_pool,
142            handlers=(_GenericHandler(self._trigger),),
143            options=(('grpc.so_reuseport', 0),),
144            maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY)
145        port = self._server.add_insecure_port('[::]:0')
146        self._server.start()
147        self._channel = grpc.insecure_channel('localhost:%d' % port)
148
149    def tearDown(self):
150        self._server.stop(0)
151        self._channel.close()
152
153    def testUnaryUnary(self):
154        multi_callable = self._channel.unary_unary(_UNARY_UNARY)
155        futures = []
156        for _ in range(test_constants.THREAD_CONCURRENCY):
157            futures.append(multi_callable.future(_REQUEST))
158
159        self._trigger.await_calls()
160
161        with self.assertRaises(grpc.RpcError) as exception_context:
162            multi_callable(_REQUEST)
163
164        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
165                         exception_context.exception.code())
166
167        future_exception = multi_callable.future(_REQUEST)
168        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
169                         future_exception.exception().code())
170
171        self._trigger.trigger()
172        for future in futures:
173            self.assertEqual(_RESPONSE, future.result())
174
175        # Ensure a new request can be handled
176        self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
177
178    def testUnaryStream(self):
179        multi_callable = self._channel.unary_stream(_UNARY_STREAM)
180        calls = []
181        for _ in range(test_constants.THREAD_CONCURRENCY):
182            calls.append(multi_callable(_REQUEST))
183
184        self._trigger.await_calls()
185
186        with self.assertRaises(grpc.RpcError) as exception_context:
187            next(multi_callable(_REQUEST))
188
189        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
190                         exception_context.exception.code())
191
192        self._trigger.trigger()
193
194        for call in calls:
195            for response in call:
196                self.assertEqual(_RESPONSE, response)
197
198        # Ensure a new request can be handled
199        new_call = multi_callable(_REQUEST)
200        for response in new_call:
201            self.assertEqual(_RESPONSE, response)
202
203    def testStreamUnary(self):
204        multi_callable = self._channel.stream_unary(_STREAM_UNARY)
205        futures = []
206        request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
207        for _ in range(test_constants.THREAD_CONCURRENCY):
208            futures.append(multi_callable.future(request))
209
210        self._trigger.await_calls()
211
212        with self.assertRaises(grpc.RpcError) as exception_context:
213            multi_callable(request)
214
215        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
216                         exception_context.exception.code())
217
218        future_exception = multi_callable.future(request)
219        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
220                         future_exception.exception().code())
221
222        self._trigger.trigger()
223
224        for future in futures:
225            self.assertEqual(_RESPONSE, future.result())
226
227        # Ensure a new request can be handled
228        self.assertEqual(_RESPONSE, multi_callable(request))
229
230    def testStreamStream(self):
231        multi_callable = self._channel.stream_stream(_STREAM_STREAM)
232        calls = []
233        request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
234        for _ in range(test_constants.THREAD_CONCURRENCY):
235            calls.append(multi_callable(request))
236
237        self._trigger.await_calls()
238
239        with self.assertRaises(grpc.RpcError) as exception_context:
240            next(multi_callable(request))
241
242        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
243                         exception_context.exception.code())
244
245        self._trigger.trigger()
246
247        for call in calls:
248            for response in call:
249                self.assertEqual(_RESPONSE, response)
250
251        # Ensure a new request can be handled
252        new_call = multi_callable(request)
253        for response in new_call:
254            self.assertEqual(_RESPONSE, response)
255
256
257if __name__ == '__main__':
258    logging.basicConfig()
259    unittest.main(verbosity=2)
260