1/* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19// client starts an interop client to do stress test and a metrics server to report qps. 20package main 21 22import ( 23 "context" 24 "flag" 25 "fmt" 26 "math/rand" 27 "net" 28 "strconv" 29 "strings" 30 "sync" 31 "time" 32 33 "google.golang.org/grpc" 34 "google.golang.org/grpc/codes" 35 "google.golang.org/grpc/credentials" 36 "google.golang.org/grpc/grpclog" 37 "google.golang.org/grpc/interop" 38 testpb "google.golang.org/grpc/interop/grpc_testing" 39 "google.golang.org/grpc/status" 40 metricspb "google.golang.org/grpc/stress/grpc_testing" 41 "google.golang.org/grpc/testdata" 42) 43 44var ( 45 serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses") 46 testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights") 47 testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds") 48 numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server") 49 numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server") 50 metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics") 51 useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") 52 testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") 53 tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") 54 caFile = flag.String("ca_file", "", "The file containing the CA root cert file") 55 56 logger = grpclog.Component("stress") 57) 58 59// testCaseWithWeight contains the test case type and its weight. 60type testCaseWithWeight struct { 61 name string 62 weight int 63} 64 65// parseTestCases converts test case string to a list of struct testCaseWithWeight. 66func parseTestCases(testCaseString string) []testCaseWithWeight { 67 testCaseStrings := strings.Split(testCaseString, ",") 68 testCases := make([]testCaseWithWeight, len(testCaseStrings)) 69 for i, str := range testCaseStrings { 70 testCase := strings.Split(str, ":") 71 if len(testCase) != 2 { 72 panic(fmt.Sprintf("invalid test case with weight: %s", str)) 73 } 74 // Check if test case is supported. 75 switch testCase[0] { 76 case 77 "empty_unary", 78 "large_unary", 79 "client_streaming", 80 "server_streaming", 81 "ping_pong", 82 "empty_stream", 83 "timeout_on_sleeping_server", 84 "cancel_after_begin", 85 "cancel_after_first_response", 86 "status_code_and_message", 87 "custom_metadata": 88 default: 89 panic(fmt.Sprintf("unknown test type: %s", testCase[0])) 90 } 91 testCases[i].name = testCase[0] 92 w, err := strconv.Atoi(testCase[1]) 93 if err != nil { 94 panic(fmt.Sprintf("%v", err)) 95 } 96 testCases[i].weight = w 97 } 98 return testCases 99} 100 101// weightedRandomTestSelector defines a weighted random selector for test case types. 102type weightedRandomTestSelector struct { 103 tests []testCaseWithWeight 104 totalWeight int 105} 106 107// newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight. 108func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector { 109 var totalWeight int 110 for _, t := range tests { 111 totalWeight += t.weight 112 } 113 rand.Seed(time.Now().UnixNano()) 114 return &weightedRandomTestSelector{tests, totalWeight} 115} 116 117func (selector weightedRandomTestSelector) getNextTest() string { 118 random := rand.Intn(selector.totalWeight) 119 var weightSofar int 120 for _, test := range selector.tests { 121 weightSofar += test.weight 122 if random < weightSofar { 123 return test.name 124 } 125 } 126 panic("no test case selected by weightedRandomTestSelector") 127} 128 129// gauge stores the qps of one interop client (one stub). 130type gauge struct { 131 mutex sync.RWMutex 132 val int64 133} 134 135func (g *gauge) set(v int64) { 136 g.mutex.Lock() 137 defer g.mutex.Unlock() 138 g.val = v 139} 140 141func (g *gauge) get() int64 { 142 g.mutex.RLock() 143 defer g.mutex.RUnlock() 144 return g.val 145} 146 147// server implements metrics server functions. 148type server struct { 149 metricspb.UnimplementedMetricsServiceServer 150 mutex sync.RWMutex 151 // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge. 152 gauges map[string]*gauge 153} 154 155// newMetricsServer returns a new metrics server. 156func newMetricsServer() *server { 157 return &server{gauges: make(map[string]*gauge)} 158} 159 160// GetAllGauges returns all gauges. 161func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error { 162 s.mutex.RLock() 163 defer s.mutex.RUnlock() 164 165 for name, gauge := range s.gauges { 166 if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil { 167 return err 168 } 169 } 170 return nil 171} 172 173// GetGauge returns the gauge for the given name. 174func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) { 175 s.mutex.RLock() 176 defer s.mutex.RUnlock() 177 178 if g, ok := s.gauges[in.Name]; ok { 179 return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil 180 } 181 return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name) 182} 183 184// createGauge creates a gauge using the given name in metrics server. 185func (s *server) createGauge(name string) *gauge { 186 s.mutex.Lock() 187 defer s.mutex.Unlock() 188 189 if _, ok := s.gauges[name]; ok { 190 // gauge already exists. 191 panic(fmt.Sprintf("gauge %s already exists", name)) 192 } 193 var g gauge 194 s.gauges[name] = &g 195 return &g 196} 197 198func startServer(server *server, port int) { 199 lis, err := net.Listen("tcp", ":"+strconv.Itoa(port)) 200 if err != nil { 201 logger.Fatalf("failed to listen: %v", err) 202 } 203 204 s := grpc.NewServer() 205 metricspb.RegisterMetricsServiceServer(s, server) 206 s.Serve(lis) 207 208} 209 210// performRPCs uses weightedRandomTestSelector to select test case and runs the tests. 211func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) { 212 client := testpb.NewTestServiceClient(conn) 213 var numCalls int64 214 startTime := time.Now() 215 for { 216 test := selector.getNextTest() 217 switch test { 218 case "empty_unary": 219 interop.DoEmptyUnaryCall(client, grpc.WaitForReady(true)) 220 case "large_unary": 221 interop.DoLargeUnaryCall(client, grpc.WaitForReady(true)) 222 case "client_streaming": 223 interop.DoClientStreaming(client, grpc.WaitForReady(true)) 224 case "server_streaming": 225 interop.DoServerStreaming(client, grpc.WaitForReady(true)) 226 case "ping_pong": 227 interop.DoPingPong(client, grpc.WaitForReady(true)) 228 case "empty_stream": 229 interop.DoEmptyStream(client, grpc.WaitForReady(true)) 230 case "timeout_on_sleeping_server": 231 interop.DoTimeoutOnSleepingServer(client, grpc.WaitForReady(true)) 232 case "cancel_after_begin": 233 interop.DoCancelAfterBegin(client, grpc.WaitForReady(true)) 234 case "cancel_after_first_response": 235 interop.DoCancelAfterFirstResponse(client, grpc.WaitForReady(true)) 236 case "status_code_and_message": 237 interop.DoStatusCodeAndMessage(client, grpc.WaitForReady(true)) 238 case "custom_metadata": 239 interop.DoCustomMetadata(client, grpc.WaitForReady(true)) 240 } 241 numCalls++ 242 gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds())) 243 244 select { 245 case <-stop: 246 return 247 default: 248 } 249 } 250} 251 252func logParameterInfo(addresses []string, tests []testCaseWithWeight) { 253 logger.Infof("server_addresses: %s", *serverAddresses) 254 logger.Infof("test_cases: %s", *testCases) 255 logger.Infof("test_duration_secs: %d", *testDurationSecs) 256 logger.Infof("num_channels_per_server: %d", *numChannelsPerServer) 257 logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel) 258 logger.Infof("metrics_port: %d", *metricsPort) 259 logger.Infof("use_tls: %t", *useTLS) 260 logger.Infof("use_test_ca: %t", *testCA) 261 logger.Infof("server_host_override: %s", *tlsServerName) 262 263 logger.Infoln("addresses:") 264 for i, addr := range addresses { 265 logger.Infof("%d. %s\n", i+1, addr) 266 } 267 logger.Infoln("tests:") 268 for i, test := range tests { 269 logger.Infof("%d. %v\n", i+1, test) 270 } 271} 272 273func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) { 274 var opts []grpc.DialOption 275 if useTLS { 276 var sn string 277 if tlsServerName != "" { 278 sn = tlsServerName 279 } 280 var creds credentials.TransportCredentials 281 if testCA { 282 var err error 283 if *caFile == "" { 284 *caFile = testdata.Path("x509/server_ca_cert.pem") 285 } 286 creds, err = credentials.NewClientTLSFromFile(*caFile, sn) 287 if err != nil { 288 logger.Fatalf("Failed to create TLS credentials %v", err) 289 } 290 } else { 291 creds = credentials.NewClientTLSFromCert(nil, sn) 292 } 293 opts = append(opts, grpc.WithTransportCredentials(creds)) 294 } else { 295 opts = append(opts, grpc.WithInsecure()) 296 } 297 return grpc.Dial(address, opts...) 298} 299 300func main() { 301 flag.Parse() 302 addresses := strings.Split(*serverAddresses, ",") 303 tests := parseTestCases(*testCases) 304 logParameterInfo(addresses, tests) 305 testSelector := newWeightedRandomTestSelector(tests) 306 metricsServer := newMetricsServer() 307 308 var wg sync.WaitGroup 309 wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel) 310 stop := make(chan bool) 311 312 for serverIndex, address := range addresses { 313 for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ { 314 conn, err := newConn(address, *useTLS, *testCA, *tlsServerName) 315 if err != nil { 316 logger.Fatalf("Fail to dial: %v", err) 317 } 318 defer conn.Close() 319 for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ { 320 name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1) 321 go func() { 322 defer wg.Done() 323 g := metricsServer.createGauge(name) 324 performRPCs(g, conn, testSelector, stop) 325 }() 326 } 327 328 } 329 } 330 go startServer(metricsServer, *metricsPort) 331 if *testDurationSecs > 0 { 332 time.Sleep(time.Duration(*testDurationSecs) * time.Second) 333 close(stop) 334 } 335 wg.Wait() 336 logger.Infof(" ===== ALL DONE ===== ") 337 338} 339