1/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19//go:generate protoc -I ../grpc_testing --go_out=plugins=grpc:../grpc_testing ../grpc_testing/metrics.proto
20
21// client starts an interop client to do stress test and a metrics server to report qps.
22package main
23
24import (
25	"context"
26	"flag"
27	"fmt"
28	"math/rand"
29	"net"
30	"strconv"
31	"strings"
32	"sync"
33	"time"
34
35	"google.golang.org/grpc"
36	"google.golang.org/grpc/codes"
37	"google.golang.org/grpc/credentials"
38	"google.golang.org/grpc/grpclog"
39	"google.golang.org/grpc/interop"
40	testpb "google.golang.org/grpc/interop/grpc_testing"
41	"google.golang.org/grpc/status"
42	metricspb "google.golang.org/grpc/stress/grpc_testing"
43	"google.golang.org/grpc/testdata"
44)
45
46var (
47	serverAddresses      = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
48	testCases            = flag.String("test_cases", "", "a list of test cases along with the relative weights")
49	testDurationSecs     = flag.Int("test_duration_secs", -1, "test duration in seconds")
50	numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
51	numStubsPerChannel   = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
52	metricsPort          = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
53	useTLS               = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
54	testCA               = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
55	tlsServerName        = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
56	caFile               = flag.String("ca_file", "", "The file containing the CA root cert file")
57)
58
59// testCaseWithWeight contains the test case type and its weight.
60type testCaseWithWeight struct {
61	name   string
62	weight int
63}
64
65// parseTestCases converts test case string to a list of struct testCaseWithWeight.
66func parseTestCases(testCaseString string) []testCaseWithWeight {
67	testCaseStrings := strings.Split(testCaseString, ",")
68	testCases := make([]testCaseWithWeight, len(testCaseStrings))
69	for i, str := range testCaseStrings {
70		testCase := strings.Split(str, ":")
71		if len(testCase) != 2 {
72			panic(fmt.Sprintf("invalid test case with weight: %s", str))
73		}
74		// Check if test case is supported.
75		switch testCase[0] {
76		case
77			"empty_unary",
78			"large_unary",
79			"client_streaming",
80			"server_streaming",
81			"ping_pong",
82			"empty_stream",
83			"timeout_on_sleeping_server",
84			"cancel_after_begin",
85			"cancel_after_first_response",
86			"status_code_and_message",
87			"custom_metadata":
88		default:
89			panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
90		}
91		testCases[i].name = testCase[0]
92		w, err := strconv.Atoi(testCase[1])
93		if err != nil {
94			panic(fmt.Sprintf("%v", err))
95		}
96		testCases[i].weight = w
97	}
98	return testCases
99}
100
101// weightedRandomTestSelector defines a weighted random selector for test case types.
102type weightedRandomTestSelector struct {
103	tests       []testCaseWithWeight
104	totalWeight int
105}
106
107// newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
108func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
109	var totalWeight int
110	for _, t := range tests {
111		totalWeight += t.weight
112	}
113	rand.Seed(time.Now().UnixNano())
114	return &weightedRandomTestSelector{tests, totalWeight}
115}
116
117func (selector weightedRandomTestSelector) getNextTest() string {
118	random := rand.Intn(selector.totalWeight)
119	var weightSofar int
120	for _, test := range selector.tests {
121		weightSofar += test.weight
122		if random < weightSofar {
123			return test.name
124		}
125	}
126	panic("no test case selected by weightedRandomTestSelector")
127}
128
129// gauge stores the qps of one interop client (one stub).
130type gauge struct {
131	mutex sync.RWMutex
132	val   int64
133}
134
135func (g *gauge) set(v int64) {
136	g.mutex.Lock()
137	defer g.mutex.Unlock()
138	g.val = v
139}
140
141func (g *gauge) get() int64 {
142	g.mutex.RLock()
143	defer g.mutex.RUnlock()
144	return g.val
145}
146
147// server implements metrics server functions.
148type server struct {
149	mutex sync.RWMutex
150	// gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
151	gauges map[string]*gauge
152}
153
154// newMetricsServer returns a new metrics server.
155func newMetricsServer() *server {
156	return &server{gauges: make(map[string]*gauge)}
157}
158
159// GetAllGauges returns all gauges.
160func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
161	s.mutex.RLock()
162	defer s.mutex.RUnlock()
163
164	for name, gauge := range s.gauges {
165		if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
166			return err
167		}
168	}
169	return nil
170}
171
172// GetGauge returns the gauge for the given name.
173func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
174	s.mutex.RLock()
175	defer s.mutex.RUnlock()
176
177	if g, ok := s.gauges[in.Name]; ok {
178		return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
179	}
180	return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
181}
182
183// createGauge creates a gauge using the given name in metrics server.
184func (s *server) createGauge(name string) *gauge {
185	s.mutex.Lock()
186	defer s.mutex.Unlock()
187
188	if _, ok := s.gauges[name]; ok {
189		// gauge already exists.
190		panic(fmt.Sprintf("gauge %s already exists", name))
191	}
192	var g gauge
193	s.gauges[name] = &g
194	return &g
195}
196
197func startServer(server *server, port int) {
198	lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
199	if err != nil {
200		grpclog.Fatalf("failed to listen: %v", err)
201	}
202
203	s := grpc.NewServer()
204	metricspb.RegisterMetricsServiceServer(s, server)
205	s.Serve(lis)
206
207}
208
209// performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
210func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
211	client := testpb.NewTestServiceClient(conn)
212	var numCalls int64
213	startTime := time.Now()
214	for {
215		test := selector.getNextTest()
216		switch test {
217		case "empty_unary":
218			interop.DoEmptyUnaryCall(client, grpc.WaitForReady(true))
219		case "large_unary":
220			interop.DoLargeUnaryCall(client, grpc.WaitForReady(true))
221		case "client_streaming":
222			interop.DoClientStreaming(client, grpc.WaitForReady(true))
223		case "server_streaming":
224			interop.DoServerStreaming(client, grpc.WaitForReady(true))
225		case "ping_pong":
226			interop.DoPingPong(client, grpc.WaitForReady(true))
227		case "empty_stream":
228			interop.DoEmptyStream(client, grpc.WaitForReady(true))
229		case "timeout_on_sleeping_server":
230			interop.DoTimeoutOnSleepingServer(client, grpc.WaitForReady(true))
231		case "cancel_after_begin":
232			interop.DoCancelAfterBegin(client, grpc.WaitForReady(true))
233		case "cancel_after_first_response":
234			interop.DoCancelAfterFirstResponse(client, grpc.WaitForReady(true))
235		case "status_code_and_message":
236			interop.DoStatusCodeAndMessage(client, grpc.WaitForReady(true))
237		case "custom_metadata":
238			interop.DoCustomMetadata(client, grpc.WaitForReady(true))
239		}
240		numCalls++
241		gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
242
243		select {
244		case <-stop:
245			return
246		default:
247		}
248	}
249}
250
251func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
252	grpclog.Infof("server_addresses: %s", *serverAddresses)
253	grpclog.Infof("test_cases: %s", *testCases)
254	grpclog.Infof("test_duration_secs: %d", *testDurationSecs)
255	grpclog.Infof("num_channels_per_server: %d", *numChannelsPerServer)
256	grpclog.Infof("num_stubs_per_channel: %d", *numStubsPerChannel)
257	grpclog.Infof("metrics_port: %d", *metricsPort)
258	grpclog.Infof("use_tls: %t", *useTLS)
259	grpclog.Infof("use_test_ca: %t", *testCA)
260	grpclog.Infof("server_host_override: %s", *tlsServerName)
261
262	grpclog.Infoln("addresses:")
263	for i, addr := range addresses {
264		grpclog.Infof("%d. %s\n", i+1, addr)
265	}
266	grpclog.Infoln("tests:")
267	for i, test := range tests {
268		grpclog.Infof("%d. %v\n", i+1, test)
269	}
270}
271
272func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
273	var opts []grpc.DialOption
274	if useTLS {
275		var sn string
276		if tlsServerName != "" {
277			sn = tlsServerName
278		}
279		var creds credentials.TransportCredentials
280		if testCA {
281			var err error
282			if *caFile == "" {
283				*caFile = testdata.Path("ca.pem")
284			}
285			creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
286			if err != nil {
287				grpclog.Fatalf("Failed to create TLS credentials %v", err)
288			}
289		} else {
290			creds = credentials.NewClientTLSFromCert(nil, sn)
291		}
292		opts = append(opts, grpc.WithTransportCredentials(creds))
293	} else {
294		opts = append(opts, grpc.WithInsecure())
295	}
296	return grpc.Dial(address, opts...)
297}
298
299func main() {
300	flag.Parse()
301	addresses := strings.Split(*serverAddresses, ",")
302	tests := parseTestCases(*testCases)
303	logParameterInfo(addresses, tests)
304	testSelector := newWeightedRandomTestSelector(tests)
305	metricsServer := newMetricsServer()
306
307	var wg sync.WaitGroup
308	wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
309	stop := make(chan bool)
310
311	for serverIndex, address := range addresses {
312		for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
313			conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
314			if err != nil {
315				grpclog.Fatalf("Fail to dial: %v", err)
316			}
317			defer conn.Close()
318			for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
319				name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
320				go func() {
321					defer wg.Done()
322					g := metricsServer.createGauge(name)
323					performRPCs(g, conn, testSelector, stop)
324				}()
325			}
326
327		}
328	}
329	go startServer(metricsServer, *metricsPort)
330	if *testDurationSecs > 0 {
331		time.Sleep(time.Duration(*testDurationSecs) * time.Second)
332		close(stop)
333	}
334	wg.Wait()
335	grpclog.Infof(" ===== ALL DONE ===== ")
336
337}
338