1# Copyright 2019 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior of the Call classes."""
15
16import asyncio
17import logging
18import unittest
19import datetime
20
21import grpc
22from grpc.experimental import aio
23
24from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
25from tests_aio.unit._test_base import AioTestBase
26from tests_aio.unit._test_server import start_test_server
27from tests_aio.unit._constants import UNREACHABLE_TARGET
28
29_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds()
30
31_NUM_STREAM_RESPONSES = 5
32_RESPONSE_PAYLOAD_SIZE = 42
33_REQUEST_PAYLOAD_SIZE = 7
34_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
35_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
36_INFINITE_INTERVAL_US = 2**31 - 1
37
38
39class _MulticallableTestMixin():
40
41    async def setUp(self):
42        address, self._server = await start_test_server()
43        self._channel = aio.insecure_channel(address)
44        self._stub = test_pb2_grpc.TestServiceStub(self._channel)
45
46    async def tearDown(self):
47        await self._channel.close()
48        await self._server.stop(None)
49
50
51class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
52
53    async def test_call_to_string(self):
54        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
55
56        self.assertTrue(str(call) is not None)
57        self.assertTrue(repr(call) is not None)
58
59        await call
60
61        self.assertTrue(str(call) is not None)
62        self.assertTrue(repr(call) is not None)
63
64    async def test_call_ok(self):
65        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
66
67        self.assertFalse(call.done())
68
69        response = await call
70
71        self.assertTrue(call.done())
72        self.assertIsInstance(response, messages_pb2.SimpleResponse)
73        self.assertEqual(await call.code(), grpc.StatusCode.OK)
74
75        # Response is cached at call object level, reentrance
76        # returns again the same response
77        response_retry = await call
78        self.assertIs(response, response_retry)
79
80    async def test_call_rpc_error(self):
81        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
82            stub = test_pb2_grpc.TestServiceStub(channel)
83
84            call = stub.UnaryCall(messages_pb2.SimpleRequest())
85
86            with self.assertRaises(aio.AioRpcError) as exception_context:
87                await call
88
89            self.assertEqual(grpc.StatusCode.UNAVAILABLE,
90                             exception_context.exception.code())
91
92            self.assertTrue(call.done())
93            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
94
95    async def test_call_code_awaitable(self):
96        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
97        self.assertEqual(await call.code(), grpc.StatusCode.OK)
98
99    async def test_call_details_awaitable(self):
100        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
101        self.assertEqual('', await call.details())
102
103    async def test_call_initial_metadata_awaitable(self):
104        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
105        self.assertEqual((), await call.initial_metadata())
106
107    async def test_call_trailing_metadata_awaitable(self):
108        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
109        self.assertEqual((), await call.trailing_metadata())
110
111    async def test_call_initial_metadata_cancelable(self):
112        coro_started = asyncio.Event()
113        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
114
115        async def coro():
116            coro_started.set()
117            await call.initial_metadata()
118
119        task = self.loop.create_task(coro())
120        await coro_started.wait()
121        task.cancel()
122
123        # Test that initial metadata can still be asked thought
124        # a cancellation happened with the previous task
125        self.assertEqual((), await call.initial_metadata())
126
127    async def test_call_initial_metadata_multiple_waiters(self):
128        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
129
130        async def coro():
131            return await call.initial_metadata()
132
133        task1 = self.loop.create_task(coro())
134        task2 = self.loop.create_task(coro())
135
136        await call
137
138        self.assertEqual([(), ()], await asyncio.gather(*[task1, task2]))
139
140    async def test_call_code_cancelable(self):
141        coro_started = asyncio.Event()
142        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
143
144        async def coro():
145            coro_started.set()
146            await call.code()
147
148        task = self.loop.create_task(coro())
149        await coro_started.wait()
150        task.cancel()
151
152        # Test that code can still be asked thought
153        # a cancellation happened with the previous task
154        self.assertEqual(grpc.StatusCode.OK, await call.code())
155
156    async def test_call_code_multiple_waiters(self):
157        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
158
159        async def coro():
160            return await call.code()
161
162        task1 = self.loop.create_task(coro())
163        task2 = self.loop.create_task(coro())
164
165        await call
166
167        self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await
168                         asyncio.gather(task1, task2))
169
170    async def test_cancel_unary_unary(self):
171        call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
172
173        self.assertFalse(call.cancelled())
174
175        self.assertTrue(call.cancel())
176        self.assertFalse(call.cancel())
177
178        with self.assertRaises(asyncio.CancelledError):
179            await call
180
181        # The info in the RpcError should match the info in Call object.
182        self.assertTrue(call.cancelled())
183        self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
184        self.assertEqual(await call.details(),
185                         'Locally cancelled by application!')
186
187    async def test_cancel_unary_unary_in_task(self):
188        coro_started = asyncio.Event()
189        call = self._stub.EmptyCall(messages_pb2.SimpleRequest())
190
191        async def another_coro():
192            coro_started.set()
193            await call
194
195        task = self.loop.create_task(another_coro())
196        await coro_started.wait()
197
198        self.assertFalse(task.done())
199        task.cancel()
200
201        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
202
203        with self.assertRaises(asyncio.CancelledError):
204            await task
205
206    async def test_passing_credentials_fails_over_insecure_channel(self):
207        call_credentials = grpc.composite_call_credentials(
208            grpc.access_token_call_credentials("abc"),
209            grpc.access_token_call_credentials("def"),
210        )
211        with self.assertRaisesRegex(
212                aio.UsageError,
213                "Call credentials are only valid on secure channels"):
214            self._stub.UnaryCall(messages_pb2.SimpleRequest(),
215                                 credentials=call_credentials)
216
217
218class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
219
220    async def test_call_rpc_error(self):
221        channel = aio.insecure_channel(UNREACHABLE_TARGET)
222        request = messages_pb2.StreamingOutputCallRequest()
223        stub = test_pb2_grpc.TestServiceStub(channel)
224        call = stub.StreamingOutputCall(request)
225
226        with self.assertRaises(aio.AioRpcError) as exception_context:
227            async for response in call:
228                pass
229
230        self.assertEqual(grpc.StatusCode.UNAVAILABLE,
231                         exception_context.exception.code())
232
233        self.assertTrue(call.done())
234        self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
235        await channel.close()
236
237    async def test_cancel_unary_stream(self):
238        # Prepares the request
239        request = messages_pb2.StreamingOutputCallRequest()
240        for _ in range(_NUM_STREAM_RESPONSES):
241            request.response_parameters.append(
242                messages_pb2.ResponseParameters(
243                    size=_RESPONSE_PAYLOAD_SIZE,
244                    interval_us=_RESPONSE_INTERVAL_US,
245                ))
246
247        # Invokes the actual RPC
248        call = self._stub.StreamingOutputCall(request)
249        self.assertFalse(call.cancelled())
250
251        response = await call.read()
252        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
253        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
254
255        self.assertTrue(call.cancel())
256        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
257        self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
258                         call.details())
259        self.assertFalse(call.cancel())
260
261        with self.assertRaises(asyncio.CancelledError):
262            await call.read()
263        self.assertTrue(call.cancelled())
264
265    async def test_multiple_cancel_unary_stream(self):
266        # Prepares the request
267        request = messages_pb2.StreamingOutputCallRequest()
268        for _ in range(_NUM_STREAM_RESPONSES):
269            request.response_parameters.append(
270                messages_pb2.ResponseParameters(
271                    size=_RESPONSE_PAYLOAD_SIZE,
272                    interval_us=_RESPONSE_INTERVAL_US,
273                ))
274
275        # Invokes the actual RPC
276        call = self._stub.StreamingOutputCall(request)
277        self.assertFalse(call.cancelled())
278
279        response = await call.read()
280        self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
281        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
282
283        self.assertTrue(call.cancel())
284        self.assertFalse(call.cancel())
285        self.assertFalse(call.cancel())
286        self.assertFalse(call.cancel())
287
288        with self.assertRaises(asyncio.CancelledError):
289            await call.read()
290
291    async def test_early_cancel_unary_stream(self):
292        """Test cancellation before receiving messages."""
293        # Prepares the request
294        request = messages_pb2.StreamingOutputCallRequest()
295        for _ in range(_NUM_STREAM_RESPONSES):
296            request.response_parameters.append(
297                messages_pb2.ResponseParameters(
298                    size=_RESPONSE_PAYLOAD_SIZE,
299                    interval_us=_RESPONSE_INTERVAL_US,
300                ))
301
302        # Invokes the actual RPC
303        call = self._stub.StreamingOutputCall(request)
304
305        self.assertFalse(call.cancelled())
306        self.assertTrue(call.cancel())
307        self.assertFalse(call.cancel())
308
309        with self.assertRaises(asyncio.CancelledError):
310            await call.read()
311
312        self.assertTrue(call.cancelled())
313
314        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
315        self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
316                         call.details())
317
318    async def test_late_cancel_unary_stream(self):
319        """Test cancellation after received all messages."""
320        # Prepares the request
321        request = messages_pb2.StreamingOutputCallRequest()
322        for _ in range(_NUM_STREAM_RESPONSES):
323            request.response_parameters.append(
324                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
325
326        # Invokes the actual RPC
327        call = self._stub.StreamingOutputCall(request)
328
329        for _ in range(_NUM_STREAM_RESPONSES):
330            response = await call.read()
331            self.assertIs(type(response),
332                          messages_pb2.StreamingOutputCallResponse)
333            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
334
335        # After all messages received, it is possible that the final state
336        # is received or on its way. It's basically a data race, so our
337        # expectation here is do not crash :)
338        call.cancel()
339        self.assertIn(await call.code(),
340                      [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
341
342    async def test_too_many_reads_unary_stream(self):
343        """Test calling read after received all messages fails."""
344        # Prepares the request
345        request = messages_pb2.StreamingOutputCallRequest()
346        for _ in range(_NUM_STREAM_RESPONSES):
347            request.response_parameters.append(
348                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
349
350        # Invokes the actual RPC
351        call = self._stub.StreamingOutputCall(request)
352
353        for _ in range(_NUM_STREAM_RESPONSES):
354            response = await call.read()
355            self.assertIs(type(response),
356                          messages_pb2.StreamingOutputCallResponse)
357            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
358        self.assertIs(await call.read(), aio.EOF)
359
360        # After the RPC is finished, further reads will lead to exception.
361        self.assertEqual(await call.code(), grpc.StatusCode.OK)
362        self.assertIs(await call.read(), aio.EOF)
363
364    async def test_unary_stream_async_generator(self):
365        """Sunny day test case for unary_stream."""
366        # Prepares the request
367        request = messages_pb2.StreamingOutputCallRequest()
368        for _ in range(_NUM_STREAM_RESPONSES):
369            request.response_parameters.append(
370                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
371
372        # Invokes the actual RPC
373        call = self._stub.StreamingOutputCall(request)
374        self.assertFalse(call.cancelled())
375
376        async for response in call:
377            self.assertIs(type(response),
378                          messages_pb2.StreamingOutputCallResponse)
379            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
380
381        self.assertEqual(await call.code(), grpc.StatusCode.OK)
382
383    async def test_cancel_unary_stream_in_task_using_read(self):
384        coro_started = asyncio.Event()
385
386        # Configs the server method to block forever
387        request = messages_pb2.StreamingOutputCallRequest()
388        request.response_parameters.append(
389            messages_pb2.ResponseParameters(
390                size=_RESPONSE_PAYLOAD_SIZE,
391                interval_us=_INFINITE_INTERVAL_US,
392            ))
393
394        # Invokes the actual RPC
395        call = self._stub.StreamingOutputCall(request)
396
397        async def another_coro():
398            coro_started.set()
399            await call.read()
400
401        task = self.loop.create_task(another_coro())
402        await coro_started.wait()
403
404        self.assertFalse(task.done())
405        task.cancel()
406
407        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
408
409        with self.assertRaises(asyncio.CancelledError):
410            await task
411
412    async def test_cancel_unary_stream_in_task_using_async_for(self):
413        coro_started = asyncio.Event()
414
415        # Configs the server method to block forever
416        request = messages_pb2.StreamingOutputCallRequest()
417        request.response_parameters.append(
418            messages_pb2.ResponseParameters(
419                size=_RESPONSE_PAYLOAD_SIZE,
420                interval_us=_INFINITE_INTERVAL_US,
421            ))
422
423        # Invokes the actual RPC
424        call = self._stub.StreamingOutputCall(request)
425
426        async def another_coro():
427            coro_started.set()
428            async for _ in call:
429                pass
430
431        task = self.loop.create_task(another_coro())
432        await coro_started.wait()
433
434        self.assertFalse(task.done())
435        task.cancel()
436
437        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
438
439        with self.assertRaises(asyncio.CancelledError):
440            await task
441
442    async def test_time_remaining(self):
443        request = messages_pb2.StreamingOutputCallRequest()
444        # First message comes back immediately
445        request.response_parameters.append(
446            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
447        # Second message comes back after a unit of wait time
448        request.response_parameters.append(
449            messages_pb2.ResponseParameters(
450                size=_RESPONSE_PAYLOAD_SIZE,
451                interval_us=_RESPONSE_INTERVAL_US,
452            ))
453
454        call = self._stub.StreamingOutputCall(request,
455                                              timeout=_SHORT_TIMEOUT_S * 2)
456
457        response = await call.read()
458        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
459
460        # Should be around the same as the timeout
461        remained_time = call.time_remaining()
462        self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
463        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2)
464
465        response = await call.read()
466        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
467
468        # Should be around the timeout minus a unit of wait time
469        remained_time = call.time_remaining()
470        self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2)
471        self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2)
472
473        self.assertEqual(grpc.StatusCode.OK, await call.code())
474
475
476class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
477
478    async def test_cancel_stream_unary(self):
479        call = self._stub.StreamingInputCall()
480
481        # Prepares the request
482        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
483        request = messages_pb2.StreamingInputCallRequest(payload=payload)
484
485        # Sends out requests
486        for _ in range(_NUM_STREAM_RESPONSES):
487            await call.write(request)
488
489        # Cancels the RPC
490        self.assertFalse(call.done())
491        self.assertFalse(call.cancelled())
492        self.assertTrue(call.cancel())
493        self.assertTrue(call.cancelled())
494
495        await call.done_writing()
496
497        with self.assertRaises(asyncio.CancelledError):
498            await call
499
500    async def test_early_cancel_stream_unary(self):
501        call = self._stub.StreamingInputCall()
502
503        # Cancels the RPC
504        self.assertFalse(call.done())
505        self.assertFalse(call.cancelled())
506        self.assertTrue(call.cancel())
507        self.assertTrue(call.cancelled())
508
509        with self.assertRaises(asyncio.InvalidStateError):
510            await call.write(messages_pb2.StreamingInputCallRequest())
511
512        # Should be no-op
513        await call.done_writing()
514
515        with self.assertRaises(asyncio.CancelledError):
516            await call
517
518    async def test_write_after_done_writing(self):
519        call = self._stub.StreamingInputCall()
520
521        # Prepares the request
522        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
523        request = messages_pb2.StreamingInputCallRequest(payload=payload)
524
525        # Sends out requests
526        for _ in range(_NUM_STREAM_RESPONSES):
527            await call.write(request)
528
529        # Should be no-op
530        await call.done_writing()
531
532        with self.assertRaises(asyncio.InvalidStateError):
533            await call.write(messages_pb2.StreamingInputCallRequest())
534
535        response = await call
536        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
537        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
538                         response.aggregated_payload_size)
539
540        self.assertEqual(await call.code(), grpc.StatusCode.OK)
541
542    async def test_error_in_async_generator(self):
543        # Server will pause between responses
544        request = messages_pb2.StreamingOutputCallRequest()
545        request.response_parameters.append(
546            messages_pb2.ResponseParameters(
547                size=_RESPONSE_PAYLOAD_SIZE,
548                interval_us=_RESPONSE_INTERVAL_US,
549            ))
550
551        # We expect the request iterator to receive the exception
552        request_iterator_received_the_exception = asyncio.Event()
553
554        async def request_iterator():
555            with self.assertRaises(asyncio.CancelledError):
556                for _ in range(_NUM_STREAM_RESPONSES):
557                    yield request
558                    await asyncio.sleep(_SHORT_TIMEOUT_S)
559            request_iterator_received_the_exception.set()
560
561        call = self._stub.StreamingInputCall(request_iterator())
562
563        # Cancel the RPC after at least one response
564        async def cancel_later():
565            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
566            call.cancel()
567
568        cancel_later_task = self.loop.create_task(cancel_later())
569
570        with self.assertRaises(asyncio.CancelledError):
571            await call
572
573        await request_iterator_received_the_exception.wait()
574
575        # No failures in the cancel later task!
576        await cancel_later_task
577
578    async def test_normal_iterable_requests(self):
579        # Prepares the request
580        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
581        request = messages_pb2.StreamingInputCallRequest(payload=payload)
582        requests = [request] * _NUM_STREAM_RESPONSES
583
584        # Sends out requests
585        call = self._stub.StreamingInputCall(requests)
586
587        # RPC should succeed
588        response = await call
589        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
590        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
591                         response.aggregated_payload_size)
592
593        self.assertEqual(await call.code(), grpc.StatusCode.OK)
594
595    async def test_call_rpc_error(self):
596        async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
597            stub = test_pb2_grpc.TestServiceStub(channel)
598
599            # The error should be raised automatically without any traffic.
600            call = stub.StreamingInputCall()
601            with self.assertRaises(aio.AioRpcError) as exception_context:
602                await call
603
604            self.assertEqual(grpc.StatusCode.UNAVAILABLE,
605                             exception_context.exception.code())
606
607            self.assertTrue(call.done())
608            self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
609
610    async def test_timeout(self):
611        call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S)
612
613        # The error should be raised automatically without any traffic.
614        with self.assertRaises(aio.AioRpcError) as exception_context:
615            await call
616
617        rpc_error = exception_context.exception
618        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code())
619        self.assertTrue(call.done())
620        self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code())
621
622
623# Prepares the request that stream in a ping-pong manner.
624_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
625_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
626    messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
627
628
629class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
630
631    async def test_cancel(self):
632        # Invokes the actual RPC
633        call = self._stub.FullDuplexCall()
634
635        for _ in range(_NUM_STREAM_RESPONSES):
636            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
637            response = await call.read()
638            self.assertIsInstance(response,
639                                  messages_pb2.StreamingOutputCallResponse)
640            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
641
642        # Cancels the RPC
643        self.assertFalse(call.done())
644        self.assertFalse(call.cancelled())
645        self.assertTrue(call.cancel())
646        self.assertTrue(call.cancelled())
647        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
648
649    async def test_cancel_with_pending_read(self):
650        call = self._stub.FullDuplexCall()
651
652        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
653
654        # Cancels the RPC
655        self.assertFalse(call.done())
656        self.assertFalse(call.cancelled())
657        self.assertTrue(call.cancel())
658        self.assertTrue(call.cancelled())
659        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
660
661    async def test_cancel_with_ongoing_read(self):
662        call = self._stub.FullDuplexCall()
663        coro_started = asyncio.Event()
664
665        async def read_coro():
666            coro_started.set()
667            await call.read()
668
669        read_task = self.loop.create_task(read_coro())
670        await coro_started.wait()
671        self.assertFalse(read_task.done())
672
673        # Cancels the RPC
674        self.assertFalse(call.done())
675        self.assertFalse(call.cancelled())
676        self.assertTrue(call.cancel())
677        self.assertTrue(call.cancelled())
678        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
679
680    async def test_early_cancel(self):
681        call = self._stub.FullDuplexCall()
682
683        # Cancels the RPC
684        self.assertFalse(call.done())
685        self.assertFalse(call.cancelled())
686        self.assertTrue(call.cancel())
687        self.assertTrue(call.cancelled())
688        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
689
690    async def test_cancel_after_done_writing(self):
691        call = self._stub.FullDuplexCall()
692        await call.done_writing()
693
694        # Cancels the RPC
695        self.assertFalse(call.done())
696        self.assertFalse(call.cancelled())
697        self.assertTrue(call.cancel())
698        self.assertTrue(call.cancelled())
699        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
700
701    async def test_late_cancel(self):
702        call = self._stub.FullDuplexCall()
703        await call.done_writing()
704        self.assertEqual(grpc.StatusCode.OK, await call.code())
705
706        # Cancels the RPC
707        self.assertTrue(call.done())
708        self.assertFalse(call.cancelled())
709        self.assertFalse(call.cancel())
710        self.assertFalse(call.cancelled())
711
712        # Status is still OK
713        self.assertEqual(grpc.StatusCode.OK, await call.code())
714
715    async def test_async_generator(self):
716
717        async def request_generator():
718            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
719            yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
720
721        call = self._stub.FullDuplexCall(request_generator())
722        async for response in call:
723            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
724
725        self.assertEqual(await call.code(), grpc.StatusCode.OK)
726
727    async def test_too_many_reads(self):
728
729        async def request_generator():
730            for _ in range(_NUM_STREAM_RESPONSES):
731                yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
732
733        call = self._stub.FullDuplexCall(request_generator())
734        for _ in range(_NUM_STREAM_RESPONSES):
735            response = await call.read()
736            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
737        self.assertIs(await call.read(), aio.EOF)
738
739        self.assertEqual(await call.code(), grpc.StatusCode.OK)
740        # After the RPC finished, the read should also produce EOF
741        self.assertIs(await call.read(), aio.EOF)
742
743    async def test_read_write_after_done_writing(self):
744        call = self._stub.FullDuplexCall()
745
746        # Writes two requests, and pending two requests
747        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
748        await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
749        await call.done_writing()
750
751        # Further write should fail
752        with self.assertRaises(asyncio.InvalidStateError):
753            await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
754
755        # But read should be unaffected
756        response = await call.read()
757        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
758        response = await call.read()
759        self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
760
761        self.assertEqual(await call.code(), grpc.StatusCode.OK)
762
763    async def test_error_in_async_generator(self):
764        # Server will pause between responses
765        request = messages_pb2.StreamingOutputCallRequest()
766        request.response_parameters.append(
767            messages_pb2.ResponseParameters(
768                size=_RESPONSE_PAYLOAD_SIZE,
769                interval_us=_RESPONSE_INTERVAL_US,
770            ))
771
772        # We expect the request iterator to receive the exception
773        request_iterator_received_the_exception = asyncio.Event()
774
775        async def request_iterator():
776            with self.assertRaises(asyncio.CancelledError):
777                for _ in range(_NUM_STREAM_RESPONSES):
778                    yield request
779                    await asyncio.sleep(_SHORT_TIMEOUT_S)
780            request_iterator_received_the_exception.set()
781
782        call = self._stub.FullDuplexCall(request_iterator())
783
784        # Cancel the RPC after at least one response
785        async def cancel_later():
786            await asyncio.sleep(_SHORT_TIMEOUT_S * 2)
787            call.cancel()
788
789        cancel_later_task = self.loop.create_task(cancel_later())
790
791        with self.assertRaises(asyncio.CancelledError):
792            async for response in call:
793                self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
794                                 len(response.payload.body))
795
796        await request_iterator_received_the_exception.wait()
797
798        self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
799        # No failures in the cancel later task!
800        await cancel_later_task
801
802    async def test_normal_iterable_requests(self):
803        requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES
804
805        call = self._stub.FullDuplexCall(iter(requests))
806        async for response in call:
807            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
808
809        self.assertEqual(await call.code(), grpc.StatusCode.OK)
810
811
812if __name__ == '__main__':
813    logging.basicConfig(level=logging.DEBUG)
814    unittest.main(verbosity=2)
815