1package pubsub
2
3import (
4	"math/rand"
5	"sync"
6	"time"
7
8	"github.com/libp2p/go-libp2p-core/peer"
9	"github.com/libp2p/go-libp2p-core/protocol"
10)
11
12// gossipTracer is an internal tracer that tracks IWANT requests in order to penalize
13// peers who don't follow up on IWANT requests after an IHAVE advertisement.
14// The tracking of promises is probabilistic to avoid using too much memory.
15type gossipTracer struct {
16	sync.Mutex
17
18	msgID MsgIdFunction
19
20	// promises for messages by message ID; for each message tracked, we track the promise
21	// expiration time for each peer.
22	promises map[string]map[peer.ID]time.Time
23	// promises for each peer; for each peer, we track the promised message IDs.
24	// this index allows us to quickly void promises when a peer is throttled.
25	peerPromises map[peer.ID]map[string]struct{}
26}
27
28func newGossipTracer() *gossipTracer {
29	return &gossipTracer{
30		msgID:        DefaultMsgIdFn,
31		promises:     make(map[string]map[peer.ID]time.Time),
32		peerPromises: make(map[peer.ID]map[string]struct{}),
33	}
34}
35
36func (gt *gossipTracer) Start(gs *GossipSubRouter) {
37	if gt == nil {
38		return
39	}
40
41	gt.msgID = gs.p.msgID
42}
43
44// track a promise to deliver a message from a list of msgIDs we are requesting
45func (gt *gossipTracer) AddPromise(p peer.ID, msgIDs []string) {
46	if gt == nil {
47		return
48	}
49
50	idx := rand.Intn(len(msgIDs))
51	mid := msgIDs[idx]
52
53	gt.Lock()
54	defer gt.Unlock()
55
56	promises, ok := gt.promises[mid]
57	if !ok {
58		promises = make(map[peer.ID]time.Time)
59		gt.promises[mid] = promises
60	}
61
62	_, ok = promises[p]
63	if !ok {
64		promises[p] = time.Now().Add(GossipSubIWantFollowupTime)
65		peerPromises, ok := gt.peerPromises[p]
66		if !ok {
67			peerPromises = make(map[string]struct{})
68			gt.peerPromises[p] = peerPromises
69		}
70		peerPromises[mid] = struct{}{}
71	}
72}
73
74// returns the number of broken promises for each peer who didn't follow up
75// on an IWANT request.
76func (gt *gossipTracer) GetBrokenPromises() map[peer.ID]int {
77	if gt == nil {
78		return nil
79	}
80
81	gt.Lock()
82	defer gt.Unlock()
83
84	var res map[peer.ID]int
85	now := time.Now()
86
87	// find broken promises from peers
88	for mid, promises := range gt.promises {
89		for p, expire := range promises {
90			if expire.Before(now) {
91				if res == nil {
92					res = make(map[peer.ID]int)
93				}
94				res[p]++
95
96				delete(promises, p)
97
98				peerPromises := gt.peerPromises[p]
99				delete(peerPromises, mid)
100				if len(peerPromises) == 0 {
101					delete(gt.peerPromises, p)
102				}
103			}
104		}
105
106		if len(promises) == 0 {
107			delete(gt.promises, mid)
108		}
109	}
110
111	return res
112}
113
114var _ internalTracer = (*gossipTracer)(nil)
115
116func (gt *gossipTracer) fulfillPromise(msg *Message) {
117	mid := gt.msgID(msg.Message)
118
119	gt.Lock()
120	defer gt.Unlock()
121
122	delete(gt.promises, mid)
123}
124
125func (gt *gossipTracer) DeliverMessage(msg *Message) {
126	// someone delivered a message, fulfill promises for it
127	gt.fulfillPromise(msg)
128}
129
130func (gt *gossipTracer) RejectMessage(msg *Message, reason string) {
131	// A message got rejected, so we can fulfill promises and let the score penalty apply
132	// from invalid message delivery.
133	// We do take exception and apply promise penalty regardless in the following cases, where
134	// the peer delivered an obviously invalid message.
135	switch reason {
136	case rejectMissingSignature:
137		return
138	case rejectInvalidSignature:
139		return
140	}
141
142	gt.fulfillPromise(msg)
143}
144
145func (gt *gossipTracer) ValidateMessage(msg *Message) {
146	// we consider the promise fulfilled as soon as the message begins validation
147	// if it was a case of signature issue it would have been rejected immediately
148	// without triggering the Validate trace
149	gt.fulfillPromise(msg)
150}
151
152func (gt *gossipTracer) AddPeer(p peer.ID, proto protocol.ID) {}
153func (gt *gossipTracer) RemovePeer(p peer.ID)                 {}
154func (gt *gossipTracer) Join(topic string)                    {}
155func (gt *gossipTracer) Leave(topic string)                   {}
156func (gt *gossipTracer) Graft(p peer.ID, topic string)        {}
157func (gt *gossipTracer) Prune(p peer.ID, topic string)        {}
158func (gt *gossipTracer) DuplicateMessage(msg *Message)        {}
159
160func (gt *gossipTracer) ThrottlePeer(p peer.ID) {
161	gt.Lock()
162	defer gt.Unlock()
163
164	peerPromises, ok := gt.peerPromises[p]
165	if !ok {
166		return
167	}
168
169	for mid := range peerPromises {
170		promises := gt.promises[mid]
171		delete(promises, p)
172		if len(promises) == 0 {
173			delete(gt.promises, mid)
174		}
175	}
176
177	delete(gt.peerPromises, p)
178}
179