1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The Python AsyncIO Benchmark Clients."""
15
16import abc
17import asyncio
18import logging
19import random
20import time
21
22import grpc
23from grpc.experimental import aio
24
25from src.proto.grpc.testing import benchmark_service_pb2_grpc
26from src.proto.grpc.testing import control_pb2
27from src.proto.grpc.testing import messages_pb2
28from tests.qps import histogram
29from tests.unit import resources
30
31
32class GenericStub(object):
33
34    def __init__(self, channel: aio.Channel):
35        self.UnaryCall = channel.unary_unary(
36            '/grpc.testing.BenchmarkService/UnaryCall')
37        self.StreamingFromServer = channel.unary_stream(
38            '/grpc.testing.BenchmarkService/StreamingFromServer')
39        self.StreamingCall = channel.stream_stream(
40            '/grpc.testing.BenchmarkService/StreamingCall')
41
42
43class BenchmarkClient(abc.ABC):
44    """Benchmark client interface that exposes a non-blocking send_request()."""
45
46    def __init__(self, address: str, config: control_pb2.ClientConfig,
47                 hist: histogram.Histogram):
48        # Disables underlying reuse of subchannels
49        unique_option = (('iv', random.random()),)
50
51        # Parses the channel argument from config
52        channel_args = tuple(
53            (arg.name, arg.str_value) if arg.HasField('str_value') else (
54                arg.name, int(arg.int_value)) for arg in config.channel_args)
55
56        # Creates the channel
57        if config.HasField('security_params'):
58            channel_credentials = grpc.ssl_channel_credentials(
59                resources.test_root_certificates(),)
60            server_host_override_option = ((
61                'grpc.ssl_target_name_override',
62                config.security_params.server_host_override,
63            ),)
64            self._channel = aio.secure_channel(
65                address, channel_credentials,
66                unique_option + channel_args + server_host_override_option)
67        else:
68            self._channel = aio.insecure_channel(address,
69                                                 options=unique_option +
70                                                 channel_args)
71
72        # Creates the stub
73        if config.payload_config.WhichOneof('payload') == 'simple_params':
74            self._generic = False
75            self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(
76                self._channel)
77            payload = messages_pb2.Payload(
78                body=b'\0' * config.payload_config.simple_params.req_size)
79            self._request = messages_pb2.SimpleRequest(
80                payload=payload,
81                response_size=config.payload_config.simple_params.resp_size)
82        else:
83            self._generic = True
84            self._stub = GenericStub(self._channel)
85            self._request = b'\0' * config.payload_config.bytebuf_params.req_size
86
87        self._hist = hist
88        self._response_callbacks = []
89        self._concurrency = config.outstanding_rpcs_per_channel
90
91    async def run(self) -> None:
92        await self._channel.channel_ready()
93
94    async def stop(self) -> None:
95        await self._channel.close()
96
97    def _record_query_time(self, query_time: float) -> None:
98        self._hist.add(query_time * 1e9)
99
100
101class UnaryAsyncBenchmarkClient(BenchmarkClient):
102
103    def __init__(self, address: str, config: control_pb2.ClientConfig,
104                 hist: histogram.Histogram):
105        super().__init__(address, config, hist)
106        self._running = None
107        self._stopped = asyncio.Event()
108
109    async def _send_request(self):
110        start_time = time.monotonic()
111        await self._stub.UnaryCall(self._request)
112        self._record_query_time(time.monotonic() - start_time)
113
114    async def _send_indefinitely(self) -> None:
115        while self._running:
116            await self._send_request()
117
118    async def run(self) -> None:
119        await super().run()
120        self._running = True
121        senders = (self._send_indefinitely() for _ in range(self._concurrency))
122        await asyncio.gather(*senders)
123        self._stopped.set()
124
125    async def stop(self) -> None:
126        self._running = False
127        await self._stopped.wait()
128        await super().stop()
129
130
131class StreamingAsyncBenchmarkClient(BenchmarkClient):
132
133    def __init__(self, address: str, config: control_pb2.ClientConfig,
134                 hist: histogram.Histogram):
135        super().__init__(address, config, hist)
136        self._running = None
137        self._stopped = asyncio.Event()
138
139    async def _one_streaming_call(self):
140        call = self._stub.StreamingCall()
141        while self._running:
142            start_time = time.time()
143            await call.write(self._request)
144            await call.read()
145            self._record_query_time(time.time() - start_time)
146        await call.done_writing()
147
148    async def run(self):
149        await super().run()
150        self._running = True
151        senders = (self._one_streaming_call() for _ in range(self._concurrency))
152        await asyncio.gather(*senders)
153        self._stopped.set()
154
155    async def stop(self):
156        self._running = False
157        await self._stopped.wait()
158        await super().stop()
159
160
161class ServerStreamingAsyncBenchmarkClient(BenchmarkClient):
162
163    def __init__(self, address: str, config: control_pb2.ClientConfig,
164                 hist: histogram.Histogram):
165        super().__init__(address, config, hist)
166        self._running = None
167        self._stopped = asyncio.Event()
168
169    async def _one_server_streaming_call(self):
170        call = self._stub.StreamingFromServer(self._request)
171        while self._running:
172            start_time = time.time()
173            await call.read()
174            self._record_query_time(time.time() - start_time)
175
176    async def run(self):
177        await super().run()
178        self._running = True
179        senders = (
180            self._one_server_streaming_call() for _ in range(self._concurrency))
181        await asyncio.gather(*senders)
182        self._stopped.set()
183
184    async def stop(self):
185        self._running = False
186        await self._stopped.wait()
187        await super().stop()
188