1// SPDX-License-Identifier: ISC
2// Copyright (c) 2014-2020 Bitmark Inc.
3// Use of this source code is governed by an ISC
4// license that can be found in the LICENSE file.
5
6package domain
7
8import (
9	"encoding/hex"
10	"net"
11	"strconv"
12	"strings"
13
14	"github.com/bitmark-inc/bitmarkd/fault"
15)
16
17// supported tag of TXT records from DNS
18var supported = map[string]struct{}{
19	"bitmark=v2": {},
20	"bitmark=v3": {},
21}
22
23const (
24	publicKeyLength   = 2 * 32 // characters
25	fingerprintLength = 2 * 32 // characters
26	maxPortNumber     = 65535
27	minPortNumber     = 1
28)
29
30// DnsTXT - structure for dns txt record
31type DnsTXT struct {
32	IPv4                   net.IP
33	IPv6                   net.IP
34	RPCPort                uint16
35	ConnectPort            uint16
36	CertificateFingerprint []byte
37	PublicKey              []byte
38}
39
40// Parse - parse a dns txt record
41func Parse(s string) (*DnsTXT, error) {
42	t := &DnsTXT{}
43
44	countA := 0
45	countC := 0
46	countF := 0
47	countP := 0
48	countR := 0
49
50words:
51	for i, w := range strings.Split(strings.TrimSpace(s), " ") {
52
53		if 0 == i {
54			if _, ok := supported[w]; ok {
55				continue words
56			}
57			return nil, fault.InvalidDnsTxtRecord
58		}
59
60		// ignore empty
61		if "" == w {
62			continue words
63		}
64
65		// require form: <letter>=<word>
66		if len(w) < 3 || '=' != w[1] {
67			return nil, fault.InvalidDnsTxtRecord
68		}
69
70		// w[0]=tag character; w[1]= char('='); w[2:]=parameter
71		parameter := w[2:]
72		err := error(nil)
73		switch w[0] {
74		case 'a':
75		addresses:
76			for _, address := range strings.Split(parameter, ";") {
77				if '[' == address[0] {
78					end := len(address) - 1
79					if ']' == address[end] {
80						address = address[1:end]
81					}
82				}
83				IP := net.ParseIP(address)
84				if nil == IP {
85					err = fault.InvalidIpAddress
86					break addresses
87				} else {
88					err = nil
89					if nil != IP.To4() {
90						t.IPv4 = IP
91					} else {
92						t.IPv6 = IP
93					}
94				}
95			}
96			countA += 1
97
98		case 'c':
99			t.ConnectPort, err = getPort(parameter)
100			countC += 1
101		case 's': // not actually used but still check
102			_, err = getPort(parameter)
103		case 'r':
104			t.RPCPort, err = getPort(parameter)
105			countR += 1
106		case 'p':
107			if len(parameter) != publicKeyLength {
108				err = fault.InvalidPublicKey
109			} else {
110				t.PublicKey, err = hex.DecodeString(parameter)
111				if nil != err {
112					err = fault.InvalidPublicKey
113				}
114			}
115			countP += 1
116		case 'f':
117			if len(parameter) != fingerprintLength {
118				err = fault.InvalidFingerprint
119			} else {
120				t.CertificateFingerprint, err = hex.DecodeString(parameter)
121				if nil != err {
122					err = fault.InvalidFingerprint
123				}
124			}
125			countF += 1
126		default:
127			err = fault.InvalidDnsTxtRecord
128		}
129		if nil != err {
130			return nil, err
131		}
132	}
133
134	// ensure that there is only one each of the required items
135	if countA != 1 || countC != 1 || countF != 1 || countP != 1 || countR != 1 {
136		return nil, fault.InvalidDnsTxtRecord
137	}
138
139	return t, nil
140}
141
142func getPort(s string) (uint16, error) {
143	port, err := strconv.Atoi(s)
144	if nil != err {
145		return 0, fault.InvalidPortNumber
146	}
147	if port < minPortNumber || port > maxPortNumber {
148		return 0, fault.InvalidPortNumber
149	}
150	return uint16(port), nil
151}
152