1# Copyright 2019 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior of the Call classes.""" 15 16import asyncio 17import logging 18import unittest 19import datetime 20 21import grpc 22from grpc.experimental import aio 23 24from src.proto.grpc.testing import messages_pb2, test_pb2_grpc 25from tests_aio.unit._test_base import AioTestBase 26from tests_aio.unit._test_server import start_test_server 27from tests_aio.unit._constants import UNREACHABLE_TARGET 28 29_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() 30 31_NUM_STREAM_RESPONSES = 5 32_RESPONSE_PAYLOAD_SIZE = 42 33_REQUEST_PAYLOAD_SIZE = 7 34_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' 35_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) 36_INFINITE_INTERVAL_US = 2**31 - 1 37 38 39class _MulticallableTestMixin(): 40 41 async def setUp(self): 42 address, self._server = await start_test_server() 43 self._channel = aio.insecure_channel(address) 44 self._stub = test_pb2_grpc.TestServiceStub(self._channel) 45 46 async def tearDown(self): 47 await self._channel.close() 48 await self._server.stop(None) 49 50 51class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): 52 53 async def test_call_to_string(self): 54 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 55 56 self.assertTrue(str(call) is not None) 57 self.assertTrue(repr(call) is not None) 58 59 await call 60 61 self.assertTrue(str(call) is not None) 62 self.assertTrue(repr(call) is not None) 63 64 async def test_call_ok(self): 65 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 66 67 self.assertFalse(call.done()) 68 69 response = await call 70 71 self.assertTrue(call.done()) 72 self.assertIsInstance(response, messages_pb2.SimpleResponse) 73 self.assertEqual(await call.code(), grpc.StatusCode.OK) 74 75 # Response is cached at call object level, reentrance 76 # returns again the same response 77 response_retry = await call 78 self.assertIs(response, response_retry) 79 80 async def test_call_rpc_error(self): 81 async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: 82 stub = test_pb2_grpc.TestServiceStub(channel) 83 84 call = stub.UnaryCall(messages_pb2.SimpleRequest()) 85 86 with self.assertRaises(aio.AioRpcError) as exception_context: 87 await call 88 89 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 90 exception_context.exception.code()) 91 92 self.assertTrue(call.done()) 93 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 94 95 async def test_call_code_awaitable(self): 96 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 97 self.assertEqual(await call.code(), grpc.StatusCode.OK) 98 99 async def test_call_details_awaitable(self): 100 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 101 self.assertEqual('', await call.details()) 102 103 async def test_call_initial_metadata_awaitable(self): 104 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 105 self.assertEqual((), await call.initial_metadata()) 106 107 async def test_call_trailing_metadata_awaitable(self): 108 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 109 self.assertEqual((), await call.trailing_metadata()) 110 111 async def test_call_initial_metadata_cancelable(self): 112 coro_started = asyncio.Event() 113 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 114 115 async def coro(): 116 coro_started.set() 117 await call.initial_metadata() 118 119 task = self.loop.create_task(coro()) 120 await coro_started.wait() 121 task.cancel() 122 123 # Test that initial metadata can still be asked thought 124 # a cancellation happened with the previous task 125 self.assertEqual((), await call.initial_metadata()) 126 127 async def test_call_initial_metadata_multiple_waiters(self): 128 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 129 130 async def coro(): 131 return await call.initial_metadata() 132 133 task1 = self.loop.create_task(coro()) 134 task2 = self.loop.create_task(coro()) 135 136 await call 137 138 self.assertEqual([(), ()], await asyncio.gather(*[task1, task2])) 139 140 async def test_call_code_cancelable(self): 141 coro_started = asyncio.Event() 142 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 143 144 async def coro(): 145 coro_started.set() 146 await call.code() 147 148 task = self.loop.create_task(coro()) 149 await coro_started.wait() 150 task.cancel() 151 152 # Test that code can still be asked thought 153 # a cancellation happened with the previous task 154 self.assertEqual(grpc.StatusCode.OK, await call.code()) 155 156 async def test_call_code_multiple_waiters(self): 157 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 158 159 async def coro(): 160 return await call.code() 161 162 task1 = self.loop.create_task(coro()) 163 task2 = self.loop.create_task(coro()) 164 165 await call 166 167 self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await 168 asyncio.gather(task1, task2)) 169 170 async def test_cancel_unary_unary(self): 171 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 172 173 self.assertFalse(call.cancelled()) 174 175 self.assertTrue(call.cancel()) 176 self.assertFalse(call.cancel()) 177 178 with self.assertRaises(asyncio.CancelledError): 179 await call 180 181 # The info in the RpcError should match the info in Call object. 182 self.assertTrue(call.cancelled()) 183 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) 184 self.assertEqual(await call.details(), 185 'Locally cancelled by application!') 186 187 async def test_cancel_unary_unary_in_task(self): 188 coro_started = asyncio.Event() 189 call = self._stub.EmptyCall(messages_pb2.SimpleRequest()) 190 191 async def another_coro(): 192 coro_started.set() 193 await call 194 195 task = self.loop.create_task(another_coro()) 196 await coro_started.wait() 197 198 self.assertFalse(task.done()) 199 task.cancel() 200 201 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 202 203 with self.assertRaises(asyncio.CancelledError): 204 await task 205 206 async def test_passing_credentials_fails_over_insecure_channel(self): 207 call_credentials = grpc.composite_call_credentials( 208 grpc.access_token_call_credentials("abc"), 209 grpc.access_token_call_credentials("def"), 210 ) 211 with self.assertRaisesRegex( 212 aio.UsageError, 213 "Call credentials are only valid on secure channels"): 214 self._stub.UnaryCall(messages_pb2.SimpleRequest(), 215 credentials=call_credentials) 216 217 218class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): 219 220 async def test_call_rpc_error(self): 221 channel = aio.insecure_channel(UNREACHABLE_TARGET) 222 request = messages_pb2.StreamingOutputCallRequest() 223 stub = test_pb2_grpc.TestServiceStub(channel) 224 call = stub.StreamingOutputCall(request) 225 226 with self.assertRaises(aio.AioRpcError) as exception_context: 227 async for response in call: 228 pass 229 230 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 231 exception_context.exception.code()) 232 233 self.assertTrue(call.done()) 234 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 235 await channel.close() 236 237 async def test_cancel_unary_stream(self): 238 # Prepares the request 239 request = messages_pb2.StreamingOutputCallRequest() 240 for _ in range(_NUM_STREAM_RESPONSES): 241 request.response_parameters.append( 242 messages_pb2.ResponseParameters( 243 size=_RESPONSE_PAYLOAD_SIZE, 244 interval_us=_RESPONSE_INTERVAL_US, 245 )) 246 247 # Invokes the actual RPC 248 call = self._stub.StreamingOutputCall(request) 249 self.assertFalse(call.cancelled()) 250 251 response = await call.read() 252 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) 253 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 254 255 self.assertTrue(call.cancel()) 256 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 257 self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await 258 call.details()) 259 self.assertFalse(call.cancel()) 260 261 with self.assertRaises(asyncio.CancelledError): 262 await call.read() 263 self.assertTrue(call.cancelled()) 264 265 async def test_multiple_cancel_unary_stream(self): 266 # Prepares the request 267 request = messages_pb2.StreamingOutputCallRequest() 268 for _ in range(_NUM_STREAM_RESPONSES): 269 request.response_parameters.append( 270 messages_pb2.ResponseParameters( 271 size=_RESPONSE_PAYLOAD_SIZE, 272 interval_us=_RESPONSE_INTERVAL_US, 273 )) 274 275 # Invokes the actual RPC 276 call = self._stub.StreamingOutputCall(request) 277 self.assertFalse(call.cancelled()) 278 279 response = await call.read() 280 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) 281 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 282 283 self.assertTrue(call.cancel()) 284 self.assertFalse(call.cancel()) 285 self.assertFalse(call.cancel()) 286 self.assertFalse(call.cancel()) 287 288 with self.assertRaises(asyncio.CancelledError): 289 await call.read() 290 291 async def test_early_cancel_unary_stream(self): 292 """Test cancellation before receiving messages.""" 293 # Prepares the request 294 request = messages_pb2.StreamingOutputCallRequest() 295 for _ in range(_NUM_STREAM_RESPONSES): 296 request.response_parameters.append( 297 messages_pb2.ResponseParameters( 298 size=_RESPONSE_PAYLOAD_SIZE, 299 interval_us=_RESPONSE_INTERVAL_US, 300 )) 301 302 # Invokes the actual RPC 303 call = self._stub.StreamingOutputCall(request) 304 305 self.assertFalse(call.cancelled()) 306 self.assertTrue(call.cancel()) 307 self.assertFalse(call.cancel()) 308 309 with self.assertRaises(asyncio.CancelledError): 310 await call.read() 311 312 self.assertTrue(call.cancelled()) 313 314 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 315 self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await 316 call.details()) 317 318 async def test_late_cancel_unary_stream(self): 319 """Test cancellation after received all messages.""" 320 # Prepares the request 321 request = messages_pb2.StreamingOutputCallRequest() 322 for _ in range(_NUM_STREAM_RESPONSES): 323 request.response_parameters.append( 324 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 325 326 # Invokes the actual RPC 327 call = self._stub.StreamingOutputCall(request) 328 329 for _ in range(_NUM_STREAM_RESPONSES): 330 response = await call.read() 331 self.assertIs(type(response), 332 messages_pb2.StreamingOutputCallResponse) 333 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 334 335 # After all messages received, it is possible that the final state 336 # is received or on its way. It's basically a data race, so our 337 # expectation here is do not crash :) 338 call.cancel() 339 self.assertIn(await call.code(), 340 [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]) 341 342 async def test_too_many_reads_unary_stream(self): 343 """Test calling read after received all messages fails.""" 344 # Prepares the request 345 request = messages_pb2.StreamingOutputCallRequest() 346 for _ in range(_NUM_STREAM_RESPONSES): 347 request.response_parameters.append( 348 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 349 350 # Invokes the actual RPC 351 call = self._stub.StreamingOutputCall(request) 352 353 for _ in range(_NUM_STREAM_RESPONSES): 354 response = await call.read() 355 self.assertIs(type(response), 356 messages_pb2.StreamingOutputCallResponse) 357 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 358 self.assertIs(await call.read(), aio.EOF) 359 360 # After the RPC is finished, further reads will lead to exception. 361 self.assertEqual(await call.code(), grpc.StatusCode.OK) 362 self.assertIs(await call.read(), aio.EOF) 363 364 async def test_unary_stream_async_generator(self): 365 """Sunny day test case for unary_stream.""" 366 # Prepares the request 367 request = messages_pb2.StreamingOutputCallRequest() 368 for _ in range(_NUM_STREAM_RESPONSES): 369 request.response_parameters.append( 370 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 371 372 # Invokes the actual RPC 373 call = self._stub.StreamingOutputCall(request) 374 self.assertFalse(call.cancelled()) 375 376 async for response in call: 377 self.assertIs(type(response), 378 messages_pb2.StreamingOutputCallResponse) 379 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 380 381 self.assertEqual(await call.code(), grpc.StatusCode.OK) 382 383 async def test_cancel_unary_stream_in_task_using_read(self): 384 coro_started = asyncio.Event() 385 386 # Configs the server method to block forever 387 request = messages_pb2.StreamingOutputCallRequest() 388 request.response_parameters.append( 389 messages_pb2.ResponseParameters( 390 size=_RESPONSE_PAYLOAD_SIZE, 391 interval_us=_INFINITE_INTERVAL_US, 392 )) 393 394 # Invokes the actual RPC 395 call = self._stub.StreamingOutputCall(request) 396 397 async def another_coro(): 398 coro_started.set() 399 await call.read() 400 401 task = self.loop.create_task(another_coro()) 402 await coro_started.wait() 403 404 self.assertFalse(task.done()) 405 task.cancel() 406 407 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 408 409 with self.assertRaises(asyncio.CancelledError): 410 await task 411 412 async def test_cancel_unary_stream_in_task_using_async_for(self): 413 coro_started = asyncio.Event() 414 415 # Configs the server method to block forever 416 request = messages_pb2.StreamingOutputCallRequest() 417 request.response_parameters.append( 418 messages_pb2.ResponseParameters( 419 size=_RESPONSE_PAYLOAD_SIZE, 420 interval_us=_INFINITE_INTERVAL_US, 421 )) 422 423 # Invokes the actual RPC 424 call = self._stub.StreamingOutputCall(request) 425 426 async def another_coro(): 427 coro_started.set() 428 async for _ in call: 429 pass 430 431 task = self.loop.create_task(another_coro()) 432 await coro_started.wait() 433 434 self.assertFalse(task.done()) 435 task.cancel() 436 437 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 438 439 with self.assertRaises(asyncio.CancelledError): 440 await task 441 442 async def test_time_remaining(self): 443 request = messages_pb2.StreamingOutputCallRequest() 444 # First message comes back immediately 445 request.response_parameters.append( 446 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 447 # Second message comes back after a unit of wait time 448 request.response_parameters.append( 449 messages_pb2.ResponseParameters( 450 size=_RESPONSE_PAYLOAD_SIZE, 451 interval_us=_RESPONSE_INTERVAL_US, 452 )) 453 454 call = self._stub.StreamingOutputCall(request, 455 timeout=_SHORT_TIMEOUT_S * 2) 456 457 response = await call.read() 458 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 459 460 # Should be around the same as the timeout 461 remained_time = call.time_remaining() 462 self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2) 463 self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2) 464 465 response = await call.read() 466 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 467 468 # Should be around the timeout minus a unit of wait time 469 remained_time = call.time_remaining() 470 self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2) 471 self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2) 472 473 self.assertEqual(grpc.StatusCode.OK, await call.code()) 474 475 476class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): 477 478 async def test_cancel_stream_unary(self): 479 call = self._stub.StreamingInputCall() 480 481 # Prepares the request 482 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 483 request = messages_pb2.StreamingInputCallRequest(payload=payload) 484 485 # Sends out requests 486 for _ in range(_NUM_STREAM_RESPONSES): 487 await call.write(request) 488 489 # Cancels the RPC 490 self.assertFalse(call.done()) 491 self.assertFalse(call.cancelled()) 492 self.assertTrue(call.cancel()) 493 self.assertTrue(call.cancelled()) 494 495 await call.done_writing() 496 497 with self.assertRaises(asyncio.CancelledError): 498 await call 499 500 async def test_early_cancel_stream_unary(self): 501 call = self._stub.StreamingInputCall() 502 503 # Cancels the RPC 504 self.assertFalse(call.done()) 505 self.assertFalse(call.cancelled()) 506 self.assertTrue(call.cancel()) 507 self.assertTrue(call.cancelled()) 508 509 with self.assertRaises(asyncio.InvalidStateError): 510 await call.write(messages_pb2.StreamingInputCallRequest()) 511 512 # Should be no-op 513 await call.done_writing() 514 515 with self.assertRaises(asyncio.CancelledError): 516 await call 517 518 async def test_write_after_done_writing(self): 519 call = self._stub.StreamingInputCall() 520 521 # Prepares the request 522 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 523 request = messages_pb2.StreamingInputCallRequest(payload=payload) 524 525 # Sends out requests 526 for _ in range(_NUM_STREAM_RESPONSES): 527 await call.write(request) 528 529 # Should be no-op 530 await call.done_writing() 531 532 with self.assertRaises(asyncio.InvalidStateError): 533 await call.write(messages_pb2.StreamingInputCallRequest()) 534 535 response = await call 536 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 537 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 538 response.aggregated_payload_size) 539 540 self.assertEqual(await call.code(), grpc.StatusCode.OK) 541 542 async def test_error_in_async_generator(self): 543 # Server will pause between responses 544 request = messages_pb2.StreamingOutputCallRequest() 545 request.response_parameters.append( 546 messages_pb2.ResponseParameters( 547 size=_RESPONSE_PAYLOAD_SIZE, 548 interval_us=_RESPONSE_INTERVAL_US, 549 )) 550 551 # We expect the request iterator to receive the exception 552 request_iterator_received_the_exception = asyncio.Event() 553 554 async def request_iterator(): 555 with self.assertRaises(asyncio.CancelledError): 556 for _ in range(_NUM_STREAM_RESPONSES): 557 yield request 558 await asyncio.sleep(_SHORT_TIMEOUT_S) 559 request_iterator_received_the_exception.set() 560 561 call = self._stub.StreamingInputCall(request_iterator()) 562 563 # Cancel the RPC after at least one response 564 async def cancel_later(): 565 await asyncio.sleep(_SHORT_TIMEOUT_S * 2) 566 call.cancel() 567 568 cancel_later_task = self.loop.create_task(cancel_later()) 569 570 with self.assertRaises(asyncio.CancelledError): 571 await call 572 573 await request_iterator_received_the_exception.wait() 574 575 # No failures in the cancel later task! 576 await cancel_later_task 577 578 async def test_normal_iterable_requests(self): 579 # Prepares the request 580 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 581 request = messages_pb2.StreamingInputCallRequest(payload=payload) 582 requests = [request] * _NUM_STREAM_RESPONSES 583 584 # Sends out requests 585 call = self._stub.StreamingInputCall(requests) 586 587 # RPC should succeed 588 response = await call 589 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 590 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 591 response.aggregated_payload_size) 592 593 self.assertEqual(await call.code(), grpc.StatusCode.OK) 594 595 async def test_call_rpc_error(self): 596 async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: 597 stub = test_pb2_grpc.TestServiceStub(channel) 598 599 # The error should be raised automatically without any traffic. 600 call = stub.StreamingInputCall() 601 with self.assertRaises(aio.AioRpcError) as exception_context: 602 await call 603 604 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 605 exception_context.exception.code()) 606 607 self.assertTrue(call.done()) 608 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 609 610 async def test_timeout(self): 611 call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S) 612 613 # The error should be raised automatically without any traffic. 614 with self.assertRaises(aio.AioRpcError) as exception_context: 615 await call 616 617 rpc_error = exception_context.exception 618 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) 619 self.assertTrue(call.done()) 620 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code()) 621 622 623# Prepares the request that stream in a ping-pong manner. 624_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() 625_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( 626 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) 627 628 629class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): 630 631 async def test_cancel(self): 632 # Invokes the actual RPC 633 call = self._stub.FullDuplexCall() 634 635 for _ in range(_NUM_STREAM_RESPONSES): 636 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 637 response = await call.read() 638 self.assertIsInstance(response, 639 messages_pb2.StreamingOutputCallResponse) 640 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 641 642 # Cancels the RPC 643 self.assertFalse(call.done()) 644 self.assertFalse(call.cancelled()) 645 self.assertTrue(call.cancel()) 646 self.assertTrue(call.cancelled()) 647 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 648 649 async def test_cancel_with_pending_read(self): 650 call = self._stub.FullDuplexCall() 651 652 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 653 654 # Cancels the RPC 655 self.assertFalse(call.done()) 656 self.assertFalse(call.cancelled()) 657 self.assertTrue(call.cancel()) 658 self.assertTrue(call.cancelled()) 659 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 660 661 async def test_cancel_with_ongoing_read(self): 662 call = self._stub.FullDuplexCall() 663 coro_started = asyncio.Event() 664 665 async def read_coro(): 666 coro_started.set() 667 await call.read() 668 669 read_task = self.loop.create_task(read_coro()) 670 await coro_started.wait() 671 self.assertFalse(read_task.done()) 672 673 # Cancels the RPC 674 self.assertFalse(call.done()) 675 self.assertFalse(call.cancelled()) 676 self.assertTrue(call.cancel()) 677 self.assertTrue(call.cancelled()) 678 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 679 680 async def test_early_cancel(self): 681 call = self._stub.FullDuplexCall() 682 683 # Cancels the RPC 684 self.assertFalse(call.done()) 685 self.assertFalse(call.cancelled()) 686 self.assertTrue(call.cancel()) 687 self.assertTrue(call.cancelled()) 688 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 689 690 async def test_cancel_after_done_writing(self): 691 call = self._stub.FullDuplexCall() 692 await call.done_writing() 693 694 # Cancels the RPC 695 self.assertFalse(call.done()) 696 self.assertFalse(call.cancelled()) 697 self.assertTrue(call.cancel()) 698 self.assertTrue(call.cancelled()) 699 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 700 701 async def test_late_cancel(self): 702 call = self._stub.FullDuplexCall() 703 await call.done_writing() 704 self.assertEqual(grpc.StatusCode.OK, await call.code()) 705 706 # Cancels the RPC 707 self.assertTrue(call.done()) 708 self.assertFalse(call.cancelled()) 709 self.assertFalse(call.cancel()) 710 self.assertFalse(call.cancelled()) 711 712 # Status is still OK 713 self.assertEqual(grpc.StatusCode.OK, await call.code()) 714 715 async def test_async_generator(self): 716 717 async def request_generator(): 718 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 719 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 720 721 call = self._stub.FullDuplexCall(request_generator()) 722 async for response in call: 723 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 724 725 self.assertEqual(await call.code(), grpc.StatusCode.OK) 726 727 async def test_too_many_reads(self): 728 729 async def request_generator(): 730 for _ in range(_NUM_STREAM_RESPONSES): 731 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 732 733 call = self._stub.FullDuplexCall(request_generator()) 734 for _ in range(_NUM_STREAM_RESPONSES): 735 response = await call.read() 736 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 737 self.assertIs(await call.read(), aio.EOF) 738 739 self.assertEqual(await call.code(), grpc.StatusCode.OK) 740 # After the RPC finished, the read should also produce EOF 741 self.assertIs(await call.read(), aio.EOF) 742 743 async def test_read_write_after_done_writing(self): 744 call = self._stub.FullDuplexCall() 745 746 # Writes two requests, and pending two requests 747 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 748 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 749 await call.done_writing() 750 751 # Further write should fail 752 with self.assertRaises(asyncio.InvalidStateError): 753 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 754 755 # But read should be unaffected 756 response = await call.read() 757 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 758 response = await call.read() 759 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 760 761 self.assertEqual(await call.code(), grpc.StatusCode.OK) 762 763 async def test_error_in_async_generator(self): 764 # Server will pause between responses 765 request = messages_pb2.StreamingOutputCallRequest() 766 request.response_parameters.append( 767 messages_pb2.ResponseParameters( 768 size=_RESPONSE_PAYLOAD_SIZE, 769 interval_us=_RESPONSE_INTERVAL_US, 770 )) 771 772 # We expect the request iterator to receive the exception 773 request_iterator_received_the_exception = asyncio.Event() 774 775 async def request_iterator(): 776 with self.assertRaises(asyncio.CancelledError): 777 for _ in range(_NUM_STREAM_RESPONSES): 778 yield request 779 await asyncio.sleep(_SHORT_TIMEOUT_S) 780 request_iterator_received_the_exception.set() 781 782 call = self._stub.FullDuplexCall(request_iterator()) 783 784 # Cancel the RPC after at least one response 785 async def cancel_later(): 786 await asyncio.sleep(_SHORT_TIMEOUT_S * 2) 787 call.cancel() 788 789 cancel_later_task = self.loop.create_task(cancel_later()) 790 791 with self.assertRaises(asyncio.CancelledError): 792 async for response in call: 793 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, 794 len(response.payload.body)) 795 796 await request_iterator_received_the_exception.wait() 797 798 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 799 # No failures in the cancel later task! 800 await cancel_later_task 801 802 async def test_normal_iterable_requests(self): 803 requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES 804 805 call = self._stub.FullDuplexCall(iter(requests)) 806 async for response in call: 807 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 808 809 self.assertEqual(await call.code(), grpc.StatusCode.OK) 810 811 812if __name__ == '__main__': 813 logging.basicConfig(level=logging.DEBUG) 814 unittest.main(verbosity=2) 815