1// Copyright 2018 The go-ethereum Authors
2// This file is part of the go-ethereum library.
3//
4// The go-ethereum library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Lesser General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// The go-ethereum library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU Lesser General Public License for more details.
13//
14// You should have received a copy of the GNU Lesser General Public License
15// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16
17package enode
18
19import (
20	"crypto/ecdsa"
21	"fmt"
22	"net"
23	"reflect"
24	"strconv"
25	"sync"
26	"sync/atomic"
27	"time"
28
29	"github.com/ethereum/go-ethereum/log"
30	"github.com/ethereum/go-ethereum/p2p/enr"
31	"github.com/ethereum/go-ethereum/p2p/netutil"
32)
33
34const (
35	// IP tracker configuration
36	iptrackMinStatements = 10
37	iptrackWindow        = 5 * time.Minute
38	iptrackContactWindow = 10 * time.Minute
39
40	// time needed to wait between two updates to the local ENR
41	recordUpdateThrottle = time.Millisecond
42)
43
44// LocalNode produces the signed node record of a local node, i.e. a node run in the
45// current process. Setting ENR entries via the Set method updates the record. A new version
46// of the record is signed on demand when the Node method is called.
47type LocalNode struct {
48	cur atomic.Value // holds a non-nil node pointer while the record is up-to-date
49
50	id  ID
51	key *ecdsa.PrivateKey
52	db  *DB
53
54	// everything below is protected by a lock
55	mu        sync.RWMutex
56	seq       uint64
57	update    time.Time // timestamp when the record was last updated
58	entries   map[string]enr.Entry
59	endpoint4 lnEndpoint
60	endpoint6 lnEndpoint
61}
62
63type lnEndpoint struct {
64	track                *netutil.IPTracker
65	staticIP, fallbackIP net.IP
66	fallbackUDP          uint16 // port
67}
68
69// NewLocalNode creates a local node.
70func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
71	ln := &LocalNode{
72		id:      PubkeyToIDV4(&key.PublicKey),
73		db:      db,
74		key:     key,
75		entries: make(map[string]enr.Entry),
76		endpoint4: lnEndpoint{
77			track: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements),
78		},
79		endpoint6: lnEndpoint{
80			track: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements),
81		},
82	}
83	ln.seq = db.localSeq(ln.id)
84	ln.update = time.Now()
85	ln.cur.Store((*Node)(nil))
86	return ln
87}
88
89// Database returns the node database associated with the local node.
90func (ln *LocalNode) Database() *DB {
91	return ln.db
92}
93
94// Node returns the current version of the local node record.
95func (ln *LocalNode) Node() *Node {
96	// If we have a valid record, return that
97	n := ln.cur.Load().(*Node)
98	if n != nil {
99		return n
100	}
101
102	// Record was invalidated, sign a new copy.
103	ln.mu.Lock()
104	defer ln.mu.Unlock()
105
106	// Double check the current record, since multiple goroutines might be waiting
107	// on the write mutex.
108	if n = ln.cur.Load().(*Node); n != nil {
109		return n
110	}
111
112	// The initial sequence number is the current timestamp in milliseconds. To ensure
113	// that the initial sequence number will always be higher than any previous sequence
114	// number (assuming the clock is correct), we want to avoid updating the record faster
115	// than once per ms. So we need to sleep here until the next possible update time has
116	// arrived.
117	lastChange := time.Since(ln.update)
118	if lastChange < recordUpdateThrottle {
119		time.Sleep(recordUpdateThrottle - lastChange)
120	}
121
122	ln.sign()
123	ln.update = time.Now()
124	return ln.cur.Load().(*Node)
125}
126
127// Seq returns the current sequence number of the local node record.
128func (ln *LocalNode) Seq() uint64 {
129	ln.mu.Lock()
130	defer ln.mu.Unlock()
131
132	return ln.seq
133}
134
135// ID returns the local node ID.
136func (ln *LocalNode) ID() ID {
137	return ln.id
138}
139
140// Set puts the given entry into the local record, overwriting any existing value.
141// Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll
142// be overwritten by the endpoint predictor.
143//
144// Since node record updates are throttled to one per second, Set is asynchronous.
145// Any update will be queued up and published when at least one second passes from
146// the last change.
147func (ln *LocalNode) Set(e enr.Entry) {
148	ln.mu.Lock()
149	defer ln.mu.Unlock()
150
151	ln.set(e)
152}
153
154func (ln *LocalNode) set(e enr.Entry) {
155	val, exists := ln.entries[e.ENRKey()]
156	if !exists || !reflect.DeepEqual(val, e) {
157		ln.entries[e.ENRKey()] = e
158		ln.invalidate()
159	}
160}
161
162// Delete removes the given entry from the local record.
163func (ln *LocalNode) Delete(e enr.Entry) {
164	ln.mu.Lock()
165	defer ln.mu.Unlock()
166
167	ln.delete(e)
168}
169
170func (ln *LocalNode) delete(e enr.Entry) {
171	_, exists := ln.entries[e.ENRKey()]
172	if exists {
173		delete(ln.entries, e.ENRKey())
174		ln.invalidate()
175	}
176}
177
178func (ln *LocalNode) endpointForIP(ip net.IP) *lnEndpoint {
179	if ip.To4() != nil {
180		return &ln.endpoint4
181	}
182	return &ln.endpoint6
183}
184
185// SetStaticIP sets the local IP to the given one unconditionally.
186// This disables endpoint prediction.
187func (ln *LocalNode) SetStaticIP(ip net.IP) {
188	ln.mu.Lock()
189	defer ln.mu.Unlock()
190
191	ln.endpointForIP(ip).staticIP = ip
192	ln.updateEndpoints()
193}
194
195// SetFallbackIP sets the last-resort IP address. This address is used
196// if no endpoint prediction can be made and no static IP is set.
197func (ln *LocalNode) SetFallbackIP(ip net.IP) {
198	ln.mu.Lock()
199	defer ln.mu.Unlock()
200
201	ln.endpointForIP(ip).fallbackIP = ip
202	ln.updateEndpoints()
203}
204
205// SetFallbackUDP sets the last-resort UDP-on-IPv4 port. This port is used
206// if no endpoint prediction can be made.
207func (ln *LocalNode) SetFallbackUDP(port int) {
208	ln.mu.Lock()
209	defer ln.mu.Unlock()
210
211	ln.endpoint4.fallbackUDP = uint16(port)
212	ln.endpoint6.fallbackUDP = uint16(port)
213	ln.updateEndpoints()
214}
215
216// UDPEndpointStatement should be called whenever a statement about the local node's
217// UDP endpoint is received. It feeds the local endpoint predictor.
218func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) {
219	ln.mu.Lock()
220	defer ln.mu.Unlock()
221
222	ln.endpointForIP(endpoint.IP).track.AddStatement(fromaddr.String(), endpoint.String())
223	ln.updateEndpoints()
224}
225
226// UDPContact should be called whenever the local node has announced itself to another node
227// via UDP. It feeds the local endpoint predictor.
228func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) {
229	ln.mu.Lock()
230	defer ln.mu.Unlock()
231
232	ln.endpointForIP(toaddr.IP).track.AddContact(toaddr.String())
233	ln.updateEndpoints()
234}
235
236// updateEndpoints updates the record with predicted endpoints.
237func (ln *LocalNode) updateEndpoints() {
238	ip4, udp4 := ln.endpoint4.get()
239	ip6, udp6 := ln.endpoint6.get()
240
241	if ip4 != nil && !ip4.IsUnspecified() {
242		ln.set(enr.IPv4(ip4))
243	} else {
244		ln.delete(enr.IPv4{})
245	}
246	if ip6 != nil && !ip6.IsUnspecified() {
247		ln.set(enr.IPv6(ip6))
248	} else {
249		ln.delete(enr.IPv6{})
250	}
251	if udp4 != 0 {
252		ln.set(enr.UDP(udp4))
253	} else {
254		ln.delete(enr.UDP(0))
255	}
256	if udp6 != 0 && udp6 != udp4 {
257		ln.set(enr.UDP6(udp6))
258	} else {
259		ln.delete(enr.UDP6(0))
260	}
261}
262
263// get returns the endpoint with highest precedence.
264func (e *lnEndpoint) get() (newIP net.IP, newPort uint16) {
265	newPort = e.fallbackUDP
266	if e.fallbackIP != nil {
267		newIP = e.fallbackIP
268	}
269	if e.staticIP != nil {
270		newIP = e.staticIP
271	} else if ip, port := predictAddr(e.track); ip != nil {
272		newIP = ip
273		newPort = port
274	}
275	return newIP, newPort
276}
277
278// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based
279// endpoint representation to IP and port types.
280func predictAddr(t *netutil.IPTracker) (net.IP, uint16) {
281	ep := t.PredictEndpoint()
282	if ep == "" {
283		return nil, 0
284	}
285	ipString, portString, _ := net.SplitHostPort(ep)
286	ip := net.ParseIP(ipString)
287	port, err := strconv.ParseUint(portString, 10, 16)
288	if err != nil {
289		return nil, 0
290	}
291	return ip, uint16(port)
292}
293
294func (ln *LocalNode) invalidate() {
295	ln.cur.Store((*Node)(nil))
296}
297
298func (ln *LocalNode) sign() {
299	if n := ln.cur.Load().(*Node); n != nil {
300		return // no changes
301	}
302
303	var r enr.Record
304	for _, e := range ln.entries {
305		r.Set(e)
306	}
307	ln.bumpSeq()
308	r.SetSeq(ln.seq)
309	if err := SignV4(&r, ln.key); err != nil {
310		panic(fmt.Errorf("enode: can't sign record: %v", err))
311	}
312	n, err := New(ValidSchemes, &r)
313	if err != nil {
314		panic(fmt.Errorf("enode: can't verify local record: %v", err))
315	}
316	ln.cur.Store(n)
317	log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP())
318}
319
320func (ln *LocalNode) bumpSeq() {
321	ln.seq++
322	ln.db.storeLocalSeq(ln.id, ln.seq)
323}
324
325// nowMilliseconds gives the current timestamp at millisecond precision.
326func nowMilliseconds() uint64 {
327	ns := time.Now().UnixNano()
328	if ns < 0 {
329		return 0
330	}
331	return uint64(ns / 1000 / 1000)
332}
333