1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from concurrent import futures
16import multiprocessing
17import random
18import threading
19import time
20
21import grpc
22
23from src.proto.grpc.testing import benchmark_service_pb2_grpc
24from src.proto.grpc.testing import control_pb2
25from src.proto.grpc.testing import stats_pb2
26from src.proto.grpc.testing import worker_service_pb2_grpc
27from tests.qps import benchmark_client
28from tests.qps import benchmark_server
29from tests.qps import client_runner
30from tests.qps import histogram
31from tests.unit import resources
32from tests.unit import test_common
33
34
35class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer):
36    """Python Worker Server implementation."""
37
38    def __init__(self, server_port=None):
39        self._quit_event = threading.Event()
40        self._server_port = server_port
41
42    def RunServer(self, request_iterator, context):
43        config = next(request_iterator).setup  #pylint: disable=stop-iteration-return
44        server, port = self._create_server(config)
45        cores = multiprocessing.cpu_count()
46        server.start()
47        start_time = time.time()
48        yield self._get_server_status(start_time, start_time, port, cores)
49
50        for request in request_iterator:
51            end_time = time.time()
52            status = self._get_server_status(start_time, end_time, port, cores)
53            if request.mark.reset:
54                start_time = end_time
55            yield status
56        server.stop(None)
57
58    def _get_server_status(self, start_time, end_time, port, cores):
59        end_time = time.time()
60        elapsed_time = end_time - start_time
61        stats = stats_pb2.ServerStats(time_elapsed=elapsed_time,
62                                      time_user=elapsed_time,
63                                      time_system=elapsed_time)
64        return control_pb2.ServerStatus(stats=stats, port=port, cores=cores)
65
66    def _create_server(self, config):
67        if config.async_server_threads == 0:
68            # This is the default concurrent.futures thread pool size, but
69            # None doesn't seem to work
70            server_threads = multiprocessing.cpu_count() * 5
71        else:
72            server_threads = config.async_server_threads
73        server = test_common.test_server(max_workers=server_threads)
74        if config.server_type == control_pb2.ASYNC_SERVER:
75            servicer = benchmark_server.BenchmarkServer()
76            benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
77                servicer, server)
78        elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
79            resp_size = config.payload_config.bytebuf_params.resp_size
80            servicer = benchmark_server.GenericBenchmarkServer(resp_size)
81            method_implementations = {
82                'StreamingCall':
83                    grpc.stream_stream_rpc_method_handler(servicer.StreamingCall
84                                                         ),
85                'UnaryCall':
86                    grpc.unary_unary_rpc_method_handler(servicer.UnaryCall),
87            }
88            handler = grpc.method_handlers_generic_handler(
89                'grpc.testing.BenchmarkService', method_implementations)
90            server.add_generic_rpc_handlers((handler,))
91        else:
92            raise Exception('Unsupported server type {}'.format(
93                config.server_type))
94
95        if self._server_port is not None and config.port == 0:
96            server_port = self._server_port
97        else:
98            server_port = config.port
99
100        if config.HasField('security_params'):  # Use SSL
101            server_creds = grpc.ssl_server_credentials(
102                ((resources.private_key(), resources.certificate_chain()),))
103            port = server.add_secure_port('[::]:{}'.format(server_port),
104                                          server_creds)
105        else:
106            port = server.add_insecure_port('[::]:{}'.format(server_port))
107
108        return (server, port)
109
110    def RunClient(self, request_iterator, context):
111        config = next(request_iterator).setup  #pylint: disable=stop-iteration-return
112        client_runners = []
113        qps_data = histogram.Histogram(config.histogram_params.resolution,
114                                       config.histogram_params.max_possible)
115        start_time = time.time()
116
117        # Create a client for each channel
118        for i in range(config.client_channels):
119            server = config.server_targets[i % len(config.server_targets)]
120            runner = self._create_client_runner(server, config, qps_data)
121            client_runners.append(runner)
122            runner.start()
123
124        end_time = time.time()
125        yield self._get_client_status(start_time, end_time, qps_data)
126
127        # Respond to stat requests
128        for request in request_iterator:
129            end_time = time.time()
130            status = self._get_client_status(start_time, end_time, qps_data)
131            if request.mark.reset:
132                qps_data.reset()
133                start_time = time.time()
134            yield status
135
136        # Cleanup the clients
137        for runner in client_runners:
138            runner.stop()
139
140    def _get_client_status(self, start_time, end_time, qps_data):
141        latencies = qps_data.get_data()
142        end_time = time.time()
143        elapsed_time = end_time - start_time
144        stats = stats_pb2.ClientStats(latencies=latencies,
145                                      time_elapsed=elapsed_time,
146                                      time_user=elapsed_time,
147                                      time_system=elapsed_time)
148        return control_pb2.ClientStatus(stats=stats)
149
150    def _create_client_runner(self, server, config, qps_data):
151        no_ping_pong = False
152        if config.client_type == control_pb2.SYNC_CLIENT:
153            if config.rpc_type == control_pb2.UNARY:
154                client = benchmark_client.UnarySyncBenchmarkClient(
155                    server, config, qps_data)
156            elif config.rpc_type == control_pb2.STREAMING:
157                client = benchmark_client.StreamingSyncBenchmarkClient(
158                    server, config, qps_data)
159            elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER:
160                no_ping_pong = True
161                client = benchmark_client.ServerStreamingSyncBenchmarkClient(
162                    server, config, qps_data)
163        elif config.client_type == control_pb2.ASYNC_CLIENT:
164            if config.rpc_type == control_pb2.UNARY:
165                client = benchmark_client.UnaryAsyncBenchmarkClient(
166                    server, config, qps_data)
167            else:
168                raise Exception('Async streaming client not supported')
169        else:
170            raise Exception('Unsupported client type {}'.format(
171                config.client_type))
172
173        # In multi-channel tests, we split the load across all channels
174        load_factor = float(config.client_channels)
175        if config.load_params.WhichOneof('load') == 'closed_loop':
176            runner = client_runner.ClosedLoopClientRunner(
177                client, config.outstanding_rpcs_per_channel, no_ping_pong)
178        else:  # Open loop Poisson
179            alpha = config.load_params.poisson.offered_load / load_factor
180
181            def poisson():
182                while True:
183                    yield random.expovariate(alpha)
184
185            runner = client_runner.OpenLoopClientRunner(client, poisson())
186
187        return runner
188
189    def CoreCount(self, request, context):
190        return control_pb2.CoreResponse(cores=multiprocessing.cpu_count())
191
192    def QuitWorker(self, request, context):
193        self._quit_event.set()
194        return control_pb2.Void()
195
196    def wait_for_quit(self):
197        self._quit_event.wait()
198