1# Copyright 2019 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior of the Call classes.""" 15 16import asyncio 17import datetime 18import logging 19import unittest 20 21import grpc 22from grpc.experimental import aio 23 24from src.proto.grpc.testing import messages_pb2 25from src.proto.grpc.testing import test_pb2_grpc 26from tests_aio.unit._constants import UNREACHABLE_TARGET 27from tests_aio.unit._test_base import AioTestBase 28from tests_aio.unit._test_server import start_test_server 29 30_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() 31 32_NUM_STREAM_RESPONSES = 5 33_RESPONSE_PAYLOAD_SIZE = 42 34_REQUEST_PAYLOAD_SIZE = 7 35_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' 36_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) 37_INFINITE_INTERVAL_US = 2**31 - 1 38 39 40class _MulticallableTestMixin(): 41 42 async def setUp(self): 43 address, self._server = await start_test_server() 44 self._channel = aio.insecure_channel(address) 45 self._stub = test_pb2_grpc.TestServiceStub(self._channel) 46 47 async def tearDown(self): 48 await self._channel.close() 49 await self._server.stop(None) 50 51 52class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): 53 54 async def test_call_to_string(self): 55 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 56 57 self.assertTrue(str(call) is not None) 58 self.assertTrue(repr(call) is not None) 59 60 await call 61 62 self.assertTrue(str(call) is not None) 63 self.assertTrue(repr(call) is not None) 64 65 async def test_call_ok(self): 66 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 67 68 self.assertFalse(call.done()) 69 70 response = await call 71 72 self.assertTrue(call.done()) 73 self.assertIsInstance(response, messages_pb2.SimpleResponse) 74 self.assertEqual(await call.code(), grpc.StatusCode.OK) 75 76 # Response is cached at call object level, reentrance 77 # returns again the same response 78 response_retry = await call 79 self.assertIs(response, response_retry) 80 81 async def test_call_rpc_error(self): 82 async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: 83 stub = test_pb2_grpc.TestServiceStub(channel) 84 85 call = stub.UnaryCall(messages_pb2.SimpleRequest()) 86 87 with self.assertRaises(aio.AioRpcError) as exception_context: 88 await call 89 90 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 91 exception_context.exception.code()) 92 93 self.assertTrue(call.done()) 94 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 95 96 async def test_call_code_awaitable(self): 97 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 98 self.assertEqual(await call.code(), grpc.StatusCode.OK) 99 100 async def test_call_details_awaitable(self): 101 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 102 self.assertEqual('', await call.details()) 103 104 async def test_call_initial_metadata_awaitable(self): 105 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 106 self.assertEqual(aio.Metadata(), await call.initial_metadata()) 107 108 async def test_call_trailing_metadata_awaitable(self): 109 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 110 self.assertEqual(aio.Metadata(), await call.trailing_metadata()) 111 112 async def test_call_initial_metadata_cancelable(self): 113 coro_started = asyncio.Event() 114 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 115 116 async def coro(): 117 coro_started.set() 118 await call.initial_metadata() 119 120 task = self.loop.create_task(coro()) 121 await coro_started.wait() 122 task.cancel() 123 124 # Test that initial metadata can still be asked thought 125 # a cancellation happened with the previous task 126 self.assertEqual(aio.Metadata(), await call.initial_metadata()) 127 128 async def test_call_initial_metadata_multiple_waiters(self): 129 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 130 131 async def coro(): 132 return await call.initial_metadata() 133 134 task1 = self.loop.create_task(coro()) 135 task2 = self.loop.create_task(coro()) 136 137 await call 138 expected = [aio.Metadata() for _ in range(2)] 139 self.assertEqual(expected, await asyncio.gather(*[task1, task2])) 140 141 async def test_call_code_cancelable(self): 142 coro_started = asyncio.Event() 143 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 144 145 async def coro(): 146 coro_started.set() 147 await call.code() 148 149 task = self.loop.create_task(coro()) 150 await coro_started.wait() 151 task.cancel() 152 153 # Test that code can still be asked thought 154 # a cancellation happened with the previous task 155 self.assertEqual(grpc.StatusCode.OK, await call.code()) 156 157 async def test_call_code_multiple_waiters(self): 158 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 159 160 async def coro(): 161 return await call.code() 162 163 task1 = self.loop.create_task(coro()) 164 task2 = self.loop.create_task(coro()) 165 166 await call 167 168 self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await 169 asyncio.gather(task1, task2)) 170 171 async def test_cancel_unary_unary(self): 172 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 173 174 self.assertFalse(call.cancelled()) 175 176 self.assertTrue(call.cancel()) 177 self.assertFalse(call.cancel()) 178 179 with self.assertRaises(asyncio.CancelledError): 180 await call 181 182 # The info in the RpcError should match the info in Call object. 183 self.assertTrue(call.cancelled()) 184 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) 185 self.assertEqual(await call.details(), 186 'Locally cancelled by application!') 187 188 async def test_cancel_unary_unary_in_task(self): 189 coro_started = asyncio.Event() 190 call = self._stub.EmptyCall(messages_pb2.SimpleRequest()) 191 192 async def another_coro(): 193 coro_started.set() 194 await call 195 196 task = self.loop.create_task(another_coro()) 197 await coro_started.wait() 198 199 self.assertFalse(task.done()) 200 task.cancel() 201 202 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 203 204 with self.assertRaises(asyncio.CancelledError): 205 await task 206 207 async def test_passing_credentials_fails_over_insecure_channel(self): 208 call_credentials = grpc.composite_call_credentials( 209 grpc.access_token_call_credentials("abc"), 210 grpc.access_token_call_credentials("def"), 211 ) 212 with self.assertRaisesRegex( 213 aio.UsageError, 214 "Call credentials are only valid on secure channels"): 215 self._stub.UnaryCall(messages_pb2.SimpleRequest(), 216 credentials=call_credentials) 217 218 219class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): 220 221 async def test_call_rpc_error(self): 222 channel = aio.insecure_channel(UNREACHABLE_TARGET) 223 request = messages_pb2.StreamingOutputCallRequest() 224 stub = test_pb2_grpc.TestServiceStub(channel) 225 call = stub.StreamingOutputCall(request) 226 227 with self.assertRaises(aio.AioRpcError) as exception_context: 228 async for response in call: 229 pass 230 231 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 232 exception_context.exception.code()) 233 234 self.assertTrue(call.done()) 235 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 236 await channel.close() 237 238 async def test_cancel_unary_stream(self): 239 # Prepares the request 240 request = messages_pb2.StreamingOutputCallRequest() 241 for _ in range(_NUM_STREAM_RESPONSES): 242 request.response_parameters.append( 243 messages_pb2.ResponseParameters( 244 size=_RESPONSE_PAYLOAD_SIZE, 245 interval_us=_RESPONSE_INTERVAL_US, 246 )) 247 248 # Invokes the actual RPC 249 call = self._stub.StreamingOutputCall(request) 250 self.assertFalse(call.cancelled()) 251 252 response = await call.read() 253 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) 254 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 255 256 self.assertTrue(call.cancel()) 257 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 258 self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await 259 call.details()) 260 self.assertFalse(call.cancel()) 261 262 with self.assertRaises(asyncio.CancelledError): 263 await call.read() 264 self.assertTrue(call.cancelled()) 265 266 async def test_multiple_cancel_unary_stream(self): 267 # Prepares the request 268 request = messages_pb2.StreamingOutputCallRequest() 269 for _ in range(_NUM_STREAM_RESPONSES): 270 request.response_parameters.append( 271 messages_pb2.ResponseParameters( 272 size=_RESPONSE_PAYLOAD_SIZE, 273 interval_us=_RESPONSE_INTERVAL_US, 274 )) 275 276 # Invokes the actual RPC 277 call = self._stub.StreamingOutputCall(request) 278 self.assertFalse(call.cancelled()) 279 280 response = await call.read() 281 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) 282 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 283 284 self.assertTrue(call.cancel()) 285 self.assertFalse(call.cancel()) 286 self.assertFalse(call.cancel()) 287 self.assertFalse(call.cancel()) 288 289 with self.assertRaises(asyncio.CancelledError): 290 await call.read() 291 292 async def test_early_cancel_unary_stream(self): 293 """Test cancellation before receiving messages.""" 294 # Prepares the request 295 request = messages_pb2.StreamingOutputCallRequest() 296 for _ in range(_NUM_STREAM_RESPONSES): 297 request.response_parameters.append( 298 messages_pb2.ResponseParameters( 299 size=_RESPONSE_PAYLOAD_SIZE, 300 interval_us=_RESPONSE_INTERVAL_US, 301 )) 302 303 # Invokes the actual RPC 304 call = self._stub.StreamingOutputCall(request) 305 306 self.assertFalse(call.cancelled()) 307 self.assertTrue(call.cancel()) 308 self.assertFalse(call.cancel()) 309 310 with self.assertRaises(asyncio.CancelledError): 311 await call.read() 312 313 self.assertTrue(call.cancelled()) 314 315 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 316 self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await 317 call.details()) 318 319 async def test_late_cancel_unary_stream(self): 320 """Test cancellation after received all messages.""" 321 # Prepares the request 322 request = messages_pb2.StreamingOutputCallRequest() 323 for _ in range(_NUM_STREAM_RESPONSES): 324 request.response_parameters.append( 325 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 326 327 # Invokes the actual RPC 328 call = self._stub.StreamingOutputCall(request) 329 330 for _ in range(_NUM_STREAM_RESPONSES): 331 response = await call.read() 332 self.assertIs(type(response), 333 messages_pb2.StreamingOutputCallResponse) 334 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 335 336 # After all messages received, it is possible that the final state 337 # is received or on its way. It's basically a data race, so our 338 # expectation here is do not crash :) 339 call.cancel() 340 self.assertIn(await call.code(), 341 [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]) 342 343 async def test_too_many_reads_unary_stream(self): 344 """Test calling read after received all messages fails.""" 345 # Prepares the request 346 request = messages_pb2.StreamingOutputCallRequest() 347 for _ in range(_NUM_STREAM_RESPONSES): 348 request.response_parameters.append( 349 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 350 351 # Invokes the actual RPC 352 call = self._stub.StreamingOutputCall(request) 353 354 for _ in range(_NUM_STREAM_RESPONSES): 355 response = await call.read() 356 self.assertIs(type(response), 357 messages_pb2.StreamingOutputCallResponse) 358 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 359 self.assertIs(await call.read(), aio.EOF) 360 361 # After the RPC is finished, further reads will lead to exception. 362 self.assertEqual(await call.code(), grpc.StatusCode.OK) 363 self.assertIs(await call.read(), aio.EOF) 364 365 async def test_unary_stream_async_generator(self): 366 """Sunny day test case for unary_stream.""" 367 # Prepares the request 368 request = messages_pb2.StreamingOutputCallRequest() 369 for _ in range(_NUM_STREAM_RESPONSES): 370 request.response_parameters.append( 371 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 372 373 # Invokes the actual RPC 374 call = self._stub.StreamingOutputCall(request) 375 self.assertFalse(call.cancelled()) 376 377 async for response in call: 378 self.assertIs(type(response), 379 messages_pb2.StreamingOutputCallResponse) 380 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 381 382 self.assertEqual(await call.code(), grpc.StatusCode.OK) 383 384 async def test_cancel_unary_stream_in_task_using_read(self): 385 coro_started = asyncio.Event() 386 387 # Configs the server method to block forever 388 request = messages_pb2.StreamingOutputCallRequest() 389 request.response_parameters.append( 390 messages_pb2.ResponseParameters( 391 size=_RESPONSE_PAYLOAD_SIZE, 392 interval_us=_INFINITE_INTERVAL_US, 393 )) 394 395 # Invokes the actual RPC 396 call = self._stub.StreamingOutputCall(request) 397 398 async def another_coro(): 399 coro_started.set() 400 await call.read() 401 402 task = self.loop.create_task(another_coro()) 403 await coro_started.wait() 404 405 self.assertFalse(task.done()) 406 task.cancel() 407 408 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 409 410 with self.assertRaises(asyncio.CancelledError): 411 await task 412 413 async def test_cancel_unary_stream_in_task_using_async_for(self): 414 coro_started = asyncio.Event() 415 416 # Configs the server method to block forever 417 request = messages_pb2.StreamingOutputCallRequest() 418 request.response_parameters.append( 419 messages_pb2.ResponseParameters( 420 size=_RESPONSE_PAYLOAD_SIZE, 421 interval_us=_INFINITE_INTERVAL_US, 422 )) 423 424 # Invokes the actual RPC 425 call = self._stub.StreamingOutputCall(request) 426 427 async def another_coro(): 428 coro_started.set() 429 async for _ in call: 430 pass 431 432 task = self.loop.create_task(another_coro()) 433 await coro_started.wait() 434 435 self.assertFalse(task.done()) 436 task.cancel() 437 438 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 439 440 with self.assertRaises(asyncio.CancelledError): 441 await task 442 443 async def test_time_remaining(self): 444 request = messages_pb2.StreamingOutputCallRequest() 445 # First message comes back immediately 446 request.response_parameters.append( 447 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 448 # Second message comes back after a unit of wait time 449 request.response_parameters.append( 450 messages_pb2.ResponseParameters( 451 size=_RESPONSE_PAYLOAD_SIZE, 452 interval_us=_RESPONSE_INTERVAL_US, 453 )) 454 455 call = self._stub.StreamingOutputCall(request, 456 timeout=_SHORT_TIMEOUT_S * 2) 457 458 response = await call.read() 459 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 460 461 # Should be around the same as the timeout 462 remained_time = call.time_remaining() 463 self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2) 464 self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2) 465 466 response = await call.read() 467 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 468 469 # Should be around the timeout minus a unit of wait time 470 remained_time = call.time_remaining() 471 self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2) 472 self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2) 473 474 self.assertEqual(grpc.StatusCode.OK, await call.code()) 475 476 async def test_empty_responses(self): 477 # Prepares the request 478 request = messages_pb2.StreamingOutputCallRequest() 479 for _ in range(_NUM_STREAM_RESPONSES): 480 request.response_parameters.append( 481 messages_pb2.ResponseParameters()) 482 483 # Invokes the actual RPC 484 call = self._stub.StreamingOutputCall(request) 485 486 for _ in range(_NUM_STREAM_RESPONSES): 487 response = await call.read() 488 self.assertIs(type(response), 489 messages_pb2.StreamingOutputCallResponse) 490 self.assertEqual(b'', response.SerializeToString()) 491 492 self.assertEqual(grpc.StatusCode.OK, await call.code()) 493 494 495class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): 496 497 async def test_cancel_stream_unary(self): 498 call = self._stub.StreamingInputCall() 499 500 # Prepares the request 501 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 502 request = messages_pb2.StreamingInputCallRequest(payload=payload) 503 504 # Sends out requests 505 for _ in range(_NUM_STREAM_RESPONSES): 506 await call.write(request) 507 508 # Cancels the RPC 509 self.assertFalse(call.done()) 510 self.assertFalse(call.cancelled()) 511 self.assertTrue(call.cancel()) 512 self.assertTrue(call.cancelled()) 513 514 await call.done_writing() 515 516 with self.assertRaises(asyncio.CancelledError): 517 await call 518 519 async def test_early_cancel_stream_unary(self): 520 call = self._stub.StreamingInputCall() 521 522 # Cancels the RPC 523 self.assertFalse(call.done()) 524 self.assertFalse(call.cancelled()) 525 self.assertTrue(call.cancel()) 526 self.assertTrue(call.cancelled()) 527 528 with self.assertRaises(asyncio.InvalidStateError): 529 await call.write(messages_pb2.StreamingInputCallRequest()) 530 531 # Should be no-op 532 await call.done_writing() 533 534 with self.assertRaises(asyncio.CancelledError): 535 await call 536 537 async def test_write_after_done_writing(self): 538 call = self._stub.StreamingInputCall() 539 540 # Prepares the request 541 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 542 request = messages_pb2.StreamingInputCallRequest(payload=payload) 543 544 # Sends out requests 545 for _ in range(_NUM_STREAM_RESPONSES): 546 await call.write(request) 547 548 # Should be no-op 549 await call.done_writing() 550 551 with self.assertRaises(asyncio.InvalidStateError): 552 await call.write(messages_pb2.StreamingInputCallRequest()) 553 554 response = await call 555 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 556 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 557 response.aggregated_payload_size) 558 559 self.assertEqual(await call.code(), grpc.StatusCode.OK) 560 561 async def test_error_in_async_generator(self): 562 # Server will pause between responses 563 request = messages_pb2.StreamingOutputCallRequest() 564 request.response_parameters.append( 565 messages_pb2.ResponseParameters( 566 size=_RESPONSE_PAYLOAD_SIZE, 567 interval_us=_RESPONSE_INTERVAL_US, 568 )) 569 570 # We expect the request iterator to receive the exception 571 request_iterator_received_the_exception = asyncio.Event() 572 573 async def request_iterator(): 574 with self.assertRaises(asyncio.CancelledError): 575 for _ in range(_NUM_STREAM_RESPONSES): 576 yield request 577 await asyncio.sleep(_SHORT_TIMEOUT_S) 578 request_iterator_received_the_exception.set() 579 580 call = self._stub.StreamingInputCall(request_iterator()) 581 582 # Cancel the RPC after at least one response 583 async def cancel_later(): 584 await asyncio.sleep(_SHORT_TIMEOUT_S * 2) 585 call.cancel() 586 587 cancel_later_task = self.loop.create_task(cancel_later()) 588 589 with self.assertRaises(asyncio.CancelledError): 590 await call 591 592 await request_iterator_received_the_exception.wait() 593 594 # No failures in the cancel later task! 595 await cancel_later_task 596 597 async def test_normal_iterable_requests(self): 598 # Prepares the request 599 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 600 request = messages_pb2.StreamingInputCallRequest(payload=payload) 601 requests = [request] * _NUM_STREAM_RESPONSES 602 603 # Sends out requests 604 call = self._stub.StreamingInputCall(requests) 605 606 # RPC should succeed 607 response = await call 608 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 609 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 610 response.aggregated_payload_size) 611 612 self.assertEqual(await call.code(), grpc.StatusCode.OK) 613 614 async def test_call_rpc_error(self): 615 async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: 616 stub = test_pb2_grpc.TestServiceStub(channel) 617 618 # The error should be raised automatically without any traffic. 619 call = stub.StreamingInputCall() 620 with self.assertRaises(aio.AioRpcError) as exception_context: 621 await call 622 623 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 624 exception_context.exception.code()) 625 626 self.assertTrue(call.done()) 627 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 628 629 async def test_timeout(self): 630 call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S) 631 632 # The error should be raised automatically without any traffic. 633 with self.assertRaises(aio.AioRpcError) as exception_context: 634 await call 635 636 rpc_error = exception_context.exception 637 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) 638 self.assertTrue(call.done()) 639 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code()) 640 641 642# Prepares the request that stream in a ping-pong manner. 643_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() 644_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( 645 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) 646_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = messages_pb2.StreamingOutputCallRequest( 647) 648_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append( 649 messages_pb2.ResponseParameters()) 650 651 652class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): 653 654 async def test_cancel(self): 655 # Invokes the actual RPC 656 call = self._stub.FullDuplexCall() 657 658 for _ in range(_NUM_STREAM_RESPONSES): 659 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 660 response = await call.read() 661 self.assertIsInstance(response, 662 messages_pb2.StreamingOutputCallResponse) 663 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 664 665 # Cancels the RPC 666 self.assertFalse(call.done()) 667 self.assertFalse(call.cancelled()) 668 self.assertTrue(call.cancel()) 669 self.assertTrue(call.cancelled()) 670 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 671 672 async def test_cancel_with_pending_read(self): 673 call = self._stub.FullDuplexCall() 674 675 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 676 677 # Cancels the RPC 678 self.assertFalse(call.done()) 679 self.assertFalse(call.cancelled()) 680 self.assertTrue(call.cancel()) 681 self.assertTrue(call.cancelled()) 682 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 683 684 async def test_cancel_with_ongoing_read(self): 685 call = self._stub.FullDuplexCall() 686 coro_started = asyncio.Event() 687 688 async def read_coro(): 689 coro_started.set() 690 await call.read() 691 692 read_task = self.loop.create_task(read_coro()) 693 await coro_started.wait() 694 self.assertFalse(read_task.done()) 695 696 # Cancels the RPC 697 self.assertFalse(call.done()) 698 self.assertFalse(call.cancelled()) 699 self.assertTrue(call.cancel()) 700 self.assertTrue(call.cancelled()) 701 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 702 703 async def test_early_cancel(self): 704 call = self._stub.FullDuplexCall() 705 706 # Cancels the RPC 707 self.assertFalse(call.done()) 708 self.assertFalse(call.cancelled()) 709 self.assertTrue(call.cancel()) 710 self.assertTrue(call.cancelled()) 711 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 712 713 async def test_cancel_after_done_writing(self): 714 call = self._stub.FullDuplexCall() 715 await call.done_writing() 716 717 # Cancels the RPC 718 self.assertFalse(call.done()) 719 self.assertFalse(call.cancelled()) 720 self.assertTrue(call.cancel()) 721 self.assertTrue(call.cancelled()) 722 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 723 724 async def test_late_cancel(self): 725 call = self._stub.FullDuplexCall() 726 await call.done_writing() 727 self.assertEqual(grpc.StatusCode.OK, await call.code()) 728 729 # Cancels the RPC 730 self.assertTrue(call.done()) 731 self.assertFalse(call.cancelled()) 732 self.assertFalse(call.cancel()) 733 self.assertFalse(call.cancelled()) 734 735 # Status is still OK 736 self.assertEqual(grpc.StatusCode.OK, await call.code()) 737 738 async def test_async_generator(self): 739 740 async def request_generator(): 741 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 742 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 743 744 call = self._stub.FullDuplexCall(request_generator()) 745 async for response in call: 746 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 747 748 self.assertEqual(await call.code(), grpc.StatusCode.OK) 749 750 async def test_too_many_reads(self): 751 752 async def request_generator(): 753 for _ in range(_NUM_STREAM_RESPONSES): 754 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 755 756 call = self._stub.FullDuplexCall(request_generator()) 757 for _ in range(_NUM_STREAM_RESPONSES): 758 response = await call.read() 759 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 760 self.assertIs(await call.read(), aio.EOF) 761 762 self.assertEqual(await call.code(), grpc.StatusCode.OK) 763 # After the RPC finished, the read should also produce EOF 764 self.assertIs(await call.read(), aio.EOF) 765 766 async def test_read_write_after_done_writing(self): 767 call = self._stub.FullDuplexCall() 768 769 # Writes two requests, and pending two requests 770 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 771 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 772 await call.done_writing() 773 774 # Further write should fail 775 with self.assertRaises(asyncio.InvalidStateError): 776 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 777 778 # But read should be unaffected 779 response = await call.read() 780 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 781 response = await call.read() 782 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 783 784 self.assertEqual(await call.code(), grpc.StatusCode.OK) 785 786 async def test_error_in_async_generator(self): 787 # Server will pause between responses 788 request = messages_pb2.StreamingOutputCallRequest() 789 request.response_parameters.append( 790 messages_pb2.ResponseParameters( 791 size=_RESPONSE_PAYLOAD_SIZE, 792 interval_us=_RESPONSE_INTERVAL_US, 793 )) 794 795 # We expect the request iterator to receive the exception 796 request_iterator_received_the_exception = asyncio.Event() 797 798 async def request_iterator(): 799 with self.assertRaises(asyncio.CancelledError): 800 for _ in range(_NUM_STREAM_RESPONSES): 801 yield request 802 await asyncio.sleep(_SHORT_TIMEOUT_S) 803 request_iterator_received_the_exception.set() 804 805 call = self._stub.FullDuplexCall(request_iterator()) 806 807 # Cancel the RPC after at least one response 808 async def cancel_later(): 809 await asyncio.sleep(_SHORT_TIMEOUT_S * 2) 810 call.cancel() 811 812 cancel_later_task = self.loop.create_task(cancel_later()) 813 814 with self.assertRaises(asyncio.CancelledError): 815 async for response in call: 816 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, 817 len(response.payload.body)) 818 819 await request_iterator_received_the_exception.wait() 820 821 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 822 # No failures in the cancel later task! 823 await cancel_later_task 824 825 async def test_normal_iterable_requests(self): 826 requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES 827 828 call = self._stub.FullDuplexCall(iter(requests)) 829 async for response in call: 830 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 831 832 self.assertEqual(await call.code(), grpc.StatusCode.OK) 833 834 async def test_empty_ping_pong(self): 835 call = self._stub.FullDuplexCall() 836 for _ in range(_NUM_STREAM_RESPONSES): 837 await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE) 838 response = await call.read() 839 self.assertEqual(b'', response.SerializeToString()) 840 await call.done_writing() 841 self.assertEqual(await call.code(), grpc.StatusCode.OK) 842 843 844if __name__ == '__main__': 845 logging.basicConfig(level=logging.DEBUG) 846 unittest.main(verbosity=2) 847