1/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// client starts an interop client to do stress test and a metrics server to report qps.
20package main
21
22import (
23	"context"
24	"flag"
25	"fmt"
26	"math/rand"
27	"net"
28	"strconv"
29	"strings"
30	"sync"
31	"time"
32
33	"google.golang.org/grpc"
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/credentials"
36	"google.golang.org/grpc/grpclog"
37	"google.golang.org/grpc/interop"
38	"google.golang.org/grpc/status"
39	"google.golang.org/grpc/testdata"
40
41	testgrpc "google.golang.org/grpc/interop/grpc_testing"
42	metricspb "google.golang.org/grpc/stress/grpc_testing"
43)
44
45var (
46	serverAddresses      = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
47	testCases            = flag.String("test_cases", "", "a list of test cases along with the relative weights")
48	testDurationSecs     = flag.Int("test_duration_secs", -1, "test duration in seconds")
49	numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
50	numStubsPerChannel   = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
51	metricsPort          = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
52	useTLS               = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
53	testCA               = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
54	tlsServerName        = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
55	caFile               = flag.String("ca_file", "", "The file containing the CA root cert file")
56
57	logger = grpclog.Component("stress")
58)
59
60// testCaseWithWeight contains the test case type and its weight.
61type testCaseWithWeight struct {
62	name   string
63	weight int
64}
65
66// parseTestCases converts test case string to a list of struct testCaseWithWeight.
67func parseTestCases(testCaseString string) []testCaseWithWeight {
68	testCaseStrings := strings.Split(testCaseString, ",")
69	testCases := make([]testCaseWithWeight, len(testCaseStrings))
70	for i, str := range testCaseStrings {
71		testCase := strings.Split(str, ":")
72		if len(testCase) != 2 {
73			panic(fmt.Sprintf("invalid test case with weight: %s", str))
74		}
75		// Check if test case is supported.
76		switch testCase[0] {
77		case
78			"empty_unary",
79			"large_unary",
80			"client_streaming",
81			"server_streaming",
82			"ping_pong",
83			"empty_stream",
84			"timeout_on_sleeping_server",
85			"cancel_after_begin",
86			"cancel_after_first_response",
87			"status_code_and_message",
88			"custom_metadata":
89		default:
90			panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
91		}
92		testCases[i].name = testCase[0]
93		w, err := strconv.Atoi(testCase[1])
94		if err != nil {
95			panic(fmt.Sprintf("%v", err))
96		}
97		testCases[i].weight = w
98	}
99	return testCases
100}
101
102// weightedRandomTestSelector defines a weighted random selector for test case types.
103type weightedRandomTestSelector struct {
104	tests       []testCaseWithWeight
105	totalWeight int
106}
107
108// newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
109func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
110	var totalWeight int
111	for _, t := range tests {
112		totalWeight += t.weight
113	}
114	rand.Seed(time.Now().UnixNano())
115	return &weightedRandomTestSelector{tests, totalWeight}
116}
117
118func (selector weightedRandomTestSelector) getNextTest() string {
119	random := rand.Intn(selector.totalWeight)
120	var weightSofar int
121	for _, test := range selector.tests {
122		weightSofar += test.weight
123		if random < weightSofar {
124			return test.name
125		}
126	}
127	panic("no test case selected by weightedRandomTestSelector")
128}
129
130// gauge stores the qps of one interop client (one stub).
131type gauge struct {
132	mutex sync.RWMutex
133	val   int64
134}
135
136func (g *gauge) set(v int64) {
137	g.mutex.Lock()
138	defer g.mutex.Unlock()
139	g.val = v
140}
141
142func (g *gauge) get() int64 {
143	g.mutex.RLock()
144	defer g.mutex.RUnlock()
145	return g.val
146}
147
148// server implements metrics server functions.
149type server struct {
150	metricspb.UnimplementedMetricsServiceServer
151	mutex sync.RWMutex
152	// gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
153	gauges map[string]*gauge
154}
155
156// newMetricsServer returns a new metrics server.
157func newMetricsServer() *server {
158	return &server{gauges: make(map[string]*gauge)}
159}
160
161// GetAllGauges returns all gauges.
162func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
163	s.mutex.RLock()
164	defer s.mutex.RUnlock()
165
166	for name, gauge := range s.gauges {
167		if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
168			return err
169		}
170	}
171	return nil
172}
173
174// GetGauge returns the gauge for the given name.
175func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
176	s.mutex.RLock()
177	defer s.mutex.RUnlock()
178
179	if g, ok := s.gauges[in.Name]; ok {
180		return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
181	}
182	return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
183}
184
185// createGauge creates a gauge using the given name in metrics server.
186func (s *server) createGauge(name string) *gauge {
187	s.mutex.Lock()
188	defer s.mutex.Unlock()
189
190	if _, ok := s.gauges[name]; ok {
191		// gauge already exists.
192		panic(fmt.Sprintf("gauge %s already exists", name))
193	}
194	var g gauge
195	s.gauges[name] = &g
196	return &g
197}
198
199func startServer(server *server, port int) {
200	lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
201	if err != nil {
202		logger.Fatalf("failed to listen: %v", err)
203	}
204
205	s := grpc.NewServer()
206	metricspb.RegisterMetricsServiceServer(s, server)
207	s.Serve(lis)
208
209}
210
211// performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
212func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
213	client := testgrpc.NewTestServiceClient(conn)
214	var numCalls int64
215	startTime := time.Now()
216	for {
217		test := selector.getNextTest()
218		switch test {
219		case "empty_unary":
220			interop.DoEmptyUnaryCall(client, grpc.WaitForReady(true))
221		case "large_unary":
222			interop.DoLargeUnaryCall(client, grpc.WaitForReady(true))
223		case "client_streaming":
224			interop.DoClientStreaming(client, grpc.WaitForReady(true))
225		case "server_streaming":
226			interop.DoServerStreaming(client, grpc.WaitForReady(true))
227		case "ping_pong":
228			interop.DoPingPong(client, grpc.WaitForReady(true))
229		case "empty_stream":
230			interop.DoEmptyStream(client, grpc.WaitForReady(true))
231		case "timeout_on_sleeping_server":
232			interop.DoTimeoutOnSleepingServer(client, grpc.WaitForReady(true))
233		case "cancel_after_begin":
234			interop.DoCancelAfterBegin(client, grpc.WaitForReady(true))
235		case "cancel_after_first_response":
236			interop.DoCancelAfterFirstResponse(client, grpc.WaitForReady(true))
237		case "status_code_and_message":
238			interop.DoStatusCodeAndMessage(client, grpc.WaitForReady(true))
239		case "custom_metadata":
240			interop.DoCustomMetadata(client, grpc.WaitForReady(true))
241		}
242		numCalls++
243		gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
244
245		select {
246		case <-stop:
247			return
248		default:
249		}
250	}
251}
252
253func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
254	logger.Infof("server_addresses: %s", *serverAddresses)
255	logger.Infof("test_cases: %s", *testCases)
256	logger.Infof("test_duration_secs: %d", *testDurationSecs)
257	logger.Infof("num_channels_per_server: %d", *numChannelsPerServer)
258	logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel)
259	logger.Infof("metrics_port: %d", *metricsPort)
260	logger.Infof("use_tls: %t", *useTLS)
261	logger.Infof("use_test_ca: %t", *testCA)
262	logger.Infof("server_host_override: %s", *tlsServerName)
263
264	logger.Infoln("addresses:")
265	for i, addr := range addresses {
266		logger.Infof("%d. %s\n", i+1, addr)
267	}
268	logger.Infoln("tests:")
269	for i, test := range tests {
270		logger.Infof("%d. %v\n", i+1, test)
271	}
272}
273
274func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
275	var opts []grpc.DialOption
276	if useTLS {
277		var sn string
278		if tlsServerName != "" {
279			sn = tlsServerName
280		}
281		var creds credentials.TransportCredentials
282		if testCA {
283			var err error
284			if *caFile == "" {
285				*caFile = testdata.Path("x509/server_ca_cert.pem")
286			}
287			creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
288			if err != nil {
289				logger.Fatalf("Failed to create TLS credentials %v", err)
290			}
291		} else {
292			creds = credentials.NewClientTLSFromCert(nil, sn)
293		}
294		opts = append(opts, grpc.WithTransportCredentials(creds))
295	} else {
296		opts = append(opts, grpc.WithInsecure())
297	}
298	return grpc.Dial(address, opts...)
299}
300
301func main() {
302	flag.Parse()
303	addresses := strings.Split(*serverAddresses, ",")
304	tests := parseTestCases(*testCases)
305	logParameterInfo(addresses, tests)
306	testSelector := newWeightedRandomTestSelector(tests)
307	metricsServer := newMetricsServer()
308
309	var wg sync.WaitGroup
310	wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
311	stop := make(chan bool)
312
313	for serverIndex, address := range addresses {
314		for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
315			conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
316			if err != nil {
317				logger.Fatalf("Fail to dial: %v", err)
318			}
319			defer conn.Close()
320			for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
321				name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
322				go func() {
323					defer wg.Done()
324					g := metricsServer.createGauge(name)
325					performRPCs(g, conn, testSelector, stop)
326				}()
327			}
328
329		}
330	}
331	go startServer(metricsServer, *metricsPort)
332	if *testDurationSecs > 0 {
333		time.Sleep(time.Duration(*testDurationSecs) * time.Second)
334		close(stop)
335	}
336	wg.Wait()
337	logger.Infof(" ===== ALL DONE ===== ")
338
339}
340