1// Copyright 2018 The go-ethereum Authors
2// This file is part of the go-ethereum library.
3//
4// The go-ethereum library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Lesser General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// The go-ethereum library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU Lesser General Public License for more details.
13//
14// You should have received a copy of the GNU Lesser General Public License
15// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16
17package dnsdisc
18
19import (
20	"bytes"
21	"crypto/ecdsa"
22	"encoding/base32"
23	"encoding/base64"
24	"fmt"
25	"io"
26	"sort"
27	"strings"
28
29	"github.com/ethereum/go-ethereum/crypto"
30	"github.com/ethereum/go-ethereum/p2p/enode"
31	"github.com/ethereum/go-ethereum/p2p/enr"
32	"github.com/ethereum/go-ethereum/rlp"
33	"golang.org/x/crypto/sha3"
34)
35
36// Tree is a merkle tree of node records.
37type Tree struct {
38	root    *rootEntry
39	entries map[string]entry
40}
41
42// Sign signs the tree with the given private key and sets the sequence number.
43func (t *Tree) Sign(key *ecdsa.PrivateKey, domain string) (url string, err error) {
44	root := *t.root
45	sig, err := crypto.Sign(root.sigHash(), key)
46	if err != nil {
47		return "", err
48	}
49	root.sig = sig
50	t.root = &root
51	link := newLinkEntry(domain, &key.PublicKey)
52	return link.String(), nil
53}
54
55// SetSignature verifies the given signature and assigns it as the tree's current
56// signature if valid.
57func (t *Tree) SetSignature(pubkey *ecdsa.PublicKey, signature string) error {
58	sig, err := b64format.DecodeString(signature)
59	if err != nil || len(sig) != crypto.SignatureLength {
60		return errInvalidSig
61	}
62	root := *t.root
63	root.sig = sig
64	if !root.verifySignature(pubkey) {
65		return errInvalidSig
66	}
67	t.root = &root
68	return nil
69}
70
71// Seq returns the sequence number of the tree.
72func (t *Tree) Seq() uint {
73	return t.root.seq
74}
75
76// Signature returns the signature of the tree.
77func (t *Tree) Signature() string {
78	return b64format.EncodeToString(t.root.sig)
79}
80
81// ToTXT returns all DNS TXT records required for the tree.
82func (t *Tree) ToTXT(domain string) map[string]string {
83	records := map[string]string{domain: t.root.String()}
84	for _, e := range t.entries {
85		sd := subdomain(e)
86		if domain != "" {
87			sd = sd + "." + domain
88		}
89		records[sd] = e.String()
90	}
91	return records
92}
93
94// Links returns all links contained in the tree.
95func (t *Tree) Links() []string {
96	var links []string
97	for _, e := range t.entries {
98		if le, ok := e.(*linkEntry); ok {
99			links = append(links, le.String())
100		}
101	}
102	return links
103}
104
105// Nodes returns all nodes contained in the tree.
106func (t *Tree) Nodes() []*enode.Node {
107	var nodes []*enode.Node
108	for _, e := range t.entries {
109		if ee, ok := e.(*enrEntry); ok {
110			nodes = append(nodes, ee.node)
111		}
112	}
113	return nodes
114}
115
116/*
117We want to keep the UDP size below 512 bytes. The UDP size is roughly:
118UDP length = 8 + UDP payload length ( 229 )
119UPD Payload length:
120 - dns.id 2
121 - dns.flags 2
122 - dns.count.queries 2
123 - dns.count.answers 2
124 - dns.count.auth_rr 2
125 - dns.count.add_rr 2
126 - queries (query-size + 6)
127 - answers :
128 	- dns.resp.name 2
129 	- dns.resp.type 2
130 	- dns.resp.class 2
131 	- dns.resp.ttl 4
132 	- dns.resp.len 2
133 	- dns.txt.length 1
134 	- dns.txt resp_data_size
135
136So the total size is roughly a fixed overhead of `39`, and the size of the
137query (domain name) and response.
138The query size is, for example, FVY6INQ6LZ33WLCHO3BPR3FH6Y.snap.mainnet.ethdisco.net (52)
139
140We also have some static data in the response, such as `enrtree-branch:`, and potentially
141splitting the response up with `" "`, leaving us with a size of roughly `400` that we need
142to stay below.
143
144The number `370` is used to have some margin for extra overhead (for example, the dns query
145may be larger - more subdomains).
146*/
147const (
148	hashAbbrevSize = 1 + 16*13/8          // Size of an encoded hash (plus comma)
149	maxChildren    = 370 / hashAbbrevSize // 13 children
150	minHashLength  = 12
151)
152
153// MakeTree creates a tree containing the given nodes and links.
154func MakeTree(seq uint, nodes []*enode.Node, links []string) (*Tree, error) {
155	// Sort records by ID and ensure all nodes have a valid record.
156	records := make([]*enode.Node, len(nodes))
157
158	copy(records, nodes)
159	sortByID(records)
160	for _, n := range records {
161		if len(n.Record().Signature()) == 0 {
162			return nil, fmt.Errorf("can't add node %v: unsigned node record", n.ID())
163		}
164	}
165
166	// Create the leaf list.
167	enrEntries := make([]entry, len(records))
168	for i, r := range records {
169		enrEntries[i] = &enrEntry{r}
170	}
171	linkEntries := make([]entry, len(links))
172	for i, l := range links {
173		le, err := parseLink(l)
174		if err != nil {
175			return nil, err
176		}
177		linkEntries[i] = le
178	}
179
180	// Create intermediate nodes.
181	t := &Tree{entries: make(map[string]entry)}
182	eroot := t.build(enrEntries)
183	t.entries[subdomain(eroot)] = eroot
184	lroot := t.build(linkEntries)
185	t.entries[subdomain(lroot)] = lroot
186	t.root = &rootEntry{seq: seq, eroot: subdomain(eroot), lroot: subdomain(lroot)}
187	return t, nil
188}
189
190func (t *Tree) build(entries []entry) entry {
191	if len(entries) == 1 {
192		return entries[0]
193	}
194	if len(entries) <= maxChildren {
195		hashes := make([]string, len(entries))
196		for i, e := range entries {
197			hashes[i] = subdomain(e)
198			t.entries[hashes[i]] = e
199		}
200		return &branchEntry{hashes}
201	}
202	var subtrees []entry
203	for len(entries) > 0 {
204		n := maxChildren
205		if len(entries) < n {
206			n = len(entries)
207		}
208		sub := t.build(entries[:n])
209		entries = entries[n:]
210		subtrees = append(subtrees, sub)
211		t.entries[subdomain(sub)] = sub
212	}
213	return t.build(subtrees)
214}
215
216func sortByID(nodes []*enode.Node) []*enode.Node {
217	sort.Slice(nodes, func(i, j int) bool {
218		return bytes.Compare(nodes[i].ID().Bytes(), nodes[j].ID().Bytes()) < 0
219	})
220	return nodes
221}
222
223// Entry Types
224
225type entry interface {
226	fmt.Stringer
227}
228
229type (
230	rootEntry struct {
231		eroot string
232		lroot string
233		seq   uint
234		sig   []byte
235	}
236	branchEntry struct {
237		children []string
238	}
239	enrEntry struct {
240		node *enode.Node
241	}
242	linkEntry struct {
243		str    string
244		domain string
245		pubkey *ecdsa.PublicKey
246	}
247)
248
249// Entry Encoding
250
251var (
252	b32format = base32.StdEncoding.WithPadding(base32.NoPadding)
253	b64format = base64.RawURLEncoding
254)
255
256const (
257	rootPrefix   = "enrtree-root:v1"
258	linkPrefix   = "enrtree://"
259	branchPrefix = "enrtree-branch:"
260	enrPrefix    = "enr:"
261)
262
263func subdomain(e entry) string {
264	h := sha3.NewLegacyKeccak256()
265	io.WriteString(h, e.String())
266	return b32format.EncodeToString(h.Sum(nil)[:16])
267}
268
269func (e *rootEntry) String() string {
270	return fmt.Sprintf(rootPrefix+" e=%s l=%s seq=%d sig=%s", e.eroot, e.lroot, e.seq, b64format.EncodeToString(e.sig))
271}
272
273func (e *rootEntry) sigHash() []byte {
274	h := sha3.NewLegacyKeccak256()
275	fmt.Fprintf(h, rootPrefix+" e=%s l=%s seq=%d", e.eroot, e.lroot, e.seq)
276	return h.Sum(nil)
277}
278
279func (e *rootEntry) verifySignature(pubkey *ecdsa.PublicKey) bool {
280	sig := e.sig[:crypto.RecoveryIDOffset] // remove recovery id
281	enckey := crypto.FromECDSAPub(pubkey)
282	return crypto.VerifySignature(enckey, e.sigHash(), sig)
283}
284
285func (e *branchEntry) String() string {
286	return branchPrefix + strings.Join(e.children, ",")
287}
288
289func (e *enrEntry) String() string {
290	return e.node.String()
291}
292
293func (e *linkEntry) String() string {
294	return linkPrefix + e.str
295}
296
297func newLinkEntry(domain string, pubkey *ecdsa.PublicKey) *linkEntry {
298	key := b32format.EncodeToString(crypto.CompressPubkey(pubkey))
299	str := key + "@" + domain
300	return &linkEntry{str, domain, pubkey}
301}
302
303// Entry Parsing
304
305func parseEntry(e string, validSchemes enr.IdentityScheme) (entry, error) {
306	switch {
307	case strings.HasPrefix(e, linkPrefix):
308		return parseLinkEntry(e)
309	case strings.HasPrefix(e, branchPrefix):
310		return parseBranch(e)
311	case strings.HasPrefix(e, enrPrefix):
312		return parseENR(e, validSchemes)
313	default:
314		return nil, errUnknownEntry
315	}
316}
317
318func parseRoot(e string) (rootEntry, error) {
319	var eroot, lroot, sig string
320	var seq uint
321	if _, err := fmt.Sscanf(e, rootPrefix+" e=%s l=%s seq=%d sig=%s", &eroot, &lroot, &seq, &sig); err != nil {
322		return rootEntry{}, entryError{"root", errSyntax}
323	}
324	if !isValidHash(eroot) || !isValidHash(lroot) {
325		return rootEntry{}, entryError{"root", errInvalidChild}
326	}
327	sigb, err := b64format.DecodeString(sig)
328	if err != nil || len(sigb) != crypto.SignatureLength {
329		return rootEntry{}, entryError{"root", errInvalidSig}
330	}
331	return rootEntry{eroot, lroot, seq, sigb}, nil
332}
333
334func parseLinkEntry(e string) (entry, error) {
335	le, err := parseLink(e)
336	if err != nil {
337		return nil, err
338	}
339	return le, nil
340}
341
342func parseLink(e string) (*linkEntry, error) {
343	if !strings.HasPrefix(e, linkPrefix) {
344		return nil, fmt.Errorf("wrong/missing scheme 'enrtree' in URL")
345	}
346	e = e[len(linkPrefix):]
347	pos := strings.IndexByte(e, '@')
348	if pos == -1 {
349		return nil, entryError{"link", errNoPubkey}
350	}
351	keystring, domain := e[:pos], e[pos+1:]
352	keybytes, err := b32format.DecodeString(keystring)
353	if err != nil {
354		return nil, entryError{"link", errBadPubkey}
355	}
356	key, err := crypto.DecompressPubkey(keybytes)
357	if err != nil {
358		return nil, entryError{"link", errBadPubkey}
359	}
360	return &linkEntry{e, domain, key}, nil
361}
362
363func parseBranch(e string) (entry, error) {
364	e = e[len(branchPrefix):]
365	if e == "" {
366		return &branchEntry{}, nil // empty entry is OK
367	}
368	hashes := make([]string, 0, strings.Count(e, ","))
369	for _, c := range strings.Split(e, ",") {
370		if !isValidHash(c) {
371			return nil, entryError{"branch", errInvalidChild}
372		}
373		hashes = append(hashes, c)
374	}
375	return &branchEntry{hashes}, nil
376}
377
378func parseENR(e string, validSchemes enr.IdentityScheme) (entry, error) {
379	e = e[len(enrPrefix):]
380	enc, err := b64format.DecodeString(e)
381	if err != nil {
382		return nil, entryError{"enr", errInvalidENR}
383	}
384	var rec enr.Record
385	if err := rlp.DecodeBytes(enc, &rec); err != nil {
386		return nil, entryError{"enr", err}
387	}
388	n, err := enode.New(validSchemes, &rec)
389	if err != nil {
390		return nil, entryError{"enr", err}
391	}
392	return &enrEntry{n}, nil
393}
394
395func isValidHash(s string) bool {
396	dlen := b32format.DecodedLen(len(s))
397	if dlen < minHashLength || dlen > 32 || strings.ContainsAny(s, "\n\r") {
398		return false
399	}
400	buf := make([]byte, 32)
401	_, err := b32format.Decode(buf, []byte(s))
402	return err == nil
403}
404
405// truncateHash truncates the given base32 hash string to the minimum acceptable length.
406func truncateHash(hash string) string {
407	maxLen := b32format.EncodedLen(minHashLength)
408	if len(hash) < maxLen {
409		panic(fmt.Errorf("dnsdisc: hash %q is too short", hash))
410	}
411	return hash[:maxLen]
412}
413
414// URL encoding
415
416// ParseURL parses an enrtree:// URL and returns its components.
417func ParseURL(url string) (domain string, pubkey *ecdsa.PublicKey, err error) {
418	le, err := parseLink(url)
419	if err != nil {
420		return "", nil, err
421	}
422	return le.domain, le.pubkey, nil
423}
424