1/* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19//go:generate protoc -I ../grpc_testing --go_out=plugins=grpc:../grpc_testing ../grpc_testing/metrics.proto 20 21// client starts an interop client to do stress test and a metrics server to report qps. 22package main 23 24import ( 25 "context" 26 "flag" 27 "fmt" 28 "math/rand" 29 "net" 30 "strconv" 31 "strings" 32 "sync" 33 "time" 34 35 "google.golang.org/grpc" 36 "google.golang.org/grpc/codes" 37 "google.golang.org/grpc/credentials" 38 "google.golang.org/grpc/grpclog" 39 "google.golang.org/grpc/interop" 40 testpb "google.golang.org/grpc/interop/grpc_testing" 41 "google.golang.org/grpc/status" 42 metricspb "google.golang.org/grpc/stress/grpc_testing" 43 "google.golang.org/grpc/testdata" 44) 45 46var ( 47 serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses") 48 testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights") 49 testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds") 50 numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server") 51 numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server") 52 metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics") 53 useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") 54 testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") 55 tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") 56 caFile = flag.String("ca_file", "", "The file containing the CA root cert file") 57) 58 59// testCaseWithWeight contains the test case type and its weight. 60type testCaseWithWeight struct { 61 name string 62 weight int 63} 64 65// parseTestCases converts test case string to a list of struct testCaseWithWeight. 66func parseTestCases(testCaseString string) []testCaseWithWeight { 67 testCaseStrings := strings.Split(testCaseString, ",") 68 testCases := make([]testCaseWithWeight, len(testCaseStrings)) 69 for i, str := range testCaseStrings { 70 testCase := strings.Split(str, ":") 71 if len(testCase) != 2 { 72 panic(fmt.Sprintf("invalid test case with weight: %s", str)) 73 } 74 // Check if test case is supported. 75 switch testCase[0] { 76 case 77 "empty_unary", 78 "large_unary", 79 "client_streaming", 80 "server_streaming", 81 "ping_pong", 82 "empty_stream", 83 "timeout_on_sleeping_server", 84 "cancel_after_begin", 85 "cancel_after_first_response", 86 "status_code_and_message", 87 "custom_metadata": 88 default: 89 panic(fmt.Sprintf("unknown test type: %s", testCase[0])) 90 } 91 testCases[i].name = testCase[0] 92 w, err := strconv.Atoi(testCase[1]) 93 if err != nil { 94 panic(fmt.Sprintf("%v", err)) 95 } 96 testCases[i].weight = w 97 } 98 return testCases 99} 100 101// weightedRandomTestSelector defines a weighted random selector for test case types. 102type weightedRandomTestSelector struct { 103 tests []testCaseWithWeight 104 totalWeight int 105} 106 107// newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight. 108func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector { 109 var totalWeight int 110 for _, t := range tests { 111 totalWeight += t.weight 112 } 113 rand.Seed(time.Now().UnixNano()) 114 return &weightedRandomTestSelector{tests, totalWeight} 115} 116 117func (selector weightedRandomTestSelector) getNextTest() string { 118 random := rand.Intn(selector.totalWeight) 119 var weightSofar int 120 for _, test := range selector.tests { 121 weightSofar += test.weight 122 if random < weightSofar { 123 return test.name 124 } 125 } 126 panic("no test case selected by weightedRandomTestSelector") 127} 128 129// gauge stores the qps of one interop client (one stub). 130type gauge struct { 131 mutex sync.RWMutex 132 val int64 133} 134 135func (g *gauge) set(v int64) { 136 g.mutex.Lock() 137 defer g.mutex.Unlock() 138 g.val = v 139} 140 141func (g *gauge) get() int64 { 142 g.mutex.RLock() 143 defer g.mutex.RUnlock() 144 return g.val 145} 146 147// server implements metrics server functions. 148type server struct { 149 mutex sync.RWMutex 150 // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge. 151 gauges map[string]*gauge 152} 153 154// newMetricsServer returns a new metrics server. 155func newMetricsServer() *server { 156 return &server{gauges: make(map[string]*gauge)} 157} 158 159// GetAllGauges returns all gauges. 160func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error { 161 s.mutex.RLock() 162 defer s.mutex.RUnlock() 163 164 for name, gauge := range s.gauges { 165 if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil { 166 return err 167 } 168 } 169 return nil 170} 171 172// GetGauge returns the gauge for the given name. 173func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) { 174 s.mutex.RLock() 175 defer s.mutex.RUnlock() 176 177 if g, ok := s.gauges[in.Name]; ok { 178 return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil 179 } 180 return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name) 181} 182 183// createGauge creates a gauge using the given name in metrics server. 184func (s *server) createGauge(name string) *gauge { 185 s.mutex.Lock() 186 defer s.mutex.Unlock() 187 188 if _, ok := s.gauges[name]; ok { 189 // gauge already exists. 190 panic(fmt.Sprintf("gauge %s already exists", name)) 191 } 192 var g gauge 193 s.gauges[name] = &g 194 return &g 195} 196 197func startServer(server *server, port int) { 198 lis, err := net.Listen("tcp", ":"+strconv.Itoa(port)) 199 if err != nil { 200 grpclog.Fatalf("failed to listen: %v", err) 201 } 202 203 s := grpc.NewServer() 204 metricspb.RegisterMetricsServiceServer(s, server) 205 s.Serve(lis) 206 207} 208 209// performRPCs uses weightedRandomTestSelector to select test case and runs the tests. 210func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) { 211 client := testpb.NewTestServiceClient(conn) 212 var numCalls int64 213 startTime := time.Now() 214 for { 215 test := selector.getNextTest() 216 switch test { 217 case "empty_unary": 218 interop.DoEmptyUnaryCall(client, grpc.WaitForReady(true)) 219 case "large_unary": 220 interop.DoLargeUnaryCall(client, grpc.WaitForReady(true)) 221 case "client_streaming": 222 interop.DoClientStreaming(client, grpc.WaitForReady(true)) 223 case "server_streaming": 224 interop.DoServerStreaming(client, grpc.WaitForReady(true)) 225 case "ping_pong": 226 interop.DoPingPong(client, grpc.WaitForReady(true)) 227 case "empty_stream": 228 interop.DoEmptyStream(client, grpc.WaitForReady(true)) 229 case "timeout_on_sleeping_server": 230 interop.DoTimeoutOnSleepingServer(client, grpc.WaitForReady(true)) 231 case "cancel_after_begin": 232 interop.DoCancelAfterBegin(client, grpc.WaitForReady(true)) 233 case "cancel_after_first_response": 234 interop.DoCancelAfterFirstResponse(client, grpc.WaitForReady(true)) 235 case "status_code_and_message": 236 interop.DoStatusCodeAndMessage(client, grpc.WaitForReady(true)) 237 case "custom_metadata": 238 interop.DoCustomMetadata(client, grpc.WaitForReady(true)) 239 } 240 numCalls++ 241 gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds())) 242 243 select { 244 case <-stop: 245 return 246 default: 247 } 248 } 249} 250 251func logParameterInfo(addresses []string, tests []testCaseWithWeight) { 252 grpclog.Infof("server_addresses: %s", *serverAddresses) 253 grpclog.Infof("test_cases: %s", *testCases) 254 grpclog.Infof("test_duration_secs: %d", *testDurationSecs) 255 grpclog.Infof("num_channels_per_server: %d", *numChannelsPerServer) 256 grpclog.Infof("num_stubs_per_channel: %d", *numStubsPerChannel) 257 grpclog.Infof("metrics_port: %d", *metricsPort) 258 grpclog.Infof("use_tls: %t", *useTLS) 259 grpclog.Infof("use_test_ca: %t", *testCA) 260 grpclog.Infof("server_host_override: %s", *tlsServerName) 261 262 grpclog.Infoln("addresses:") 263 for i, addr := range addresses { 264 grpclog.Infof("%d. %s\n", i+1, addr) 265 } 266 grpclog.Infoln("tests:") 267 for i, test := range tests { 268 grpclog.Infof("%d. %v\n", i+1, test) 269 } 270} 271 272func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) { 273 var opts []grpc.DialOption 274 if useTLS { 275 var sn string 276 if tlsServerName != "" { 277 sn = tlsServerName 278 } 279 var creds credentials.TransportCredentials 280 if testCA { 281 var err error 282 if *caFile == "" { 283 *caFile = testdata.Path("ca.pem") 284 } 285 creds, err = credentials.NewClientTLSFromFile(*caFile, sn) 286 if err != nil { 287 grpclog.Fatalf("Failed to create TLS credentials %v", err) 288 } 289 } else { 290 creds = credentials.NewClientTLSFromCert(nil, sn) 291 } 292 opts = append(opts, grpc.WithTransportCredentials(creds)) 293 } else { 294 opts = append(opts, grpc.WithInsecure()) 295 } 296 return grpc.Dial(address, opts...) 297} 298 299func main() { 300 flag.Parse() 301 addresses := strings.Split(*serverAddresses, ",") 302 tests := parseTestCases(*testCases) 303 logParameterInfo(addresses, tests) 304 testSelector := newWeightedRandomTestSelector(tests) 305 metricsServer := newMetricsServer() 306 307 var wg sync.WaitGroup 308 wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel) 309 stop := make(chan bool) 310 311 for serverIndex, address := range addresses { 312 for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ { 313 conn, err := newConn(address, *useTLS, *testCA, *tlsServerName) 314 if err != nil { 315 grpclog.Fatalf("Fail to dial: %v", err) 316 } 317 defer conn.Close() 318 for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ { 319 name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1) 320 go func() { 321 defer wg.Done() 322 g := metricsServer.createGauge(name) 323 performRPCs(g, conn, testSelector, stop) 324 }() 325 } 326 327 } 328 } 329 go startServer(metricsServer, *metricsPort) 330 if *testDurationSecs > 0 { 331 time.Sleep(time.Duration(*testDurationSecs) * time.Second) 332 close(stop) 333 } 334 wg.Wait() 335 grpclog.Infof(" ===== ALL DONE ===== ") 336 337} 338