1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""The Python AsyncIO Benchmark Clients.""" 15 16import abc 17import asyncio 18import logging 19import random 20import time 21 22import grpc 23from grpc.experimental import aio 24 25from src.proto.grpc.testing import benchmark_service_pb2_grpc 26from src.proto.grpc.testing import control_pb2 27from src.proto.grpc.testing import messages_pb2 28from tests.qps import histogram 29from tests.unit import resources 30 31 32class GenericStub(object): 33 34 def __init__(self, channel: aio.Channel): 35 self.UnaryCall = channel.unary_unary( 36 '/grpc.testing.BenchmarkService/UnaryCall') 37 self.StreamingFromServer = channel.unary_stream( 38 '/grpc.testing.BenchmarkService/StreamingFromServer') 39 self.StreamingCall = channel.stream_stream( 40 '/grpc.testing.BenchmarkService/StreamingCall') 41 42 43class BenchmarkClient(abc.ABC): 44 """Benchmark client interface that exposes a non-blocking send_request().""" 45 46 def __init__(self, address: str, config: control_pb2.ClientConfig, 47 hist: histogram.Histogram): 48 # Disables underlying reuse of subchannels 49 unique_option = (('iv', random.random()),) 50 51 # Parses the channel argument from config 52 channel_args = tuple( 53 (arg.name, arg.str_value) if arg.HasField('str_value') else ( 54 arg.name, int(arg.int_value)) for arg in config.channel_args) 55 56 # Creates the channel 57 if config.HasField('security_params'): 58 channel_credentials = grpc.ssl_channel_credentials( 59 resources.test_root_certificates(),) 60 server_host_override_option = (( 61 'grpc.ssl_target_name_override', 62 config.security_params.server_host_override, 63 ),) 64 self._channel = aio.secure_channel( 65 address, channel_credentials, 66 unique_option + channel_args + server_host_override_option) 67 else: 68 self._channel = aio.insecure_channel(address, 69 options=unique_option + 70 channel_args) 71 72 # Creates the stub 73 if config.payload_config.WhichOneof('payload') == 'simple_params': 74 self._generic = False 75 self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( 76 self._channel) 77 payload = messages_pb2.Payload( 78 body=b'\0' * config.payload_config.simple_params.req_size) 79 self._request = messages_pb2.SimpleRequest( 80 payload=payload, 81 response_size=config.payload_config.simple_params.resp_size) 82 else: 83 self._generic = True 84 self._stub = GenericStub(self._channel) 85 self._request = b'\0' * config.payload_config.bytebuf_params.req_size 86 87 self._hist = hist 88 self._response_callbacks = [] 89 self._concurrency = config.outstanding_rpcs_per_channel 90 91 async def run(self) -> None: 92 await self._channel.channel_ready() 93 94 async def stop(self) -> None: 95 await self._channel.close() 96 97 def _record_query_time(self, query_time: float) -> None: 98 self._hist.add(query_time * 1e9) 99 100 101class UnaryAsyncBenchmarkClient(BenchmarkClient): 102 103 def __init__(self, address: str, config: control_pb2.ClientConfig, 104 hist: histogram.Histogram): 105 super().__init__(address, config, hist) 106 self._running = None 107 self._stopped = asyncio.Event() 108 109 async def _send_request(self): 110 start_time = time.monotonic() 111 await self._stub.UnaryCall(self._request) 112 self._record_query_time(time.monotonic() - start_time) 113 114 async def _send_indefinitely(self) -> None: 115 while self._running: 116 await self._send_request() 117 118 async def run(self) -> None: 119 await super().run() 120 self._running = True 121 senders = (self._send_indefinitely() for _ in range(self._concurrency)) 122 await asyncio.gather(*senders) 123 self._stopped.set() 124 125 async def stop(self) -> None: 126 self._running = False 127 await self._stopped.wait() 128 await super().stop() 129 130 131class StreamingAsyncBenchmarkClient(BenchmarkClient): 132 133 def __init__(self, address: str, config: control_pb2.ClientConfig, 134 hist: histogram.Histogram): 135 super().__init__(address, config, hist) 136 self._running = None 137 self._stopped = asyncio.Event() 138 139 async def _one_streaming_call(self): 140 call = self._stub.StreamingCall() 141 while self._running: 142 start_time = time.time() 143 await call.write(self._request) 144 await call.read() 145 self._record_query_time(time.time() - start_time) 146 await call.done_writing() 147 148 async def run(self): 149 await super().run() 150 self._running = True 151 senders = (self._one_streaming_call() for _ in range(self._concurrency)) 152 await asyncio.gather(*senders) 153 self._stopped.set() 154 155 async def stop(self): 156 self._running = False 157 await self._stopped.wait() 158 await super().stop() 159 160 161class ServerStreamingAsyncBenchmarkClient(BenchmarkClient): 162 163 def __init__(self, address: str, config: control_pb2.ClientConfig, 164 hist: histogram.Histogram): 165 super().__init__(address, config, hist) 166 self._running = None 167 self._stopped = asyncio.Event() 168 169 async def _one_server_streaming_call(self): 170 call = self._stub.StreamingFromServer(self._request) 171 while self._running: 172 start_time = time.time() 173 await call.read() 174 self._record_query_time(time.time() - start_time) 175 176 async def run(self): 177 await super().run() 178 self._running = True 179 senders = ( 180 self._one_server_streaming_call() for _ in range(self._concurrency)) 181 await asyncio.gather(*senders) 182 self._stopped.set() 183 184 async def stop(self): 185 self._running = False 186 await self._stopped.wait() 187 await super().stop() 188