1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package storj
5
6import (
7	"crypto/x509/pkix"
8	"database/sql/driver"
9	"encoding/binary"
10	"encoding/json"
11	"math/bits"
12
13	"github.com/btcsuite/btcutil/base58"
14	"github.com/zeebo/errs"
15
16	"storj.io/common/peertls/extensions"
17)
18
19var (
20	// ErrNodeID is used when something goes wrong with a node id.
21	ErrNodeID = errs.Class("node ID")
22	// ErrVersion is used for identity version related errors.
23	ErrVersion = errs.Class("node ID version")
24)
25
26// NodeIDSize is the byte length of a NodeID.
27const NodeIDSize = 32
28
29// NodeID is a unique node identifier.
30type NodeID [NodeIDSize]byte
31
32// NodeIDList is a slice of NodeIDs (implements sort).
33type NodeIDList []NodeID
34
35// NewVersionedID adds an identity version to a node ID.
36func NewVersionedID(id NodeID, version IDVersion) NodeID {
37	var versionedID NodeID
38	copy(versionedID[:], id[:])
39
40	versionedID[NodeIDSize-1] = byte(version.Number)
41	return versionedID
42}
43
44// NewVersionExt creates a new identity version certificate extension for the
45// given identity version.
46func NewVersionExt(version IDVersion) pkix.Extension {
47	return pkix.Extension{
48		Id:    extensions.IdentityVersionExtID,
49		Value: []byte{byte(version.Number)},
50	}
51}
52
53// NodeIDFromString decodes a base58check encoded node id string.
54func NodeIDFromString(s string) (NodeID, error) {
55	idBytes, versionNumber, err := base58.CheckDecode(s)
56	if err != nil {
57		return NodeID{}, ErrNodeID.Wrap(err)
58	}
59	unversionedID, err := NodeIDFromBytes(idBytes)
60	if err != nil {
61		return NodeID{}, err
62	}
63
64	version := IDVersions[IDVersionNumber(versionNumber)]
65	return NewVersionedID(unversionedID, version), nil
66}
67
68// NodeIDsFromBytes converts a 2d byte slice into a list of nodes.
69func NodeIDsFromBytes(b [][]byte) (ids NodeIDList, err error) {
70	var idErrs []error
71	for _, idBytes := range b {
72		id, err := NodeIDFromBytes(idBytes)
73		if err != nil {
74			idErrs = append(idErrs, err)
75			continue
76		}
77
78		ids = append(ids, id)
79	}
80
81	if err = errs.Combine(idErrs...); err != nil {
82		return nil, err
83	}
84	return ids, nil
85}
86
87// NodeIDFromBytes converts a byte slice into a node id.
88func NodeIDFromBytes(b []byte) (NodeID, error) {
89	bLen := len(b)
90	if bLen != len(NodeID{}) {
91		return NodeID{}, ErrNodeID.New("not enough bytes to make a node id; have %d, need %d", bLen, len(NodeID{}))
92	}
93
94	var id NodeID
95	copy(id[:], b)
96	return id, nil
97}
98
99// String returns NodeID as base58 encoded string with checksum and version bytes.
100func (id NodeID) String() string {
101	unversionedID := id.unversioned()
102	return base58.CheckEncode(unversionedID[:], byte(id.Version().Number))
103}
104
105// IsZero returns whether NodeID is unassigned.
106func (id NodeID) IsZero() bool {
107	return id == NodeID{}
108}
109
110// Bytes returns raw bytes of the id.
111func (id NodeID) Bytes() []byte { return id[:] }
112
113// Less returns whether id is smaller than other in lexicographic order.
114func (id NodeID) Less(other NodeID) bool {
115	a0, b0 := binary.BigEndian.Uint64(id[0:]), binary.BigEndian.Uint64(other[0:])
116	if a0 < b0 {
117		return true
118	} else if a0 > b0 {
119		return false
120	}
121
122	a1, b1 := binary.BigEndian.Uint64(id[8:]), binary.BigEndian.Uint64(other[8:])
123	if a1 < b1 {
124		return true
125	} else if a1 > b1 {
126		return false
127	}
128
129	a2, b2 := binary.BigEndian.Uint64(id[16:]), binary.BigEndian.Uint64(other[16:])
130	if a2 < b2 {
131		return true
132	} else if a2 > b2 {
133		return false
134	}
135
136	a3, b3 := binary.BigEndian.Uint64(id[24:]), binary.BigEndian.Uint64(other[24:])
137	if a3 < b3 {
138		return true
139	} else if a3 > b3 {
140		return false
141	}
142
143	return false
144}
145
146// Compare returns an integer comparing id and other lexicographically.
147// The result will be 0 if id==other, -1 if id < other, and +1 if id > other.
148func (id NodeID) Compare(other NodeID) int {
149	a0, b0 := binary.BigEndian.Uint64(id[0:]), binary.BigEndian.Uint64(other[0:])
150	if a0 < b0 {
151		return -1
152	} else if a0 > b0 {
153		return 1
154	}
155
156	a1, b1 := binary.BigEndian.Uint64(id[8:]), binary.BigEndian.Uint64(other[8:])
157	if a1 < b1 {
158		return -1
159	} else if a1 > b1 {
160		return 1
161	}
162
163	a2, b2 := binary.BigEndian.Uint64(id[16:]), binary.BigEndian.Uint64(other[16:])
164	if a2 < b2 {
165		return -1
166	} else if a2 > b2 {
167		return 1
168	}
169
170	a3, b3 := binary.BigEndian.Uint64(id[24:]), binary.BigEndian.Uint64(other[24:])
171	if a3 < b3 {
172		return -1
173	} else if a3 > b3 {
174		return 1
175	}
176
177	return 0
178}
179
180// Version returns the version of the identity format.
181func (id NodeID) Version() IDVersion {
182	versionNumber := id.versionByte()
183	if versionNumber == 0 {
184		return IDVersions[V0]
185	}
186
187	version, err := GetIDVersion(IDVersionNumber(versionNumber))
188	// NB: when in doubt, use V0
189	if err != nil {
190		return IDVersions[V0]
191	}
192
193	return version
194}
195
196// Difficulty returns the number of trailing zero bits in a node ID.
197func (id NodeID) Difficulty() (uint16, error) {
198	idLen := len(id)
199	var b byte
200	var zeroBits int
201	// NB: last difficulty byte is used for version
202	for i := 2; i <= idLen; i++ {
203		b = id[idLen-i]
204
205		if b != 0 {
206			zeroBits = bits.TrailingZeros16(uint16(b))
207			if zeroBits == 16 {
208				// we already checked that b != 0.
209				return 0, ErrNodeID.New("impossible codepath!")
210			}
211
212			return uint16((i-1)*8 + zeroBits), nil
213		}
214	}
215
216	return 0, ErrNodeID.New("difficulty matches id hash length: %d; hash (hex): % x", idLen, id)
217}
218
219// Marshal serializes a node id.
220func (id NodeID) Marshal() ([]byte, error) {
221	return id.Bytes(), nil
222}
223
224// MarshalTo serializes a node ID into the passed byte slice.
225func (id *NodeID) MarshalTo(data []byte) (n int, err error) {
226	n = copy(data, id.Bytes())
227	return n, nil
228}
229
230// Unmarshal deserializes a node ID.
231func (id *NodeID) Unmarshal(data []byte) error {
232	var err error
233	*id, err = NodeIDFromBytes(data)
234	return err
235}
236
237func (id NodeID) versionByte() byte {
238	return id[NodeIDSize-1]
239}
240
241// unversioned returns the node ID with the version byte replaced with `0`.
242// NB: Legacy node IDs (i.e. pre-identity-versions) with a difficulty less
243// than `8` are unsupported.
244func (id NodeID) unversioned() NodeID {
245	unversionedID := NodeID{}
246	copy(unversionedID[:], id[:NodeIDSize-1])
247	return unversionedID
248}
249
250// Size returns the length of a node ID (implements gogo's custom type interface).
251func (id *NodeID) Size() int {
252	return len(id)
253}
254
255// MarshalJSON serializes a node ID to a json string as bytes.
256func (id NodeID) MarshalJSON() ([]byte, error) {
257	return []byte(`"` + id.String() + `"`), nil
258}
259
260// Value converts a NodeID to a database field.
261func (id NodeID) Value() (driver.Value, error) {
262	return id.Bytes(), nil
263}
264
265// Scan extracts a NodeID from a database field.
266func (id *NodeID) Scan(src interface{}) (err error) {
267	b, ok := src.([]byte)
268	if !ok {
269		return ErrNodeID.New("NodeID Scan expects []byte")
270	}
271	n, err := NodeIDFromBytes(b)
272	*id = n
273	return err
274}
275
276// UnmarshalJSON deserializes a json string (as bytes) to a node ID.
277func (id *NodeID) UnmarshalJSON(data []byte) error {
278	var unquoted string
279	err := json.Unmarshal(data, &unquoted)
280	if err != nil {
281		return err
282	}
283
284	*id, err = NodeIDFromString(unquoted)
285	if err != nil {
286		return err
287	}
288	return nil
289}
290
291// Strings returns a string slice of the node IDs.
292func (n NodeIDList) Strings() []string {
293	var strings []string
294	for _, nid := range n {
295		strings = append(strings, nid.String())
296	}
297	return strings
298}
299
300// Bytes returns a 2d byte slice of the node IDs.
301func (n NodeIDList) Bytes() (idsBytes [][]byte) {
302	for _, nid := range n {
303		idsBytes = append(idsBytes, nid.Bytes())
304	}
305	return idsBytes
306}
307
308// Len implements sort.Interface.Len().
309func (n NodeIDList) Len() int { return len(n) }
310
311// Swap implements sort.Interface.Swap().
312func (n NodeIDList) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
313
314// Less implements sort.Interface.Less().
315func (n NodeIDList) Less(i, j int) bool { return n[i].Less(n[j]) }
316
317// Contains tests if the node IDs contain id.
318func (n NodeIDList) Contains(id NodeID) bool {
319	for _, nid := range n {
320		if nid == id {
321			return true
322		}
323	}
324	return false
325}
326
327// Unique returns slice of the unique node IDs.
328func (n NodeIDList) Unique() NodeIDList {
329	var result []NodeID
330next:
331	for _, id := range n {
332		for _, added := range result {
333			if added == id {
334				continue next
335			}
336		}
337		result = append(result, id)
338	}
339
340	return result
341}
342