1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package channelz defines APIs for enabling channelz service, entry
20// registration/deletion, and accessing channelz data. It also defines channelz
21// metric struct formats.
22//
23// All APIs in this package are experimental.
24package channelz
25
26import (
27	"fmt"
28	"sort"
29	"sync"
30	"sync/atomic"
31	"time"
32
33	"google.golang.org/grpc/grpclog"
34)
35
36const (
37	defaultMaxTraceEntry int32 = 30
38)
39
40var (
41	db    dbWrapper
42	idGen idGenerator
43	// EntryPerPage defines the number of channelz entries to be shown on a web page.
44	EntryPerPage  = int64(50)
45	curState      int32
46	maxTraceEntry = defaultMaxTraceEntry
47)
48
49// TurnOn turns on channelz data collection.
50func TurnOn() {
51	if !IsOn() {
52		NewChannelzStorage()
53		atomic.StoreInt32(&curState, 1)
54	}
55}
56
57// IsOn returns whether channelz data collection is on.
58func IsOn() bool {
59	return atomic.CompareAndSwapInt32(&curState, 1, 1)
60}
61
62// SetMaxTraceEntry sets maximum number of trace entry per entity (i.e. channel/subchannel).
63// Setting it to 0 will disable channel tracing.
64func SetMaxTraceEntry(i int32) {
65	atomic.StoreInt32(&maxTraceEntry, i)
66}
67
68// ResetMaxTraceEntryToDefault resets the maximum number of trace entry per entity to default.
69func ResetMaxTraceEntryToDefault() {
70	atomic.StoreInt32(&maxTraceEntry, defaultMaxTraceEntry)
71}
72
73func getMaxTraceEntry() int {
74	i := atomic.LoadInt32(&maxTraceEntry)
75	return int(i)
76}
77
78// dbWarpper wraps around a reference to internal channelz data storage, and
79// provide synchronized functionality to set and get the reference.
80type dbWrapper struct {
81	mu sync.RWMutex
82	DB *channelMap
83}
84
85func (d *dbWrapper) set(db *channelMap) {
86	d.mu.Lock()
87	d.DB = db
88	d.mu.Unlock()
89}
90
91func (d *dbWrapper) get() *channelMap {
92	d.mu.RLock()
93	defer d.mu.RUnlock()
94	return d.DB
95}
96
97// NewChannelzStorage initializes channelz data storage and id generator.
98//
99// This function returns a cleanup function to wait for all channelz state to be reset by the
100// grpc goroutines when those entities get closed. By using this cleanup function, we make sure tests
101// don't mess up each other, i.e. lingering goroutine from previous test doing entity removal happen
102// to remove some entity just register by the new test, since the id space is the same.
103//
104// Note: This function is exported for testing purpose only. User should not call
105// it in most cases.
106func NewChannelzStorage() (cleanup func() error) {
107	db.set(&channelMap{
108		topLevelChannels: make(map[int64]struct{}),
109		channels:         make(map[int64]*channel),
110		listenSockets:    make(map[int64]*listenSocket),
111		normalSockets:    make(map[int64]*normalSocket),
112		servers:          make(map[int64]*server),
113		subChannels:      make(map[int64]*subChannel),
114	})
115	idGen.reset()
116	return func() error {
117		var err error
118		cm := db.get()
119		if cm == nil {
120			return nil
121		}
122		for i := 0; i < 1000; i++ {
123			cm.mu.Lock()
124			if len(cm.topLevelChannels) == 0 && len(cm.servers) == 0 && len(cm.channels) == 0 && len(cm.subChannels) == 0 && len(cm.listenSockets) == 0 && len(cm.normalSockets) == 0 {
125				cm.mu.Unlock()
126				// all things stored in the channelz map have been cleared.
127				return nil
128			}
129			cm.mu.Unlock()
130			time.Sleep(10 * time.Millisecond)
131		}
132
133		cm.mu.Lock()
134		err = fmt.Errorf("after 10s the channelz map has not been cleaned up yet, topchannels: %d, servers: %d, channels: %d, subchannels: %d, listen sockets: %d, normal sockets: %d", len(cm.topLevelChannels), len(cm.servers), len(cm.channels), len(cm.subChannels), len(cm.listenSockets), len(cm.normalSockets))
135		cm.mu.Unlock()
136		return err
137	}
138}
139
140// GetTopChannels returns a slice of top channel's ChannelMetric, along with a
141// boolean indicating whether there's more top channels to be queried for.
142//
143// The arg id specifies that only top channel with id at or above it will be included
144// in the result. The returned slice is up to a length of the arg maxResults or
145// EntryPerPage if maxResults is zero, and is sorted in ascending id order.
146func GetTopChannels(id int64, maxResults int64) ([]*ChannelMetric, bool) {
147	return db.get().GetTopChannels(id, maxResults)
148}
149
150// GetServers returns a slice of server's ServerMetric, along with a
151// boolean indicating whether there's more servers to be queried for.
152//
153// The arg id specifies that only server with id at or above it will be included
154// in the result. The returned slice is up to a length of the arg maxResults or
155// EntryPerPage if maxResults is zero, and is sorted in ascending id order.
156func GetServers(id int64, maxResults int64) ([]*ServerMetric, bool) {
157	return db.get().GetServers(id, maxResults)
158}
159
160// GetServerSockets returns a slice of server's (identified by id) normal socket's
161// SocketMetric, along with a boolean indicating whether there's more sockets to
162// be queried for.
163//
164// The arg startID specifies that only sockets with id at or above it will be
165// included in the result. The returned slice is up to a length of the arg maxResults
166// or EntryPerPage if maxResults is zero, and is sorted in ascending id order.
167func GetServerSockets(id int64, startID int64, maxResults int64) ([]*SocketMetric, bool) {
168	return db.get().GetServerSockets(id, startID, maxResults)
169}
170
171// GetChannel returns the ChannelMetric for the channel (identified by id).
172func GetChannel(id int64) *ChannelMetric {
173	return db.get().GetChannel(id)
174}
175
176// GetSubChannel returns the SubChannelMetric for the subchannel (identified by id).
177func GetSubChannel(id int64) *SubChannelMetric {
178	return db.get().GetSubChannel(id)
179}
180
181// GetSocket returns the SocketInternalMetric for the socket (identified by id).
182func GetSocket(id int64) *SocketMetric {
183	return db.get().GetSocket(id)
184}
185
186// GetServer returns the ServerMetric for the server (identified by id).
187func GetServer(id int64) *ServerMetric {
188	return db.get().GetServer(id)
189}
190
191// RegisterChannel registers the given channel c in channelz database with ref
192// as its reference name, and add it to the child list of its parent (identified
193// by pid). pid = 0 means no parent. It returns the unique channelz tracking id
194// assigned to this channel.
195func RegisterChannel(c Channel, pid int64, ref string) int64 {
196	id := idGen.genID()
197	cn := &channel{
198		refName:     ref,
199		c:           c,
200		subChans:    make(map[int64]string),
201		nestedChans: make(map[int64]string),
202		id:          id,
203		pid:         pid,
204		trace:       &channelTrace{createdTime: time.Now(), events: make([]*TraceEvent, 0, getMaxTraceEntry())},
205	}
206	if pid == 0 {
207		db.get().addChannel(id, cn, true, pid, ref)
208	} else {
209		db.get().addChannel(id, cn, false, pid, ref)
210	}
211	return id
212}
213
214// RegisterSubChannel registers the given channel c in channelz database with ref
215// as its reference name, and add it to the child list of its parent (identified
216// by pid). It returns the unique channelz tracking id assigned to this subchannel.
217func RegisterSubChannel(c Channel, pid int64, ref string) int64 {
218	if pid == 0 {
219		grpclog.Error("a SubChannel's parent id cannot be 0")
220		return 0
221	}
222	id := idGen.genID()
223	sc := &subChannel{
224		refName: ref,
225		c:       c,
226		sockets: make(map[int64]string),
227		id:      id,
228		pid:     pid,
229		trace:   &channelTrace{createdTime: time.Now(), events: make([]*TraceEvent, 0, getMaxTraceEntry())},
230	}
231	db.get().addSubChannel(id, sc, pid, ref)
232	return id
233}
234
235// RegisterServer registers the given server s in channelz database. It returns
236// the unique channelz tracking id assigned to this server.
237func RegisterServer(s Server, ref string) int64 {
238	id := idGen.genID()
239	svr := &server{
240		refName:       ref,
241		s:             s,
242		sockets:       make(map[int64]string),
243		listenSockets: make(map[int64]string),
244		id:            id,
245	}
246	db.get().addServer(id, svr)
247	return id
248}
249
250// RegisterListenSocket registers the given listen socket s in channelz database
251// with ref as its reference name, and add it to the child list of its parent
252// (identified by pid). It returns the unique channelz tracking id assigned to
253// this listen socket.
254func RegisterListenSocket(s Socket, pid int64, ref string) int64 {
255	if pid == 0 {
256		grpclog.Error("a ListenSocket's parent id cannot be 0")
257		return 0
258	}
259	id := idGen.genID()
260	ls := &listenSocket{refName: ref, s: s, id: id, pid: pid}
261	db.get().addListenSocket(id, ls, pid, ref)
262	return id
263}
264
265// RegisterNormalSocket registers the given normal socket s in channelz database
266// with ref as its reference name, and add it to the child list of its parent
267// (identified by pid). It returns the unique channelz tracking id assigned to
268// this normal socket.
269func RegisterNormalSocket(s Socket, pid int64, ref string) int64 {
270	if pid == 0 {
271		grpclog.Error("a NormalSocket's parent id cannot be 0")
272		return 0
273	}
274	id := idGen.genID()
275	ns := &normalSocket{refName: ref, s: s, id: id, pid: pid}
276	db.get().addNormalSocket(id, ns, pid, ref)
277	return id
278}
279
280// RemoveEntry removes an entry with unique channelz trakcing id to be id from
281// channelz database.
282func RemoveEntry(id int64) {
283	db.get().removeEntry(id)
284}
285
286// TraceEventDesc is what the caller of AddTraceEvent should provide to describe the event to be added
287// to the channel trace.
288// The Parent field is optional. It is used for event that will be recorded in the entity's parent
289// trace also.
290type TraceEventDesc struct {
291	Desc     string
292	Severity Severity
293	Parent   *TraceEventDesc
294}
295
296// AddTraceEvent adds trace related to the entity with specified id, using the provided TraceEventDesc.
297func AddTraceEvent(id int64, desc *TraceEventDesc) {
298	if getMaxTraceEntry() == 0 {
299		return
300	}
301	db.get().traceEvent(id, desc)
302}
303
304// channelMap is the storage data structure for channelz.
305// Methods of channelMap can be divided in two two categories with respect to locking.
306// 1. Methods acquire the global lock.
307// 2. Methods that can only be called when global lock is held.
308// A second type of method need always to be called inside a first type of method.
309type channelMap struct {
310	mu               sync.RWMutex
311	topLevelChannels map[int64]struct{}
312	servers          map[int64]*server
313	channels         map[int64]*channel
314	subChannels      map[int64]*subChannel
315	listenSockets    map[int64]*listenSocket
316	normalSockets    map[int64]*normalSocket
317}
318
319func (c *channelMap) addServer(id int64, s *server) {
320	c.mu.Lock()
321	s.cm = c
322	c.servers[id] = s
323	c.mu.Unlock()
324}
325
326func (c *channelMap) addChannel(id int64, cn *channel, isTopChannel bool, pid int64, ref string) {
327	c.mu.Lock()
328	cn.cm = c
329	cn.trace.cm = c
330	c.channels[id] = cn
331	if isTopChannel {
332		c.topLevelChannels[id] = struct{}{}
333	} else {
334		c.findEntry(pid).addChild(id, cn)
335	}
336	c.mu.Unlock()
337}
338
339func (c *channelMap) addSubChannel(id int64, sc *subChannel, pid int64, ref string) {
340	c.mu.Lock()
341	sc.cm = c
342	sc.trace.cm = c
343	c.subChannels[id] = sc
344	c.findEntry(pid).addChild(id, sc)
345	c.mu.Unlock()
346}
347
348func (c *channelMap) addListenSocket(id int64, ls *listenSocket, pid int64, ref string) {
349	c.mu.Lock()
350	ls.cm = c
351	c.listenSockets[id] = ls
352	c.findEntry(pid).addChild(id, ls)
353	c.mu.Unlock()
354}
355
356func (c *channelMap) addNormalSocket(id int64, ns *normalSocket, pid int64, ref string) {
357	c.mu.Lock()
358	ns.cm = c
359	c.normalSockets[id] = ns
360	c.findEntry(pid).addChild(id, ns)
361	c.mu.Unlock()
362}
363
364// removeEntry triggers the removal of an entry, which may not indeed delete the entry, if it has to
365// wait on the deletion of its children and until no other entity's channel trace references it.
366// It may lead to a chain of entry deletion. For example, deleting the last socket of a gracefully
367// shutting down server will lead to the server being also deleted.
368func (c *channelMap) removeEntry(id int64) {
369	c.mu.Lock()
370	c.findEntry(id).triggerDelete()
371	c.mu.Unlock()
372}
373
374// c.mu must be held by the caller
375func (c *channelMap) decrTraceRefCount(id int64) {
376	e := c.findEntry(id)
377	if v, ok := e.(tracedChannel); ok {
378		v.decrTraceRefCount()
379		e.deleteSelfIfReady()
380	}
381}
382
383// c.mu must be held by the caller.
384func (c *channelMap) findEntry(id int64) entry {
385	var v entry
386	var ok bool
387	if v, ok = c.channels[id]; ok {
388		return v
389	}
390	if v, ok = c.subChannels[id]; ok {
391		return v
392	}
393	if v, ok = c.servers[id]; ok {
394		return v
395	}
396	if v, ok = c.listenSockets[id]; ok {
397		return v
398	}
399	if v, ok = c.normalSockets[id]; ok {
400		return v
401	}
402	return &dummyEntry{idNotFound: id}
403}
404
405// c.mu must be held by the caller
406// deleteEntry simply deletes an entry from the channelMap. Before calling this
407// method, caller must check this entry is ready to be deleted, i.e removeEntry()
408// has been called on it, and no children still exist.
409// Conditionals are ordered by the expected frequency of deletion of each entity
410// type, in order to optimize performance.
411func (c *channelMap) deleteEntry(id int64) {
412	var ok bool
413	if _, ok = c.normalSockets[id]; ok {
414		delete(c.normalSockets, id)
415		return
416	}
417	if _, ok = c.subChannels[id]; ok {
418		delete(c.subChannels, id)
419		return
420	}
421	if _, ok = c.channels[id]; ok {
422		delete(c.channels, id)
423		delete(c.topLevelChannels, id)
424		return
425	}
426	if _, ok = c.listenSockets[id]; ok {
427		delete(c.listenSockets, id)
428		return
429	}
430	if _, ok = c.servers[id]; ok {
431		delete(c.servers, id)
432		return
433	}
434}
435
436func (c *channelMap) traceEvent(id int64, desc *TraceEventDesc) {
437	c.mu.Lock()
438	child := c.findEntry(id)
439	childTC, ok := child.(tracedChannel)
440	if !ok {
441		c.mu.Unlock()
442		return
443	}
444	childTC.getChannelTrace().append(&TraceEvent{Desc: desc.Desc, Severity: desc.Severity, Timestamp: time.Now()})
445	if desc.Parent != nil {
446		parent := c.findEntry(child.getParentID())
447		var chanType RefChannelType
448		switch child.(type) {
449		case *channel:
450			chanType = RefChannel
451		case *subChannel:
452			chanType = RefSubChannel
453		}
454		if parentTC, ok := parent.(tracedChannel); ok {
455			parentTC.getChannelTrace().append(&TraceEvent{
456				Desc:      desc.Parent.Desc,
457				Severity:  desc.Parent.Severity,
458				Timestamp: time.Now(),
459				RefID:     id,
460				RefName:   childTC.getRefName(),
461				RefType:   chanType,
462			})
463			childTC.incrTraceRefCount()
464		}
465	}
466	c.mu.Unlock()
467}
468
469type int64Slice []int64
470
471func (s int64Slice) Len() int           { return len(s) }
472func (s int64Slice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
473func (s int64Slice) Less(i, j int) bool { return s[i] < s[j] }
474
475func copyMap(m map[int64]string) map[int64]string {
476	n := make(map[int64]string)
477	for k, v := range m {
478		n[k] = v
479	}
480	return n
481}
482
483func min(a, b int64) int64 {
484	if a < b {
485		return a
486	}
487	return b
488}
489
490func (c *channelMap) GetTopChannels(id int64, maxResults int64) ([]*ChannelMetric, bool) {
491	if maxResults <= 0 {
492		maxResults = EntryPerPage
493	}
494	c.mu.RLock()
495	l := int64(len(c.topLevelChannels))
496	ids := make([]int64, 0, l)
497	cns := make([]*channel, 0, min(l, maxResults))
498
499	for k := range c.topLevelChannels {
500		ids = append(ids, k)
501	}
502	sort.Sort(int64Slice(ids))
503	idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
504	count := int64(0)
505	var end bool
506	var t []*ChannelMetric
507	for i, v := range ids[idx:] {
508		if count == maxResults {
509			break
510		}
511		if cn, ok := c.channels[v]; ok {
512			cns = append(cns, cn)
513			t = append(t, &ChannelMetric{
514				NestedChans: copyMap(cn.nestedChans),
515				SubChans:    copyMap(cn.subChans),
516			})
517			count++
518		}
519		if i == len(ids[idx:])-1 {
520			end = true
521			break
522		}
523	}
524	c.mu.RUnlock()
525	if count == 0 {
526		end = true
527	}
528
529	for i, cn := range cns {
530		t[i].ChannelData = cn.c.ChannelzMetric()
531		t[i].ID = cn.id
532		t[i].RefName = cn.refName
533		t[i].Trace = cn.trace.dumpData()
534	}
535	return t, end
536}
537
538func (c *channelMap) GetServers(id, maxResults int64) ([]*ServerMetric, bool) {
539	if maxResults <= 0 {
540		maxResults = EntryPerPage
541	}
542	c.mu.RLock()
543	l := int64(len(c.servers))
544	ids := make([]int64, 0, l)
545	ss := make([]*server, 0, min(l, maxResults))
546	for k := range c.servers {
547		ids = append(ids, k)
548	}
549	sort.Sort(int64Slice(ids))
550	idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
551	count := int64(0)
552	var end bool
553	var s []*ServerMetric
554	for i, v := range ids[idx:] {
555		if count == maxResults {
556			break
557		}
558		if svr, ok := c.servers[v]; ok {
559			ss = append(ss, svr)
560			s = append(s, &ServerMetric{
561				ListenSockets: copyMap(svr.listenSockets),
562			})
563			count++
564		}
565		if i == len(ids[idx:])-1 {
566			end = true
567			break
568		}
569	}
570	c.mu.RUnlock()
571	if count == 0 {
572		end = true
573	}
574
575	for i, svr := range ss {
576		s[i].ServerData = svr.s.ChannelzMetric()
577		s[i].ID = svr.id
578		s[i].RefName = svr.refName
579	}
580	return s, end
581}
582
583func (c *channelMap) GetServerSockets(id int64, startID int64, maxResults int64) ([]*SocketMetric, bool) {
584	if maxResults <= 0 {
585		maxResults = EntryPerPage
586	}
587	var svr *server
588	var ok bool
589	c.mu.RLock()
590	if svr, ok = c.servers[id]; !ok {
591		// server with id doesn't exist.
592		c.mu.RUnlock()
593		return nil, true
594	}
595	svrskts := svr.sockets
596	l := int64(len(svrskts))
597	ids := make([]int64, 0, l)
598	sks := make([]*normalSocket, 0, min(l, maxResults))
599	for k := range svrskts {
600		ids = append(ids, k)
601	}
602	sort.Sort(int64Slice(ids))
603	idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= startID })
604	count := int64(0)
605	var end bool
606	for i, v := range ids[idx:] {
607		if count == maxResults {
608			break
609		}
610		if ns, ok := c.normalSockets[v]; ok {
611			sks = append(sks, ns)
612			count++
613		}
614		if i == len(ids[idx:])-1 {
615			end = true
616			break
617		}
618	}
619	c.mu.RUnlock()
620	if count == 0 {
621		end = true
622	}
623	var s []*SocketMetric
624	for _, ns := range sks {
625		sm := &SocketMetric{}
626		sm.SocketData = ns.s.ChannelzMetric()
627		sm.ID = ns.id
628		sm.RefName = ns.refName
629		s = append(s, sm)
630	}
631	return s, end
632}
633
634func (c *channelMap) GetChannel(id int64) *ChannelMetric {
635	cm := &ChannelMetric{}
636	var cn *channel
637	var ok bool
638	c.mu.RLock()
639	if cn, ok = c.channels[id]; !ok {
640		// channel with id doesn't exist.
641		c.mu.RUnlock()
642		return nil
643	}
644	cm.NestedChans = copyMap(cn.nestedChans)
645	cm.SubChans = copyMap(cn.subChans)
646	// cn.c can be set to &dummyChannel{} when deleteSelfFromMap is called. Save a copy of cn.c when
647	// holding the lock to prevent potential data race.
648	chanCopy := cn.c
649	c.mu.RUnlock()
650	cm.ChannelData = chanCopy.ChannelzMetric()
651	cm.ID = cn.id
652	cm.RefName = cn.refName
653	cm.Trace = cn.trace.dumpData()
654	return cm
655}
656
657func (c *channelMap) GetSubChannel(id int64) *SubChannelMetric {
658	cm := &SubChannelMetric{}
659	var sc *subChannel
660	var ok bool
661	c.mu.RLock()
662	if sc, ok = c.subChannels[id]; !ok {
663		// subchannel with id doesn't exist.
664		c.mu.RUnlock()
665		return nil
666	}
667	cm.Sockets = copyMap(sc.sockets)
668	// sc.c can be set to &dummyChannel{} when deleteSelfFromMap is called. Save a copy of sc.c when
669	// holding the lock to prevent potential data race.
670	chanCopy := sc.c
671	c.mu.RUnlock()
672	cm.ChannelData = chanCopy.ChannelzMetric()
673	cm.ID = sc.id
674	cm.RefName = sc.refName
675	cm.Trace = sc.trace.dumpData()
676	return cm
677}
678
679func (c *channelMap) GetSocket(id int64) *SocketMetric {
680	sm := &SocketMetric{}
681	c.mu.RLock()
682	if ls, ok := c.listenSockets[id]; ok {
683		c.mu.RUnlock()
684		sm.SocketData = ls.s.ChannelzMetric()
685		sm.ID = ls.id
686		sm.RefName = ls.refName
687		return sm
688	}
689	if ns, ok := c.normalSockets[id]; ok {
690		c.mu.RUnlock()
691		sm.SocketData = ns.s.ChannelzMetric()
692		sm.ID = ns.id
693		sm.RefName = ns.refName
694		return sm
695	}
696	c.mu.RUnlock()
697	return nil
698}
699
700func (c *channelMap) GetServer(id int64) *ServerMetric {
701	sm := &ServerMetric{}
702	var svr *server
703	var ok bool
704	c.mu.RLock()
705	if svr, ok = c.servers[id]; !ok {
706		c.mu.RUnlock()
707		return nil
708	}
709	sm.ListenSockets = copyMap(svr.listenSockets)
710	c.mu.RUnlock()
711	sm.ID = svr.id
712	sm.RefName = svr.refName
713	sm.ServerData = svr.s.ChannelzMetric()
714	return sm
715}
716
717type idGenerator struct {
718	id int64
719}
720
721func (i *idGenerator) reset() {
722	atomic.StoreInt64(&i.id, 0)
723}
724
725func (i *idGenerator) genID() int64 {
726	return atomic.AddInt64(&i.id, 1)
727}
728