1# Copyright 2019 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior of the Call classes."""
15
16import asyncio
17import datetime
18import logging
19import unittest
20
21import grpc
22from grpc.experimental import aio
23
24from src.proto.grpc.testing import messages_pb2
25from src.proto.grpc.testing import test_pb2_grpc
26from tests_aio.unit._constants import UNREACHABLE_TARGET
27from tests_aio.unit._test_base import AioTestBase
28from tests_aio.unit._test_server import start_test_server
29
30_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds()
31
32_NUM_STREAM_RESPONSES = 5
33_RESPONSE_PAYLOAD_SIZE = 42
34_REQUEST_PAYLOAD_SIZE = 7
35_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
36_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
37_INFINITE_INTERVAL_US = 2**31 - 1
38
39
40class _MulticallableTestMixin():
41
42    async def setUp(self):
43        address, self._server = await start_test_server()
44        self._channel = aio.insecure_channel(address)
45        self._stub = test_pb2_grpc.TestServiceStub(self._channel)
46
47    async def tearDown(self):
48        await self._channel.close()
49        await self._server.stop(None)
50
51
52class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
53
54    async def test_call_to_string(self):
55        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
56
57        self.assertTrue(str(call) is not None)
58        self.assertTrue(repr(call) is not None)
59
60        await call
61
62        self.assertTrue(str(call) is not None)
63        self.assertTrue(repr(call) is not None)
64
65    async def test_call_ok(self):
66        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
67
68        self.assertFalse(call.done())
69
70        response = await call
71
72        self.assertTrue(call.done())
73        self.assertIsInstance(response, messages_pb2.SimpleResponse)
74        self.assertEqual(await call.code(), grpc.StatusCode.OK)
75
76        # Response is cached at call object level, reentrance
77        # returns again the same response
78        response_retry = await call
79        self.assertIs(response, response_retry)
80
81    async def test_call_rpc_error(self):
82        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
83            stub = test_pb2_grpc.TestServiceStub(channel)
84
85            call = stub.UnaryCall(messages_pb2.SimpleRequest())
86
87            with self.assertRaises(aio.AioRpcError) as exception_context:
88                await call
89
90            self.assertEqual(grpc.StatusCode.UNAVAILABLE,
91                             exception_context.exception.code())
92
93            self.assertTrue(call.done())
94            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
95
96    async def test_call_code_awaitable(self):
97        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
98        self.assertEqual(await call.code(), grpc.StatusCode.OK)
99
100    async def test_call_details_awaitable(self):
101        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
102        self.assertEqual('', await call.details())
103
104    async def test_call_initial_metadata_awaitable(self):
105        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
106        self.assertEqual(aio.Metadata(), await call.initial_metadata())
107
108    async def test_call_trailing_metadata_awaitable(self):
109        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
110        self.assertEqual(aio.Metadata(), await call.trailing_metadata())
111
112    async def test_call_initial_metadata_cancelable(self):
113        coro_started = asyncio.Event()
114        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
115
116        async def coro():
117            coro_started.set()
118            await call.initial_metadata()
119
120        task = self.loop.create_task(coro())
121        await coro_started.wait()
122        task.cancel()
123
124        # Test that initial metadata can still be asked thought
125        # a cancellation happened with the previous task
126        self.assertEqual(aio.Metadata(), await call.initial_metadata())
127
128    async def test_call_initial_metadata_multiple_waiters(self):
129        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
130
131        async def coro():
132            return await call.initial_metadata()
133
134        task1 = self.loop.create_task(coro())
135        task2 = self.loop.create_task(coro())
136
137        await call
138        expected = [aio.Metadata() for _ in range(2)]
139        self.assertEqual(expected, await asyncio.gather(*[task1, task2]))
140
141    async def test_call_code_cancelable(self):
142        coro_started = asyncio.Event()
143        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
144
145        async def coro():
146            coro_started.set()
147            await call.code()
148
149        task = self.loop.create_task(coro())
150        await coro_started.wait()
151        task.cancel()
152
153        # Test that code can still be asked thought
154        # a cancellation happened with the previous task
155        self.assertEqual(grpc.StatusCode.OK, await call.code())
156
157    async def test_call_code_multiple_waiters(self):
158        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
159
160        async def coro():
161            return await call.code()
162
163        task1 = self.loop.create_task(coro())
164        task2 = self.loop.create_task(coro())
165
166        await call
167
168        self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await
169                         asyncio.gather(task1, task2))
170
171    async def test_cancel_unary_unary(self):
172        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
173
174        self.assertFalse(call.cancelled())
175
176        self.assertTrue(call.cancel())
177        self.assertFalse(call.cancel())
178
179        with self.assertRaises(asyncio.CancelledError):
180            await call
181
182        # The info in the RpcError should match the info in Call object.
183        self.assertTrue(call.cancelled())
184        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
185        self.assertEqual(await call.details(),
186                         'Locally cancelled by application!')
187
188    async def test_cancel_unary_unary_in_task(self):
189        coro_started = asyncio.Event()
190        call = self._stub.EmptyCall(messages_pb2.SimpleRequest())
191
192        async def another_coro():
193            coro_started.set()
194            await call
195
196        task = self.loop.create_task(another_coro())
197        await coro_started.wait()
198
199        self.assertFalse(task.done())
200        task.cancel()
201
202        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
203
204        with self.assertRaises(asyncio.CancelledError):
205            await task
206
207    async def test_passing_credentials_fails_over_insecure_channel(self):
208        call_credentials = grpc.composite_call_credentials(
209            grpc.access_token_call_credentials("abc"),
210            grpc.access_token_call_credentials("def"),
211        )
212        with self.assertRaisesRegex(
213                aio.UsageError,
214                "Call credentials are only valid on secure channels"):
215            self._stub.UnaryCall(messages_pb2.SimpleRequest(),
216                                 credentials=call_credentials)
217
218
219class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
220
221    async def test_call_rpc_error(self):
222        channel = aio.insecure_channel(UNREACHABLE_TARGET)
223        request = messages_pb2.StreamingOutputCallRequest()
224        stub = test_pb2_grpc.TestServiceStub(channel)
225        call = stub.StreamingOutputCall(request)
226
227        with self.assertRaises(aio.AioRpcError) as exception_context:
228            async for response in call:
229                pass
230
231        self.assertEqual(grpc.StatusCode.UNAVAILABLE,
232                         exception_context.exception.code())
233
234        self.assertTrue(call.done())
235        self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
236        await channel.close()
237
238    async def test_cancel_unary_stream(self):
239        # Prepares the request
240        request = messages_pb2.StreamingOutputCallRequest()
241        for _ in range(_NUM_STREAM_RESPONSES):
242            request.response_parameters.append(
243                messages_pb2.ResponseParameters(
244                    size=_RESPONSE_PAYLOAD_SIZE,
245                    interval_us=_RESPONSE_INTERVAL_US,
246                ))
247
248        # Invokes the actual RPC
249        call = self._stub.StreamingOutputCall(request)
250        self.assertFalse(call.cancelled())
251
252        response = await call.read()
253        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
254        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
255
256        self.assertTrue(call.cancel())
257        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
258        self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
259                         call.details())
260        self.assertFalse(call.cancel())
261
262        with self.assertRaises(asyncio.CancelledError):
263            await call.read()
264        self.assertTrue(call.cancelled())
265
266    async def test_multiple_cancel_unary_stream(self):
267        # Prepares the request
268        request = messages_pb2.StreamingOutputCallRequest()
269        for _ in range(_NUM_STREAM_RESPONSES):
270            request.response_parameters.append(
271                messages_pb2.ResponseParameters(
272                    size=_RESPONSE_PAYLOAD_SIZE,
273                    interval_us=_RESPONSE_INTERVAL_US,
274                ))
275
276        # Invokes the actual RPC
277        call = self._stub.StreamingOutputCall(request)
278        self.assertFalse(call.cancelled())
279
280        response = await call.read()
281        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
282        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
283
284        self.assertTrue(call.cancel())
285        self.assertFalse(call.cancel())
286        self.assertFalse(call.cancel())
287        self.assertFalse(call.cancel())
288
289        with self.assertRaises(asyncio.CancelledError):
290            await call.read()
291
292    async def test_early_cancel_unary_stream(self):
293        """Test cancellation before receiving messages."""
294        # Prepares the request
295        request = messages_pb2.StreamingOutputCallRequest()
296        for _ in range(_NUM_STREAM_RESPONSES):
297            request.response_parameters.append(
298                messages_pb2.ResponseParameters(
299                    size=_RESPONSE_PAYLOAD_SIZE,
300                    interval_us=_RESPONSE_INTERVAL_US,
301                ))
302
303        # Invokes the actual RPC
304        call = self._stub.StreamingOutputCall(request)
305
306        self.assertFalse(call.cancelled())
307        self.assertTrue(call.cancel())
308        self.assertFalse(call.cancel())
309
310        with self.assertRaises(asyncio.CancelledError):
311            await call.read()
312
313        self.assertTrue(call.cancelled())
314
315        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
316        self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
317                         call.details())
318
319    async def test_late_cancel_unary_stream(self):
320        """Test cancellation after received all messages."""
321        # Prepares the request
322        request = messages_pb2.StreamingOutputCallRequest()
323        for _ in range(_NUM_STREAM_RESPONSES):
324            request.response_parameters.append(
325                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
326
327        # Invokes the actual RPC
328        call = self._stub.StreamingOutputCall(request)
329
330        for _ in range(_NUM_STREAM_RESPONSES):
331            response = await call.read()
332            self.assertIs(type(response),
333                          messages_pb2.StreamingOutputCallResponse)
334            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
335
336        # After all messages received, it is possible that the final state
337        # is received or on its way. It's basically a data race, so our
338        # expectation here is do not crash :)
339        call.cancel()
340        self.assertIn(await call.code(),
341                      [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
342
343    async def test_too_many_reads_unary_stream(self):
344        """Test calling read after received all messages fails."""
345        # Prepares the request
346        request = messages_pb2.StreamingOutputCallRequest()
347        for _ in range(_NUM_STREAM_RESPONSES):
348            request.response_parameters.append(
349                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
350
351        # Invokes the actual RPC
352        call = self._stub.StreamingOutputCall(request)
353
354        for _ in range(_NUM_STREAM_RESPONSES):
355            response = await call.read()
356            self.assertIs(type(response),
357                          messages_pb2.StreamingOutputCallResponse)
358            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
359        self.assertIs(await call.read(), aio.EOF)
360
361        # After the RPC is finished, further reads will lead to exception.
362        self.assertEqual(await call.code(), grpc.StatusCode.OK)
363        self.assertIs(await call.read(), aio.EOF)
364
365    async def test_unary_stream_async_generator(self):
366        """Sunny day test case for unary_stream."""
367        # Prepares the request
368        request = messages_pb2.StreamingOutputCallRequest()
369        for _ in range(_NUM_STREAM_RESPONSES):
370            request.response_parameters.append(
371                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
372
373        # Invokes the actual RPC
374        call = self._stub.StreamingOutputCall(request)
375        self.assertFalse(call.cancelled())
376
377        async for response in call:
378            self.assertIs(type(response),
379                          messages_pb2.StreamingOutputCallResponse)
380            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
381
382        self.assertEqual(await call.code(), grpc.StatusCode.OK)
383
384    async def test_cancel_unary_stream_in_task_using_read(self):
385        coro_started = asyncio.Event()
386
387        # Configs the server method to block forever
388        request = messages_pb2.StreamingOutputCallRequest()
389        request.response_parameters.append(
390            messages_pb2.ResponseParameters(
391                size=_RESPONSE_PAYLOAD_SIZE,
392                interval_us=_INFINITE_INTERVAL_US,
393            ))
394
395        # Invokes the actual RPC
396        call = self._stub.StreamingOutputCall(request)
397
398        async def another_coro():
399            coro_started.set()
400            await call.read()
401
402        task = self.loop.create_task(another_coro())
403        await coro_started.wait()
404
405        self.assertFalse(task.done())
406        task.cancel()
407
408        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
409
410        with self.assertRaises(asyncio.CancelledError):
411            await task
412
413    async def test_cancel_unary_stream_in_task_using_async_for(self):
414        coro_started = asyncio.Event()
415
416        # Configs the server method to block forever
417        request = messages_pb2.StreamingOutputCallRequest()
418        request.response_parameters.append(
419            messages_pb2.ResponseParameters(
420                size=_RESPONSE_PAYLOAD_SIZE,
421                interval_us=_INFINITE_INTERVAL_US,
422            ))
423
424        # Invokes the actual RPC
425        call = self._stub.StreamingOutputCall(request)
426
427        async def another_coro():
428            coro_started.set()
429            async for _ in call:
430                pass
431
432        task = self.loop.create_task(another_coro())
433        await coro_started.wait()
434
435        self.assertFalse(task.done())
436        task.cancel()
437
438        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
439
440        with self.assertRaises(asyncio.CancelledError):
441            await task
442
443    async def test_time_remaining(self):
444        request = messages_pb2.StreamingOutputCallRequest()
445        # First message comes back immediately
446        request.response_parameters.append(
447            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
448        # Second message comes back after a unit of wait time
449        request.response_parameters.append(
450            messages_pb2.ResponseParameters(
451                size=_RESPONSE_PAYLOAD_SIZE,
452                interval_us=_RESPONSE_INTERVAL_US,
453            ))
454
455        call = self._stub.StreamingOutputCall(request,
456                                              timeout=_SHORT_TIMEOUT_S * 2)
457
458        response = await call.read()
459        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
460
461        # Should be around the same as the timeout
462        remained_time = call.time_remaining()
463        self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
464        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2)
465
466        response = await call.read()
467        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
468
469        # Should be around the timeout minus a unit of wait time
470        remained_time = call.time_remaining()
471        self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2)
472        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
473
474        self.assertEqual(grpc.StatusCode.OK, await call.code())
475
476    async def test_empty_responses(self):
477        # Prepares the request
478        request = messages_pb2.StreamingOutputCallRequest()
479        for _ in range(_NUM_STREAM_RESPONSES):
480            request.response_parameters.append(
481                messages_pb2.ResponseParameters())
482
483        # Invokes the actual RPC
484        call = self._stub.StreamingOutputCall(request)
485
486        for _ in range(_NUM_STREAM_RESPONSES):
487            response = await call.read()
488            self.assertIs(type(response),
489                          messages_pb2.StreamingOutputCallResponse)
490            self.assertEqual(b'', response.SerializeToString())
491
492        self.assertEqual(grpc.StatusCode.OK, await call.code())
493
494
495class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
496
497    async def test_cancel_stream_unary(self):
498        call = self._stub.StreamingInputCall()
499
500        # Prepares the request
501        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
502        request = messages_pb2.StreamingInputCallRequest(payload=payload)
503
504        # Sends out requests
505        for _ in range(_NUM_STREAM_RESPONSES):
506            await call.write(request)
507
508        # Cancels the RPC
509        self.assertFalse(call.done())
510        self.assertFalse(call.cancelled())
511        self.assertTrue(call.cancel())
512        self.assertTrue(call.cancelled())
513
514        await call.done_writing()
515
516        with self.assertRaises(asyncio.CancelledError):
517            await call
518
519    async def test_early_cancel_stream_unary(self):
520        call = self._stub.StreamingInputCall()
521
522        # Cancels the RPC
523        self.assertFalse(call.done())
524        self.assertFalse(call.cancelled())
525        self.assertTrue(call.cancel())
526        self.assertTrue(call.cancelled())
527
528        with self.assertRaises(asyncio.InvalidStateError):
529            await call.write(messages_pb2.StreamingInputCallRequest())
530
531        # Should be no-op
532        await call.done_writing()
533
534        with self.assertRaises(asyncio.CancelledError):
535            await call
536
537    async def test_write_after_done_writing(self):
538        call = self._stub.StreamingInputCall()
539
540        # Prepares the request
541        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
542        request = messages_pb2.StreamingInputCallRequest(payload=payload)
543
544        # Sends out requests
545        for _ in range(_NUM_STREAM_RESPONSES):
546            await call.write(request)
547
548        # Should be no-op
549        await call.done_writing()
550
551        with self.assertRaises(asyncio.InvalidStateError):
552            await call.write(messages_pb2.StreamingInputCallRequest())
553
554        response = await call
555        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
556        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
557                         response.aggregated_payload_size)
558
559        self.assertEqual(await call.code(), grpc.StatusCode.OK)
560
561    async def test_error_in_async_generator(self):
562        # Server will pause between responses
563        request = messages_pb2.StreamingOutputCallRequest()
564        request.response_parameters.append(
565            messages_pb2.ResponseParameters(
566                size=_RESPONSE_PAYLOAD_SIZE,
567                interval_us=_RESPONSE_INTERVAL_US,
568            ))
569
570        # We expect the request iterator to receive the exception
571        request_iterator_received_the_exception = asyncio.Event()
572
573        async def request_iterator():
574            with self.assertRaises(asyncio.CancelledError):
575                for _ in range(_NUM_STREAM_RESPONSES):
576                    yield request
577                    await asyncio.sleep(_SHORT_TIMEOUT_S)
578            request_iterator_received_the_exception.set()
579
580        call = self._stub.StreamingInputCall(request_iterator())
581
582        # Cancel the RPC after at least one response
583        async def cancel_later():
584            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
585            call.cancel()
586
587        cancel_later_task = self.loop.create_task(cancel_later())
588
589        with self.assertRaises(asyncio.CancelledError):
590            await call
591
592        await request_iterator_received_the_exception.wait()
593
594        # No failures in the cancel later task!
595        await cancel_later_task
596
597    async def test_normal_iterable_requests(self):
598        # Prepares the request
599        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
600        request = messages_pb2.StreamingInputCallRequest(payload=payload)
601        requests = [request] * _NUM_STREAM_RESPONSES
602
603        # Sends out requests
604        call = self._stub.StreamingInputCall(requests)
605
606        # RPC should succeed
607        response = await call
608        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
609        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
610                         response.aggregated_payload_size)
611
612        self.assertEqual(await call.code(), grpc.StatusCode.OK)
613
614    async def test_call_rpc_error(self):
615        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
616            stub = test_pb2_grpc.TestServiceStub(channel)
617
618            # The error should be raised automatically without any traffic.
619            call = stub.StreamingInputCall()
620            with self.assertRaises(aio.AioRpcError) as exception_context:
621                await call
622
623            self.assertEqual(grpc.StatusCode.UNAVAILABLE,
624                             exception_context.exception.code())
625
626            self.assertTrue(call.done())
627            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
628
629    async def test_timeout(self):
630        call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S)
631
632        # The error should be raised automatically without any traffic.
633        with self.assertRaises(aio.AioRpcError) as exception_context:
634            await call
635
636        rpc_error = exception_context.exception
637        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code())
638        self.assertTrue(call.done())
639        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code())
640
641
642# Prepares the request that stream in a ping-pong manner.
643_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
644_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
645    messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
646_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = messages_pb2.StreamingOutputCallRequest(
647)
648_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append(
649    messages_pb2.ResponseParameters())
650
651
652class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
653
654    async def test_cancel(self):
655        # Invokes the actual RPC
656        call = self._stub.FullDuplexCall()
657
658        for _ in range(_NUM_STREAM_RESPONSES):
659            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
660            response = await call.read()
661            self.assertIsInstance(response,
662                                  messages_pb2.StreamingOutputCallResponse)
663            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
664
665        # Cancels the RPC
666        self.assertFalse(call.done())
667        self.assertFalse(call.cancelled())
668        self.assertTrue(call.cancel())
669        self.assertTrue(call.cancelled())
670        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
671
672    async def test_cancel_with_pending_read(self):
673        call = self._stub.FullDuplexCall()
674
675        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
676
677        # Cancels the RPC
678        self.assertFalse(call.done())
679        self.assertFalse(call.cancelled())
680        self.assertTrue(call.cancel())
681        self.assertTrue(call.cancelled())
682        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
683
684    async def test_cancel_with_ongoing_read(self):
685        call = self._stub.FullDuplexCall()
686        coro_started = asyncio.Event()
687
688        async def read_coro():
689            coro_started.set()
690            await call.read()
691
692        read_task = self.loop.create_task(read_coro())
693        await coro_started.wait()
694        self.assertFalse(read_task.done())
695
696        # Cancels the RPC
697        self.assertFalse(call.done())
698        self.assertFalse(call.cancelled())
699        self.assertTrue(call.cancel())
700        self.assertTrue(call.cancelled())
701        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
702
703    async def test_early_cancel(self):
704        call = self._stub.FullDuplexCall()
705
706        # Cancels the RPC
707        self.assertFalse(call.done())
708        self.assertFalse(call.cancelled())
709        self.assertTrue(call.cancel())
710        self.assertTrue(call.cancelled())
711        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
712
713    async def test_cancel_after_done_writing(self):
714        call = self._stub.FullDuplexCall()
715        await call.done_writing()
716
717        # Cancels the RPC
718        self.assertFalse(call.done())
719        self.assertFalse(call.cancelled())
720        self.assertTrue(call.cancel())
721        self.assertTrue(call.cancelled())
722        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
723
724    async def test_late_cancel(self):
725        call = self._stub.FullDuplexCall()
726        await call.done_writing()
727        self.assertEqual(grpc.StatusCode.OK, await call.code())
728
729        # Cancels the RPC
730        self.assertTrue(call.done())
731        self.assertFalse(call.cancelled())
732        self.assertFalse(call.cancel())
733        self.assertFalse(call.cancelled())
734
735        # Status is still OK
736        self.assertEqual(grpc.StatusCode.OK, await call.code())
737
738    async def test_async_generator(self):
739
740        async def request_generator():
741            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
742            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
743
744        call = self._stub.FullDuplexCall(request_generator())
745        async for response in call:
746            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
747
748        self.assertEqual(await call.code(), grpc.StatusCode.OK)
749
750    async def test_too_many_reads(self):
751
752        async def request_generator():
753            for _ in range(_NUM_STREAM_RESPONSES):
754                yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
755
756        call = self._stub.FullDuplexCall(request_generator())
757        for _ in range(_NUM_STREAM_RESPONSES):
758            response = await call.read()
759            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
760        self.assertIs(await call.read(), aio.EOF)
761
762        self.assertEqual(await call.code(), grpc.StatusCode.OK)
763        # After the RPC finished, the read should also produce EOF
764        self.assertIs(await call.read(), aio.EOF)
765
766    async def test_read_write_after_done_writing(self):
767        call = self._stub.FullDuplexCall()
768
769        # Writes two requests, and pending two requests
770        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
771        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
772        await call.done_writing()
773
774        # Further write should fail
775        with self.assertRaises(asyncio.InvalidStateError):
776            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
777
778        # But read should be unaffected
779        response = await call.read()
780        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
781        response = await call.read()
782        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
783
784        self.assertEqual(await call.code(), grpc.StatusCode.OK)
785
786    async def test_error_in_async_generator(self):
787        # Server will pause between responses
788        request = messages_pb2.StreamingOutputCallRequest()
789        request.response_parameters.append(
790            messages_pb2.ResponseParameters(
791                size=_RESPONSE_PAYLOAD_SIZE,
792                interval_us=_RESPONSE_INTERVAL_US,
793            ))
794
795        # We expect the request iterator to receive the exception
796        request_iterator_received_the_exception = asyncio.Event()
797
798        async def request_iterator():
799            with self.assertRaises(asyncio.CancelledError):
800                for _ in range(_NUM_STREAM_RESPONSES):
801                    yield request
802                    await asyncio.sleep(_SHORT_TIMEOUT_S)
803            request_iterator_received_the_exception.set()
804
805        call = self._stub.FullDuplexCall(request_iterator())
806
807        # Cancel the RPC after at least one response
808        async def cancel_later():
809            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
810            call.cancel()
811
812        cancel_later_task = self.loop.create_task(cancel_later())
813
814        with self.assertRaises(asyncio.CancelledError):
815            async for response in call:
816                self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
817                                 len(response.payload.body))
818
819        await request_iterator_received_the_exception.wait()
820
821        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
822        # No failures in the cancel later task!
823        await cancel_later_task
824
825    async def test_normal_iterable_requests(self):
826        requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES
827
828        call = self._stub.FullDuplexCall(iter(requests))
829        async for response in call:
830            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
831
832        self.assertEqual(await call.code(), grpc.StatusCode.OK)
833
834    async def test_empty_ping_pong(self):
835        call = self._stub.FullDuplexCall()
836        for _ in range(_NUM_STREAM_RESPONSES):
837            await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE)
838            response = await call.read()
839            self.assertEqual(b'', response.SerializeToString())
840        await call.done_writing()
841        self.assertEqual(await call.code(), grpc.StatusCode.OK)
842
843
844if __name__ == '__main__':
845    logging.basicConfig(level=logging.DEBUG)
846    unittest.main(verbosity=2)
847