1/*
2 *
3 * Copyright 2020 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Binary client for xDS interop tests.
20package main
21
22import (
23	"context"
24	"flag"
25	"fmt"
26	"log"
27	"net"
28	"strings"
29	"sync"
30	"sync/atomic"
31	"time"
32
33	"google.golang.org/grpc"
34	"google.golang.org/grpc/admin"
35	"google.golang.org/grpc/credentials/insecure"
36	"google.golang.org/grpc/credentials/xds"
37	"google.golang.org/grpc/grpclog"
38	"google.golang.org/grpc/metadata"
39	"google.golang.org/grpc/peer"
40	"google.golang.org/grpc/reflection"
41	"google.golang.org/grpc/status"
42	_ "google.golang.org/grpc/xds"
43
44	testgrpc "google.golang.org/grpc/interop/grpc_testing"
45	testpb "google.golang.org/grpc/interop/grpc_testing"
46)
47
48func init() {
49	rpcCfgs.Store([]*rpcConfig{{typ: unaryCall}})
50}
51
52type statsWatcherKey struct {
53	startID int32
54	endID   int32
55}
56
57// rpcInfo contains the rpc type and the hostname where the response is received
58// from.
59type rpcInfo struct {
60	typ      string
61	hostname string
62}
63
64type statsWatcher struct {
65	rpcsByPeer    map[string]int32
66	rpcsByType    map[string]map[string]int32
67	numFailures   int32
68	remainingRPCs int32
69	chanHosts     chan *rpcInfo
70}
71
72func (watcher *statsWatcher) buildResp() *testpb.LoadBalancerStatsResponse {
73	rpcsByType := make(map[string]*testpb.LoadBalancerStatsResponse_RpcsByPeer, len(watcher.rpcsByType))
74	for t, rpcsByPeer := range watcher.rpcsByType {
75		rpcsByType[t] = &testpb.LoadBalancerStatsResponse_RpcsByPeer{
76			RpcsByPeer: rpcsByPeer,
77		}
78	}
79
80	return &testpb.LoadBalancerStatsResponse{
81		NumFailures:  watcher.numFailures + watcher.remainingRPCs,
82		RpcsByPeer:   watcher.rpcsByPeer,
83		RpcsByMethod: rpcsByType,
84	}
85}
86
87type accumulatedStats struct {
88	mu                       sync.Mutex
89	numRPCsStartedByMethod   map[string]int32
90	numRPCsSucceededByMethod map[string]int32
91	numRPCsFailedByMethod    map[string]int32
92	rpcStatusByMethod        map[string]map[int32]int32
93}
94
95func convertRPCName(in string) string {
96	switch in {
97	case unaryCall:
98		return testpb.ClientConfigureRequest_UNARY_CALL.String()
99	case emptyCall:
100		return testpb.ClientConfigureRequest_EMPTY_CALL.String()
101	}
102	logger.Warningf("unrecognized rpc type: %s", in)
103	return in
104}
105
106// copyStatsMap makes a copy of the map.
107func copyStatsMap(originalMap map[string]int32) map[string]int32 {
108	newMap := make(map[string]int32, len(originalMap))
109	for k, v := range originalMap {
110		newMap[k] = v
111	}
112	return newMap
113}
114
115// copyStatsIntMap makes a copy of the map.
116func copyStatsIntMap(originalMap map[int32]int32) map[int32]int32 {
117	newMap := make(map[int32]int32, len(originalMap))
118	for k, v := range originalMap {
119		newMap[k] = v
120	}
121	return newMap
122}
123
124func (as *accumulatedStats) makeStatsMap() map[string]*testpb.LoadBalancerAccumulatedStatsResponse_MethodStats {
125	m := make(map[string]*testpb.LoadBalancerAccumulatedStatsResponse_MethodStats)
126	for k, v := range as.numRPCsStartedByMethod {
127		m[k] = &testpb.LoadBalancerAccumulatedStatsResponse_MethodStats{RpcsStarted: v}
128	}
129	for k, v := range as.rpcStatusByMethod {
130		if m[k] == nil {
131			m[k] = &testpb.LoadBalancerAccumulatedStatsResponse_MethodStats{}
132		}
133		m[k].Result = copyStatsIntMap(v)
134	}
135	return m
136}
137
138func (as *accumulatedStats) buildResp() *testpb.LoadBalancerAccumulatedStatsResponse {
139	as.mu.Lock()
140	defer as.mu.Unlock()
141	return &testpb.LoadBalancerAccumulatedStatsResponse{
142		NumRpcsStartedByMethod:   copyStatsMap(as.numRPCsStartedByMethod),
143		NumRpcsSucceededByMethod: copyStatsMap(as.numRPCsSucceededByMethod),
144		NumRpcsFailedByMethod:    copyStatsMap(as.numRPCsFailedByMethod),
145		StatsPerMethod:           as.makeStatsMap(),
146	}
147}
148
149func (as *accumulatedStats) startRPC(rpcType string) {
150	as.mu.Lock()
151	defer as.mu.Unlock()
152	as.numRPCsStartedByMethod[convertRPCName(rpcType)]++
153}
154
155func (as *accumulatedStats) finishRPC(rpcType string, err error) {
156	as.mu.Lock()
157	defer as.mu.Unlock()
158	name := convertRPCName(rpcType)
159	if as.rpcStatusByMethod[name] == nil {
160		as.rpcStatusByMethod[name] = make(map[int32]int32)
161	}
162	as.rpcStatusByMethod[name][int32(status.Convert(err).Code())]++
163	if err != nil {
164		as.numRPCsFailedByMethod[name]++
165		return
166	}
167	as.numRPCsSucceededByMethod[name]++
168}
169
170var (
171	failOnFailedRPC = flag.Bool("fail_on_failed_rpc", false, "Fail client if any RPCs fail after first success")
172	numChannels     = flag.Int("num_channels", 1, "Num of channels")
173	printResponse   = flag.Bool("print_response", false, "Write RPC response to stdout")
174	qps             = flag.Int("qps", 1, "QPS per channel, for each type of RPC")
175	rpc             = flag.String("rpc", "UnaryCall", "Types of RPCs to make, ',' separated string. RPCs can be EmptyCall or UnaryCall. Deprecated: Use Configure RPC to XdsUpdateClientConfigureServiceServer instead.")
176	rpcMetadata     = flag.String("metadata", "", "The metadata to send with RPC, in format EmptyCall:key1:value1,UnaryCall:key2:value2. Deprecated: Use Configure RPC to XdsUpdateClientConfigureServiceServer instead.")
177	rpcTimeout      = flag.Duration("rpc_timeout", 20*time.Second, "Per RPC timeout")
178	server          = flag.String("server", "localhost:8080", "Address of server to connect to")
179	statsPort       = flag.Int("stats_port", 8081, "Port to expose peer distribution stats service")
180	secureMode      = flag.Bool("secure_mode", false, "If true, retrieve security configuration from the management server. Else, use insecure credentials.")
181
182	rpcCfgs atomic.Value
183
184	mu               sync.Mutex
185	currentRequestID int32
186	watchers         = make(map[statsWatcherKey]*statsWatcher)
187
188	accStats = accumulatedStats{
189		numRPCsStartedByMethod:   make(map[string]int32),
190		numRPCsSucceededByMethod: make(map[string]int32),
191		numRPCsFailedByMethod:    make(map[string]int32),
192		rpcStatusByMethod:        make(map[string]map[int32]int32),
193	}
194
195	// 0 or 1 representing an RPC has succeeded. Use hasRPCSucceeded and
196	// setRPCSucceeded to access in a safe manner.
197	rpcSucceeded uint32
198
199	logger = grpclog.Component("interop")
200)
201
202type statsService struct {
203	testgrpc.UnimplementedLoadBalancerStatsServiceServer
204}
205
206func hasRPCSucceeded() bool {
207	return atomic.LoadUint32(&rpcSucceeded) > 0
208}
209
210func setRPCSucceeded() {
211	atomic.StoreUint32(&rpcSucceeded, 1)
212}
213
214// Wait for the next LoadBalancerStatsRequest.GetNumRpcs to start and complete,
215// and return the distribution of remote peers. This is essentially a clientside
216// LB reporting mechanism that is designed to be queried by an external test
217// driver when verifying that the client is distributing RPCs as expected.
218func (s *statsService) GetClientStats(ctx context.Context, in *testpb.LoadBalancerStatsRequest) (*testpb.LoadBalancerStatsResponse, error) {
219	mu.Lock()
220	watcherKey := statsWatcherKey{currentRequestID, currentRequestID + in.GetNumRpcs()}
221	watcher, ok := watchers[watcherKey]
222	if !ok {
223		watcher = &statsWatcher{
224			rpcsByPeer:    make(map[string]int32),
225			rpcsByType:    make(map[string]map[string]int32),
226			numFailures:   0,
227			remainingRPCs: in.GetNumRpcs(),
228			chanHosts:     make(chan *rpcInfo),
229		}
230		watchers[watcherKey] = watcher
231	}
232	mu.Unlock()
233
234	ctx, cancel := context.WithTimeout(ctx, time.Duration(in.GetTimeoutSec())*time.Second)
235	defer cancel()
236
237	defer func() {
238		mu.Lock()
239		delete(watchers, watcherKey)
240		mu.Unlock()
241	}()
242
243	// Wait until the requested RPCs have all been recorded or timeout occurs.
244	for {
245		select {
246		case info := <-watcher.chanHosts:
247			if info != nil {
248				watcher.rpcsByPeer[info.hostname]++
249
250				rpcsByPeerForType := watcher.rpcsByType[info.typ]
251				if rpcsByPeerForType == nil {
252					rpcsByPeerForType = make(map[string]int32)
253					watcher.rpcsByType[info.typ] = rpcsByPeerForType
254				}
255				rpcsByPeerForType[info.hostname]++
256			} else {
257				watcher.numFailures++
258			}
259			watcher.remainingRPCs--
260			if watcher.remainingRPCs == 0 {
261				return watcher.buildResp(), nil
262			}
263		case <-ctx.Done():
264			logger.Info("Timed out, returning partial stats")
265			return watcher.buildResp(), nil
266		}
267	}
268}
269
270func (s *statsService) GetClientAccumulatedStats(ctx context.Context, in *testpb.LoadBalancerAccumulatedStatsRequest) (*testpb.LoadBalancerAccumulatedStatsResponse, error) {
271	return accStats.buildResp(), nil
272}
273
274type configureService struct {
275	testgrpc.UnimplementedXdsUpdateClientConfigureServiceServer
276}
277
278func (s *configureService) Configure(ctx context.Context, in *testpb.ClientConfigureRequest) (*testpb.ClientConfigureResponse, error) {
279	rpcsToMD := make(map[testpb.ClientConfigureRequest_RpcType][]string)
280	for _, typ := range in.GetTypes() {
281		rpcsToMD[typ] = nil
282	}
283	for _, md := range in.GetMetadata() {
284		typ := md.GetType()
285		strs, ok := rpcsToMD[typ]
286		if !ok {
287			continue
288		}
289		rpcsToMD[typ] = append(strs, md.GetKey(), md.GetValue())
290	}
291	cfgs := make([]*rpcConfig, 0, len(rpcsToMD))
292	for typ, md := range rpcsToMD {
293		var rpcType string
294		switch typ {
295		case testpb.ClientConfigureRequest_UNARY_CALL:
296			rpcType = unaryCall
297		case testpb.ClientConfigureRequest_EMPTY_CALL:
298			rpcType = emptyCall
299		default:
300			return nil, fmt.Errorf("unsupported RPC type: %v", typ)
301		}
302		cfgs = append(cfgs, &rpcConfig{
303			typ:     rpcType,
304			md:      metadata.Pairs(md...),
305			timeout: in.GetTimeoutSec(),
306		})
307	}
308	rpcCfgs.Store(cfgs)
309	return &testpb.ClientConfigureResponse{}, nil
310}
311
312const (
313	unaryCall string = "UnaryCall"
314	emptyCall string = "EmptyCall"
315)
316
317func parseRPCTypes(rpcStr string) (ret []string) {
318	if len(rpcStr) == 0 {
319		return []string{unaryCall}
320	}
321
322	rpcs := strings.Split(rpcStr, ",")
323	for _, r := range rpcs {
324		switch r {
325		case unaryCall, emptyCall:
326			ret = append(ret, r)
327		default:
328			flag.PrintDefaults()
329			log.Fatalf("unsupported RPC type: %v", r)
330		}
331	}
332	return
333}
334
335type rpcConfig struct {
336	typ     string
337	md      metadata.MD
338	timeout int32
339}
340
341// parseRPCMetadata turns EmptyCall:key1:value1 into
342//   {typ: emptyCall, md: {key1:value1}}.
343func parseRPCMetadata(rpcMetadataStr string, rpcs []string) []*rpcConfig {
344	rpcMetadataSplit := strings.Split(rpcMetadataStr, ",")
345	rpcsToMD := make(map[string][]string)
346	for _, rm := range rpcMetadataSplit {
347		rmSplit := strings.Split(rm, ":")
348		if len(rmSplit)%2 != 1 {
349			log.Fatalf("invalid metadata config %v, want EmptyCall:key1:value1", rm)
350		}
351		rpcsToMD[rmSplit[0]] = append(rpcsToMD[rmSplit[0]], rmSplit[1:]...)
352	}
353	ret := make([]*rpcConfig, 0, len(rpcs))
354	for _, rpcT := range rpcs {
355		rpcC := &rpcConfig{
356			typ: rpcT,
357		}
358		if md := rpcsToMD[string(rpcT)]; len(md) > 0 {
359			rpcC.md = metadata.Pairs(md...)
360		}
361		ret = append(ret, rpcC)
362	}
363	return ret
364}
365
366func main() {
367	flag.Parse()
368	rpcCfgs.Store(parseRPCMetadata(*rpcMetadata, parseRPCTypes(*rpc)))
369
370	lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *statsPort))
371	if err != nil {
372		logger.Fatalf("failed to listen: %v", err)
373	}
374	s := grpc.NewServer()
375	defer s.Stop()
376	testgrpc.RegisterLoadBalancerStatsServiceServer(s, &statsService{})
377	testgrpc.RegisterXdsUpdateClientConfigureServiceServer(s, &configureService{})
378	reflection.Register(s)
379	cleanup, err := admin.Register(s)
380	if err != nil {
381		logger.Fatalf("Failed to register admin: %v", err)
382	}
383	defer cleanup()
384	go s.Serve(lis)
385
386	creds := insecure.NewCredentials()
387	if *secureMode {
388		var err error
389		creds, err = xds.NewClientCredentials(xds.ClientOptions{FallbackCreds: insecure.NewCredentials()})
390		if err != nil {
391			logger.Fatalf("Failed to create xDS credentials: %v", err)
392		}
393	}
394
395	clients := make([]testgrpc.TestServiceClient, *numChannels)
396	for i := 0; i < *numChannels; i++ {
397		conn, err := grpc.Dial(*server, grpc.WithTransportCredentials(creds))
398		if err != nil {
399			logger.Fatalf("Fail to dial: %v", err)
400		}
401		defer conn.Close()
402		clients[i] = testgrpc.NewTestServiceClient(conn)
403	}
404	ticker := time.NewTicker(time.Second / time.Duration(*qps**numChannels))
405	defer ticker.Stop()
406	sendRPCs(clients, ticker)
407}
408
409func makeOneRPC(c testgrpc.TestServiceClient, cfg *rpcConfig) (*peer.Peer, *rpcInfo, error) {
410	timeout := *rpcTimeout
411	if cfg.timeout != 0 {
412		timeout = time.Duration(cfg.timeout) * time.Second
413	}
414	ctx, cancel := context.WithTimeout(context.Background(), timeout)
415	defer cancel()
416
417	if len(cfg.md) != 0 {
418		ctx = metadata.NewOutgoingContext(ctx, cfg.md)
419	}
420	info := rpcInfo{typ: cfg.typ}
421
422	var (
423		p      peer.Peer
424		header metadata.MD
425		err    error
426	)
427	accStats.startRPC(cfg.typ)
428	switch cfg.typ {
429	case unaryCall:
430		var resp *testpb.SimpleResponse
431		resp, err = c.UnaryCall(ctx, &testpb.SimpleRequest{FillServerId: true}, grpc.Peer(&p), grpc.Header(&header))
432		// For UnaryCall, also read hostname from response, in case the server
433		// isn't updated to send headers.
434		if resp != nil {
435			info.hostname = resp.Hostname
436		}
437	case emptyCall:
438		_, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&p), grpc.Header(&header))
439	}
440	accStats.finishRPC(cfg.typ, err)
441	if err != nil {
442		return nil, nil, err
443	}
444
445	hosts := header["hostname"]
446	if len(hosts) > 0 {
447		info.hostname = hosts[0]
448	}
449	return &p, &info, err
450}
451
452func sendRPCs(clients []testgrpc.TestServiceClient, ticker *time.Ticker) {
453	var i int
454	for range ticker.C {
455		// Get and increment request ID, and save a list of watchers that are
456		// interested in this RPC.
457		mu.Lock()
458		savedRequestID := currentRequestID
459		currentRequestID++
460		savedWatchers := []*statsWatcher{}
461		for key, value := range watchers {
462			if key.startID <= savedRequestID && savedRequestID < key.endID {
463				savedWatchers = append(savedWatchers, value)
464			}
465		}
466		mu.Unlock()
467
468		// Get the RPC metadata configurations from the Configure RPC.
469		cfgs := rpcCfgs.Load().([]*rpcConfig)
470
471		c := clients[i]
472		for _, cfg := range cfgs {
473			go func(cfg *rpcConfig) {
474				p, info, err := makeOneRPC(c, cfg)
475
476				for _, watcher := range savedWatchers {
477					// This sends an empty string if the RPC failed.
478					watcher.chanHosts <- info
479				}
480				if err != nil && *failOnFailedRPC && hasRPCSucceeded() {
481					logger.Fatalf("RPC failed: %v", err)
482				}
483				if err == nil {
484					setRPCSucceeded()
485				}
486				if *printResponse {
487					if err == nil {
488						if cfg.typ == unaryCall {
489							// Need to keep this format, because some tests are
490							// relying on stdout.
491							fmt.Printf("Greeting: Hello world, this is %s, from %v\n", info.hostname, p.Addr)
492						} else {
493							fmt.Printf("RPC %q, from host %s, addr %v\n", cfg.typ, info.hostname, p.Addr)
494						}
495					} else {
496						fmt.Printf("RPC %q, failed with %v\n", cfg.typ, err)
497					}
498				}
499			}(cfg)
500		}
501		i = (i + 1) % len(clients)
502	}
503}
504