1package dnsstamps
2
3import (
4	"encoding/base64"
5	"encoding/binary"
6	"encoding/hex"
7	"errors"
8	"fmt"
9	"net"
10	"strconv"
11	"strings"
12)
13
14const DefaultPort = 443
15
16type ServerInformalProperties uint64
17
18const (
19	ServerInformalPropertyDNSSEC   = ServerInformalProperties(1) << 0
20	ServerInformalPropertyNoLog    = ServerInformalProperties(1) << 1
21	ServerInformalPropertyNoFilter = ServerInformalProperties(1) << 2
22)
23
24type StampProtoType uint8
25
26const (
27	StampProtoTypePlain         = StampProtoType(0x00)
28	StampProtoTypeDNSCrypt      = StampProtoType(0x01)
29	StampProtoTypeDoH           = StampProtoType(0x02)
30	StampProtoTypeTLS           = StampProtoType(0x03)
31	StampProtoTypeDoQ           = StampProtoType(0x04)
32	StampProtoTypeODoHTarget    = StampProtoType(0x05)
33	StampProtoTypeDNSCryptRelay = StampProtoType(0x81)
34	StampProtoTypeODoHRelay     = StampProtoType(0x85)
35)
36
37func (stampProtoType *StampProtoType) String() string {
38	switch *stampProtoType {
39	case StampProtoTypePlain:
40		return "Plain"
41	case StampProtoTypeDNSCrypt:
42		return "DNSCrypt"
43	case StampProtoTypeDoH:
44		return "DoH"
45	case StampProtoTypeTLS:
46		return "TLS"
47	case StampProtoTypeDoQ:
48		return "QUIC"
49	case StampProtoTypeODoHTarget:
50		return "oDoH target"
51	case StampProtoTypeDNSCryptRelay:
52		return "DNSCrypt relay"
53	case StampProtoTypeODoHRelay:
54		return "oDoH relay"
55	default:
56		return "(unknown)"
57	}
58}
59
60type ServerStamp struct {
61	ServerAddrStr string
62	ServerPk      []uint8
63	Hashes        [][]uint8
64	ProviderName  string
65	Path          string
66	Props         ServerInformalProperties
67	Proto         StampProtoType
68}
69
70func NewDNSCryptServerStampFromLegacy(serverAddrStr string, serverPkStr string, providerName string, props ServerInformalProperties) (ServerStamp, error) {
71	if net.ParseIP(serverAddrStr) != nil {
72		serverAddrStr = fmt.Sprintf("%s:%d", serverAddrStr, DefaultPort)
73	}
74	serverPk, err := hex.DecodeString(strings.Replace(serverPkStr, ":", "", -1))
75	if err != nil || len(serverPk) != 32 {
76		return ServerStamp{}, fmt.Errorf("Unsupported public key: [%s]", serverPkStr)
77	}
78	return ServerStamp{
79		ServerAddrStr: serverAddrStr,
80		ServerPk:      serverPk,
81		ProviderName:  providerName,
82		Props:         props,
83		Proto:         StampProtoTypeDNSCrypt,
84	}, nil
85}
86
87func NewServerStampFromString(stampStr string) (ServerStamp, error) {
88	if !strings.HasPrefix(stampStr, "sdns:") {
89		return ServerStamp{}, errors.New("Stamps are expected to start with \"sdns:\"")
90	}
91	stampStr = stampStr[5:]
92	stampStr = strings.TrimPrefix(stampStr, "//")
93	bin, err := base64.RawURLEncoding.Strict().DecodeString(stampStr)
94	if err != nil {
95		return ServerStamp{}, err
96	}
97	if len(bin) < 1 {
98		return ServerStamp{}, errors.New("Stamp is too short")
99	}
100	if bin[0] == uint8(StampProtoTypeDNSCrypt) {
101		return newDNSCryptServerStamp(bin)
102	} else if bin[0] == uint8(StampProtoTypeDoH) {
103		return newDoHServerStamp(bin)
104	} else if bin[0] == uint8(StampProtoTypeODoHTarget) {
105		return newODoHTargetStamp(bin)
106	} else if bin[0] == uint8(StampProtoTypeDNSCryptRelay) {
107		return newDNSCryptRelayStamp(bin)
108	} else if bin[0] == uint8(StampProtoTypeODoHRelay) {
109		return newODoHRelayStamp(bin)
110	}
111	return ServerStamp{}, errors.New("Unsupported stamp version or protocol")
112}
113
114func NewRelayAndServerStampFromString(stampStr string) (ServerStamp, ServerStamp, error) {
115	if !strings.HasPrefix(stampStr, "sdns://") {
116		return ServerStamp{}, ServerStamp{}, errors.New("Stamps are expected to start with \"sdns://\"")
117	}
118	stampStr = stampStr[7:]
119	parts := strings.Split(stampStr, "/")
120	if len(parts) != 2 {
121		return ServerStamp{}, ServerStamp{}, errors.New("This is not a relay+server stamp")
122	}
123	relayStamp, err := NewServerStampFromString("sdns://" + parts[0])
124	if err != nil {
125		return ServerStamp{}, ServerStamp{}, err
126	}
127	serverStamp, err := NewServerStampFromString("sdns://" + parts[1])
128	if err != nil {
129		return ServerStamp{}, ServerStamp{}, err
130	}
131	if relayStamp.Proto != StampProtoTypeDNSCryptRelay && relayStamp.Proto != StampProtoTypeODoHRelay {
132		return ServerStamp{}, ServerStamp{}, errors.New("First stamp is not a relay")
133	}
134	if !(serverStamp.Proto != StampProtoTypeDNSCryptRelay && serverStamp.Proto != StampProtoTypeODoHRelay) {
135		return ServerStamp{}, ServerStamp{}, errors.New("Second stamp is a relay")
136	}
137	return relayStamp, serverStamp, nil
138}
139
140// id(u8)=0x01 props addrLen(1) serverAddr pkStrlen(1) pkStr providerNameLen(1) providerName
141
142func newDNSCryptServerStamp(bin []byte) (ServerStamp, error) {
143	stamp := ServerStamp{Proto: StampProtoTypeDNSCrypt}
144	if len(bin) < 66 {
145		return stamp, errors.New("Stamp is too short")
146	}
147	stamp.Props = ServerInformalProperties(binary.LittleEndian.Uint64(bin[1:9]))
148	binLen := len(bin)
149	pos := 9
150
151	length := int(bin[pos])
152	if 1+length >= binLen-pos {
153		return stamp, errors.New("Invalid stamp")
154	}
155	pos++
156	stamp.ServerAddrStr = string(bin[pos : pos+length])
157	pos += length
158
159	colIndex := strings.LastIndex(stamp.ServerAddrStr, ":")
160	bracketIndex := strings.LastIndex(stamp.ServerAddrStr, "]")
161	if colIndex < bracketIndex {
162		colIndex = -1
163	}
164	if colIndex < 0 {
165		colIndex = len(stamp.ServerAddrStr)
166		stamp.ServerAddrStr = fmt.Sprintf("%s:%d", stamp.ServerAddrStr, DefaultPort)
167	}
168	if colIndex >= len(stamp.ServerAddrStr)-1 {
169		return stamp, errors.New("Invalid stamp (empty port)")
170	}
171	ipOnly := stamp.ServerAddrStr[:colIndex]
172	portOnly := stamp.ServerAddrStr[colIndex+1:]
173	if _, err := strconv.ParseUint(portOnly, 10, 16); err != nil {
174		return stamp, errors.New("Invalid stamp (port range)")
175	}
176	if net.ParseIP(strings.TrimRight(strings.TrimLeft(ipOnly, "["), "]")) == nil {
177		return stamp, errors.New("Invalid stamp (IP address)")
178	}
179
180	length = int(bin[pos])
181	if 1+length >= binLen-pos {
182		return stamp, errors.New("Invalid stamp")
183	}
184	pos++
185	stamp.ServerPk = bin[pos : pos+length]
186	pos += length
187
188	length = int(bin[pos])
189	if length >= binLen-pos {
190		return stamp, errors.New("Invalid stamp")
191	}
192	pos++
193	stamp.ProviderName = string(bin[pos : pos+length])
194	pos += length
195
196	if pos != binLen {
197		return stamp, errors.New("Invalid stamp (garbage after end)")
198	}
199	return stamp, nil
200}
201
202// id(u8)=0x02 props addrLen(1) serverAddr hashLen(1) hash hostNameLen(1) hostName pathLen(1) path
203
204func newDoHServerStamp(bin []byte) (ServerStamp, error) {
205	stamp := ServerStamp{Proto: StampProtoTypeDoH}
206	if len(bin) < 22 {
207		return stamp, errors.New("Stamp is too short")
208	}
209	stamp.Props = ServerInformalProperties(binary.LittleEndian.Uint64(bin[1:9]))
210	binLen := len(bin)
211	pos := 9
212
213	length := int(bin[pos])
214	if 1+length >= binLen-pos {
215		return stamp, errors.New("Invalid stamp")
216	}
217	pos++
218	stamp.ServerAddrStr = string(bin[pos : pos+length])
219	pos += length
220
221	for {
222		vlen := int(bin[pos])
223		length = vlen & ^0x80
224		if 1+length >= binLen-pos {
225			return stamp, errors.New("Invalid stamp")
226		}
227		pos++
228		if length > 0 {
229			stamp.Hashes = append(stamp.Hashes, bin[pos:pos+length])
230		}
231		pos += length
232		if vlen&0x80 != 0x80 {
233			break
234		}
235	}
236
237	length = int(bin[pos])
238	if 1+length >= binLen-pos {
239		return stamp, errors.New("Invalid stamp")
240	}
241	pos++
242	stamp.ProviderName = string(bin[pos : pos+length])
243	pos += length
244
245	length = int(bin[pos])
246	if length >= binLen-pos {
247		return stamp, errors.New("Invalid stamp")
248	}
249	pos++
250	stamp.Path = string(bin[pos : pos+length])
251	pos += length
252
253	if pos != binLen {
254		return stamp, errors.New("Invalid stamp (garbage after end)")
255	}
256
257	if len(stamp.ServerAddrStr) > 0 {
258		colIndex := strings.LastIndex(stamp.ServerAddrStr, ":")
259		bracketIndex := strings.LastIndex(stamp.ServerAddrStr, "]")
260		if colIndex < bracketIndex {
261			colIndex = -1
262		}
263		if colIndex < 0 {
264			colIndex = len(stamp.ServerAddrStr)
265			stamp.ServerAddrStr = fmt.Sprintf("%s:%d", stamp.ServerAddrStr, DefaultPort)
266		}
267		if colIndex >= len(stamp.ServerAddrStr)-1 {
268			return stamp, errors.New("Invalid stamp (empty port)")
269		}
270		ipOnly := stamp.ServerAddrStr[:colIndex]
271		portOnly := stamp.ServerAddrStr[colIndex+1:]
272		if _, err := strconv.ParseUint(portOnly, 10, 16); err != nil {
273			return stamp, errors.New("Invalid stamp (port range)")
274		}
275		if net.ParseIP(strings.TrimRight(strings.TrimLeft(ipOnly, "["), "]")) == nil {
276			return stamp, errors.New("Invalid stamp (IP address)")
277		}
278	}
279
280	return stamp, nil
281}
282
283// id(u8)=0x05 props hostNameLen(1) hostName pathLen(1) path
284
285func newODoHTargetStamp(bin []byte) (ServerStamp, error) {
286	stamp := ServerStamp{Proto: StampProtoTypeODoHTarget}
287	if len(bin) < 12 {
288		return stamp, errors.New("Stamp is too short")
289	}
290	stamp.Props = ServerInformalProperties(binary.LittleEndian.Uint64(bin[1:9]))
291	binLen := len(bin)
292	pos := 9
293
294	length := int(bin[pos])
295	if 1+length >= binLen-pos {
296		return stamp, errors.New("Invalid stamp")
297	}
298	pos++
299	stamp.ProviderName = string(bin[pos : pos+length])
300	pos += length
301
302	length = int(bin[pos])
303	if length >= binLen-pos {
304		return stamp, errors.New("Invalid stamp")
305	}
306	pos++
307	stamp.Path = string(bin[pos : pos+length])
308	pos += length
309
310	if pos != binLen {
311		return stamp, errors.New("Invalid stamp (garbage after end)")
312	}
313
314	return stamp, nil
315}
316
317// id(u8)=0x81 addrLen(1) serverAddr
318
319func newDNSCryptRelayStamp(bin []byte) (ServerStamp, error) {
320	stamp := ServerStamp{Proto: StampProtoTypeDNSCryptRelay}
321	if len(bin) < 13 {
322		return stamp, errors.New("Stamp is too short")
323	}
324	binLen := len(bin)
325	pos := 1
326	length := int(bin[pos])
327	if 1+length > binLen-pos {
328		return stamp, errors.New("Invalid stamp")
329	}
330	pos++
331	stamp.ServerAddrStr = string(bin[pos : pos+length])
332	pos += length
333
334	colIndex := strings.LastIndex(stamp.ServerAddrStr, ":")
335	bracketIndex := strings.LastIndex(stamp.ServerAddrStr, "]")
336	if colIndex < bracketIndex {
337		colIndex = -1
338	}
339	if colIndex < 0 {
340		colIndex = len(stamp.ServerAddrStr)
341		stamp.ServerAddrStr = fmt.Sprintf("%s:%d", stamp.ServerAddrStr, DefaultPort)
342	}
343	if colIndex >= len(stamp.ServerAddrStr)-1 {
344		return stamp, errors.New("Invalid stamp (empty port)")
345	}
346	ipOnly := stamp.ServerAddrStr[:colIndex]
347	portOnly := stamp.ServerAddrStr[colIndex+1:]
348	if _, err := strconv.ParseUint(portOnly, 10, 16); err != nil {
349		return stamp, errors.New("Invalid stamp (port range)")
350	}
351	if net.ParseIP(strings.TrimRight(strings.TrimLeft(ipOnly, "["), "]")) == nil {
352		return stamp, errors.New("Invalid stamp (IP address)")
353	}
354	if pos != binLen {
355		return stamp, errors.New("Invalid stamp (garbage after end)")
356	}
357	return stamp, nil
358}
359
360// id(u8)=0x85 props addrLen(1) serverAddr hashLen(1) hash hostNameLen(1) hostName pathLen(1) path
361
362func newODoHRelayStamp(bin []byte) (ServerStamp, error) {
363	stamp := ServerStamp{Proto: StampProtoTypeODoHRelay}
364	if len(bin) < 13 {
365		return stamp, errors.New("Stamp is too short")
366	}
367	stamp.Props = ServerInformalProperties(binary.LittleEndian.Uint64(bin[1:9]))
368	binLen := len(bin)
369	pos := 9
370
371	length := int(bin[pos])
372	if 1+length >= binLen-pos {
373		return stamp, errors.New("Invalid stamp")
374	}
375	pos++
376	stamp.ServerAddrStr = string(bin[pos : pos+length])
377	pos += length
378
379	for {
380		vlen := int(bin[pos])
381		length = vlen & ^0x80
382		if 1+length >= binLen-pos {
383			return stamp, errors.New("Invalid stamp")
384		}
385		pos++
386		if length > 0 {
387			stamp.Hashes = append(stamp.Hashes, bin[pos:pos+length])
388		}
389		pos += length
390		if vlen&0x80 != 0x80 {
391			break
392		}
393	}
394
395	length = int(bin[pos])
396	if 1+length >= binLen-pos {
397		return stamp, errors.New("Invalid stamp")
398	}
399	pos++
400	stamp.ProviderName = string(bin[pos : pos+length])
401	pos += length
402
403	length = int(bin[pos])
404	if length >= binLen-pos {
405		return stamp, errors.New("Invalid stamp")
406	}
407	pos++
408	stamp.Path = string(bin[pos : pos+length])
409	pos += length
410
411	if pos != binLen {
412		return stamp, errors.New("Invalid stamp (garbage after end)")
413	}
414
415	if len(stamp.ServerAddrStr) > 0 {
416		colIndex := strings.LastIndex(stamp.ServerAddrStr, ":")
417		bracketIndex := strings.LastIndex(stamp.ServerAddrStr, "]")
418		if colIndex < bracketIndex {
419			colIndex = -1
420		}
421		if colIndex < 0 {
422			colIndex = len(stamp.ServerAddrStr)
423			stamp.ServerAddrStr = fmt.Sprintf("%s:%d", stamp.ServerAddrStr, DefaultPort)
424		}
425		if colIndex >= len(stamp.ServerAddrStr)-1 {
426			return stamp, errors.New("Invalid stamp (empty port)")
427		}
428		ipOnly := stamp.ServerAddrStr[:colIndex]
429		portOnly := stamp.ServerAddrStr[colIndex+1:]
430		if _, err := strconv.ParseUint(portOnly, 10, 16); err != nil {
431			return stamp, errors.New("Invalid stamp (port range)")
432		}
433		if net.ParseIP(strings.TrimRight(strings.TrimLeft(ipOnly, "["), "]")) == nil {
434			return stamp, errors.New("Invalid stamp (IP address)")
435		}
436	}
437
438	return stamp, nil
439}
440
441func (stamp *ServerStamp) String() string {
442	if stamp.Proto == StampProtoTypeDNSCrypt {
443		return stamp.dnsCryptString()
444	} else if stamp.Proto == StampProtoTypeDoH {
445		return stamp.dohString()
446	} else if stamp.Proto == StampProtoTypeODoHTarget {
447		return stamp.oDohTargetString()
448	} else if stamp.Proto == StampProtoTypeDNSCryptRelay {
449		return stamp.dnsCryptRelayString()
450	} else if stamp.Proto == StampProtoTypeODoHRelay {
451		return stamp.oDohRelayString()
452	}
453	panic("Unsupported protocol")
454}
455
456func (stamp *ServerStamp) dnsCryptString() string {
457	bin := make([]uint8, 9)
458	bin[0] = uint8(StampProtoTypeDNSCrypt)
459	binary.LittleEndian.PutUint64(bin[1:9], uint64(stamp.Props))
460
461	serverAddrStr := stamp.ServerAddrStr
462	if strings.HasSuffix(serverAddrStr, ":"+strconv.Itoa(DefaultPort)) {
463		serverAddrStr = serverAddrStr[:len(serverAddrStr)-1-len(strconv.Itoa(DefaultPort))]
464	}
465	bin = append(bin, uint8(len(serverAddrStr)))
466	bin = append(bin, []uint8(serverAddrStr)...)
467
468	bin = append(bin, uint8(len(stamp.ServerPk)))
469	bin = append(bin, stamp.ServerPk...)
470
471	bin = append(bin, uint8(len(stamp.ProviderName)))
472	bin = append(bin, []uint8(stamp.ProviderName)...)
473
474	str := base64.RawURLEncoding.EncodeToString(bin)
475
476	return "sdns://" + str
477}
478
479func (stamp *ServerStamp) dohString() string {
480	bin := make([]uint8, 9)
481	bin[0] = uint8(StampProtoTypeDoH)
482	binary.LittleEndian.PutUint64(bin[1:9], uint64(stamp.Props))
483
484	serverAddrStr := stamp.ServerAddrStr
485	if strings.HasSuffix(serverAddrStr, ":"+strconv.Itoa(DefaultPort)) {
486		serverAddrStr = serverAddrStr[:len(serverAddrStr)-1-len(strconv.Itoa(DefaultPort))]
487	}
488	bin = append(bin, uint8(len(serverAddrStr)))
489	bin = append(bin, []uint8(serverAddrStr)...)
490
491	if len(stamp.Hashes) == 0 {
492		bin = append(bin, uint8(0))
493	} else {
494		last := len(stamp.Hashes) - 1
495		for i, hash := range stamp.Hashes {
496			vlen := len(hash)
497			if i < last {
498				vlen |= 0x80
499			}
500			bin = append(bin, uint8(vlen))
501			bin = append(bin, hash...)
502		}
503	}
504
505	bin = append(bin, uint8(len(stamp.ProviderName)))
506	bin = append(bin, []uint8(stamp.ProviderName)...)
507
508	bin = append(bin, uint8(len(stamp.Path)))
509	bin = append(bin, []uint8(stamp.Path)...)
510
511	str := base64.RawURLEncoding.EncodeToString(bin)
512
513	return "sdns://" + str
514}
515
516func (stamp *ServerStamp) oDohTargetString() string {
517	bin := make([]uint8, 9)
518	bin[0] = uint8(StampProtoTypeODoHTarget)
519	binary.LittleEndian.PutUint64(bin[1:9], uint64(stamp.Props))
520
521	bin = append(bin, uint8(len(stamp.ProviderName)))
522	bin = append(bin, []uint8(stamp.ProviderName)...)
523
524	bin = append(bin, uint8(len(stamp.Path)))
525	bin = append(bin, []uint8(stamp.Path)...)
526
527	str := base64.RawURLEncoding.EncodeToString(bin)
528
529	return "sdns://" + str
530}
531
532func (stamp *ServerStamp) dnsCryptRelayString() string {
533	bin := make([]uint8, 1)
534	bin[0] = uint8(StampProtoTypeDNSCryptRelay)
535
536	serverAddrStr := stamp.ServerAddrStr
537	if strings.HasSuffix(serverAddrStr, ":"+strconv.Itoa(DefaultPort)) {
538		serverAddrStr = serverAddrStr[:len(serverAddrStr)-1-len(strconv.Itoa(DefaultPort))]
539	}
540	bin = append(bin, uint8(len(serverAddrStr)))
541	bin = append(bin, []uint8(serverAddrStr)...)
542
543	str := base64.RawURLEncoding.EncodeToString(bin)
544
545	return "sdns://" + str
546}
547
548func (stamp *ServerStamp) oDohRelayString() string {
549	bin := make([]uint8, 9)
550	bin[0] = uint8(StampProtoTypeODoHRelay)
551	binary.LittleEndian.PutUint64(bin[1:9], uint64(stamp.Props))
552
553	serverAddrStr := stamp.ServerAddrStr
554	if strings.HasSuffix(serverAddrStr, ":"+strconv.Itoa(DefaultPort)) {
555		serverAddrStr = serverAddrStr[:len(serverAddrStr)-1-len(strconv.Itoa(DefaultPort))]
556	}
557	bin = append(bin, uint8(len(serverAddrStr)))
558	bin = append(bin, []uint8(serverAddrStr)...)
559
560	if len(stamp.Hashes) == 0 {
561		bin = append(bin, uint8(0))
562	} else {
563		last := len(stamp.Hashes) - 1
564		for i, hash := range stamp.Hashes {
565			vlen := len(hash)
566			if i < last {
567				vlen |= 0x80
568			}
569			bin = append(bin, uint8(vlen))
570			bin = append(bin, hash...)
571		}
572	}
573
574	bin = append(bin, uint8(len(stamp.ProviderName)))
575	bin = append(bin, []uint8(stamp.ProviderName)...)
576
577	bin = append(bin, uint8(len(stamp.Path)))
578	bin = append(bin, []uint8(stamp.Path)...)
579
580	str := base64.RawURLEncoding.EncodeToString(bin)
581
582	return "sdns://" + str
583}
584