1package identify
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"sync"
8	"time"
9
10	"github.com/libp2p/go-libp2p-core/network"
11	"github.com/libp2p/go-libp2p-core/peer"
12	"github.com/libp2p/go-libp2p-core/protocol"
13	"github.com/libp2p/go-libp2p-core/record"
14
15	pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb"
16
17	"github.com/libp2p/go-msgio/protoio"
18	ma "github.com/multiformats/go-multiaddr"
19)
20
21var errProtocolNotSupported = errors.New("protocol not supported")
22
23type identifySnapshot struct {
24	protocols []string
25	addrs     []ma.Multiaddr
26	record    *record.Envelope
27}
28
29type peerHandler struct {
30	ids     *IDService
31	started bool
32
33	ctx    context.Context
34	cancel context.CancelFunc
35
36	pid peer.ID
37
38	snapshotMu sync.RWMutex
39	snapshot   *identifySnapshot
40
41	pushCh  chan struct{}
42	deltaCh chan struct{}
43}
44
45func newPeerHandler(pid peer.ID, ids *IDService) *peerHandler {
46	ph := &peerHandler{
47		ids: ids,
48		pid: pid,
49
50		snapshot: ids.getSnapshot(),
51
52		pushCh:  make(chan struct{}, 1),
53		deltaCh: make(chan struct{}, 1),
54	}
55
56	return ph
57}
58
59// start starts a handler. This may only be called on a stopped handler, and must
60// not be called concurrently with start/stop.
61//
62// This may _not_ be called on a _canceled_ handler. I.e., a handler where the
63// passed in context expired.
64func (ph *peerHandler) start(ctx context.Context, onExit func()) {
65	if ph.cancel != nil {
66		// If this happens, we have a bug. It means we tried to start
67		// before we stopped.
68		panic("peer handler already running")
69	}
70
71	ctx, cancel := context.WithCancel(ctx)
72	ph.cancel = cancel
73
74	go ph.loop(ctx, onExit)
75}
76
77// stop stops a handler. This may not be called concurrently with any
78// other calls to stop/start.
79func (ph *peerHandler) stop() error {
80	if ph.cancel != nil {
81		ph.cancel()
82		ph.cancel = nil
83	}
84	return nil
85}
86
87// per peer loop for pushing updates
88func (ph *peerHandler) loop(ctx context.Context, onExit func()) {
89	defer onExit()
90
91	for {
92		select {
93		// our listen addresses have changed, send an IDPush.
94		case <-ph.pushCh:
95			if err := ph.sendPush(ctx); err != nil {
96				log.Warnw("failed to send Identify Push", "peer", ph.pid, "error", err)
97			}
98
99		case <-ph.deltaCh:
100			if err := ph.sendDelta(ctx); err != nil {
101				log.Warnw("failed to send Identify Delta", "peer", ph.pid, "error", err)
102			}
103
104		case <-ctx.Done():
105			return
106		}
107	}
108}
109
110func (ph *peerHandler) sendDelta(ctx context.Context) error {
111	// send a push if the peer does not support the Delta protocol.
112	if !ph.peerSupportsProtos(ctx, []string{IDDelta}) {
113		log.Debugw("will send push as peer does not support delta", "peer", ph.pid)
114		if err := ph.sendPush(ctx); err != nil {
115			return fmt.Errorf("failed to send push on delta message: %w", err)
116		}
117		return nil
118	}
119
120	// extract a delta message, updating the last state.
121	mes := ph.nextDelta()
122	if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) {
123		return nil
124	}
125
126	ds, err := ph.openStream(ctx, []string{IDDelta})
127	if err != nil {
128		return fmt.Errorf("failed to open delta stream: %w", err)
129	}
130
131	defer ds.Close()
132
133	c := ds.Conn()
134	if err := protoio.NewDelimitedWriter(ds).WriteMsg(&pb.Identify{Delta: mes}); err != nil {
135		_ = ds.Reset()
136		return fmt.Errorf("failed to send delta message, %w", err)
137	}
138	log.Debugw("sent identify update", "protocol", ds.Protocol(), "peer", c.RemotePeer(),
139		"peer address", c.RemoteMultiaddr())
140
141	return nil
142}
143
144func (ph *peerHandler) sendPush(ctx context.Context) error {
145	dp, err := ph.openStream(ctx, []string{IDPush})
146	if err == errProtocolNotSupported {
147		log.Debugw("not sending push as peer does not support protocol", "peer", ph.pid)
148		return nil
149	}
150	if err != nil {
151		return fmt.Errorf("failed to open push stream: %w", err)
152	}
153	defer dp.Close()
154
155	snapshot := ph.ids.getSnapshot()
156	ph.snapshotMu.Lock()
157	ph.snapshot = snapshot
158	ph.snapshotMu.Unlock()
159	if err := ph.ids.writeChunkedIdentifyMsg(dp.Conn(), snapshot, dp); err != nil {
160		_ = dp.Reset()
161		return fmt.Errorf("failed to send push message: %w", err)
162	}
163
164	return nil
165}
166
167func (ph *peerHandler) openStream(ctx context.Context, protos []string) (network.Stream, error) {
168	// wait for the other peer to send us an Identify response on "all" connections we have with it
169	// so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol
170	// if we know for a fact that it dosen't support the protocol.
171	conns := ph.ids.Host.Network().ConnsToPeer(ph.pid)
172	for _, c := range conns {
173		select {
174		case <-ph.ids.IdentifyWait(c):
175		case <-ctx.Done():
176			return nil, ctx.Err()
177		}
178	}
179
180	if !ph.peerSupportsProtos(ctx, protos) {
181		return nil, errProtocolNotSupported
182	}
183
184	// negotiate a stream without opening a new connection as we "should" already have a connection.
185	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
186	defer cancel()
187	ctx = network.WithNoDial(ctx, "should already have connection")
188
189	// newstream will open a stream on the first protocol the remote peer supports from the among
190	// the list of protocols passed to it.
191	s, err := ph.ids.Host.NewStream(ctx, ph.pid, protocol.ConvertFromStrings(protos)...)
192	if err != nil {
193		return nil, err
194	}
195
196	return s, err
197}
198
199// returns true if the peer supports atleast one of the given protocols
200func (ph *peerHandler) peerSupportsProtos(ctx context.Context, protos []string) bool {
201	conns := ph.ids.Host.Network().ConnsToPeer(ph.pid)
202	for _, c := range conns {
203		select {
204		case <-ph.ids.IdentifyWait(c):
205		case <-ctx.Done():
206			return false
207		}
208	}
209
210	pstore := ph.ids.Host.Peerstore()
211
212	if sup, err := pstore.SupportsProtocols(ph.pid, protos...); err == nil && len(sup) == 0 {
213		return false
214	}
215	return true
216}
217
218func (ph *peerHandler) nextDelta() *pb.Delta {
219	curr := ph.ids.Host.Mux().Protocols()
220
221	// Extract the old protocol list and replace the old snapshot with an
222	// updated one.
223	ph.snapshotMu.Lock()
224	snapshot := *ph.snapshot
225	old := snapshot.protocols
226	snapshot.protocols = curr
227	ph.snapshot = &snapshot
228	ph.snapshotMu.Unlock()
229
230	oldProtos := make(map[string]struct{}, len(old))
231	currProtos := make(map[string]struct{}, len(curr))
232
233	for _, proto := range old {
234		oldProtos[proto] = struct{}{}
235	}
236
237	for _, proto := range curr {
238		currProtos[proto] = struct{}{}
239	}
240
241	var added []string
242	var removed []string
243
244	// has it been added ?
245	for p := range currProtos {
246		if _, ok := oldProtos[p]; !ok {
247			added = append(added, p)
248		}
249	}
250
251	// has it been removed ?
252	for p := range oldProtos {
253		if _, ok := currProtos[p]; !ok {
254			removed = append(removed, p)
255		}
256	}
257
258	return &pb.Delta{
259		AddedProtocols: added,
260		RmProtocols:    removed,
261	}
262}
263