1// Copyright (C) 2019 Storj Labs, Inc. 2// See LICENSE for copying information. 3 4package storj 5 6import ( 7 "crypto/x509/pkix" 8 "database/sql/driver" 9 "encoding/binary" 10 "encoding/json" 11 "math/bits" 12 13 "github.com/btcsuite/btcutil/base58" 14 "github.com/zeebo/errs" 15 16 "storj.io/common/peertls/extensions" 17) 18 19var ( 20 // ErrNodeID is used when something goes wrong with a node id. 21 ErrNodeID = errs.Class("node ID") 22 // ErrVersion is used for identity version related errors. 23 ErrVersion = errs.Class("node ID version") 24) 25 26// NodeIDSize is the byte length of a NodeID. 27const NodeIDSize = 32 28 29// NodeID is a unique node identifier. 30type NodeID [NodeIDSize]byte 31 32// NodeIDList is a slice of NodeIDs (implements sort). 33type NodeIDList []NodeID 34 35// NewVersionedID adds an identity version to a node ID. 36func NewVersionedID(id NodeID, version IDVersion) NodeID { 37 var versionedID NodeID 38 copy(versionedID[:], id[:]) 39 40 versionedID[NodeIDSize-1] = byte(version.Number) 41 return versionedID 42} 43 44// NewVersionExt creates a new identity version certificate extension for the 45// given identity version. 46func NewVersionExt(version IDVersion) pkix.Extension { 47 return pkix.Extension{ 48 Id: extensions.IdentityVersionExtID, 49 Value: []byte{byte(version.Number)}, 50 } 51} 52 53// NodeIDFromString decodes a base58check encoded node id string. 54func NodeIDFromString(s string) (NodeID, error) { 55 idBytes, versionNumber, err := base58.CheckDecode(s) 56 if err != nil { 57 return NodeID{}, ErrNodeID.Wrap(err) 58 } 59 unversionedID, err := NodeIDFromBytes(idBytes) 60 if err != nil { 61 return NodeID{}, err 62 } 63 64 version := IDVersions[IDVersionNumber(versionNumber)] 65 return NewVersionedID(unversionedID, version), nil 66} 67 68// NodeIDsFromBytes converts a 2d byte slice into a list of nodes. 69func NodeIDsFromBytes(b [][]byte) (ids NodeIDList, err error) { 70 var idErrs []error 71 for _, idBytes := range b { 72 id, err := NodeIDFromBytes(idBytes) 73 if err != nil { 74 idErrs = append(idErrs, err) 75 continue 76 } 77 78 ids = append(ids, id) 79 } 80 81 if err = errs.Combine(idErrs...); err != nil { 82 return nil, err 83 } 84 return ids, nil 85} 86 87// NodeIDFromBytes converts a byte slice into a node id. 88func NodeIDFromBytes(b []byte) (NodeID, error) { 89 bLen := len(b) 90 if bLen != len(NodeID{}) { 91 return NodeID{}, ErrNodeID.New("not enough bytes to make a node id; have %d, need %d", bLen, len(NodeID{})) 92 } 93 94 var id NodeID 95 copy(id[:], b) 96 return id, nil 97} 98 99// String returns NodeID as base58 encoded string with checksum and version bytes. 100func (id NodeID) String() string { 101 unversionedID := id.unversioned() 102 return base58.CheckEncode(unversionedID[:], byte(id.Version().Number)) 103} 104 105// IsZero returns whether NodeID is unassigned. 106func (id NodeID) IsZero() bool { 107 return id == NodeID{} 108} 109 110// Bytes returns raw bytes of the id. 111func (id NodeID) Bytes() []byte { return id[:] } 112 113// Less returns whether id is smaller than other in lexicographic order. 114func (id NodeID) Less(other NodeID) bool { 115 a0, b0 := binary.BigEndian.Uint64(id[0:]), binary.BigEndian.Uint64(other[0:]) 116 if a0 < b0 { 117 return true 118 } else if a0 > b0 { 119 return false 120 } 121 122 a1, b1 := binary.BigEndian.Uint64(id[8:]), binary.BigEndian.Uint64(other[8:]) 123 if a1 < b1 { 124 return true 125 } else if a1 > b1 { 126 return false 127 } 128 129 a2, b2 := binary.BigEndian.Uint64(id[16:]), binary.BigEndian.Uint64(other[16:]) 130 if a2 < b2 { 131 return true 132 } else if a2 > b2 { 133 return false 134 } 135 136 a3, b3 := binary.BigEndian.Uint64(id[24:]), binary.BigEndian.Uint64(other[24:]) 137 if a3 < b3 { 138 return true 139 } else if a3 > b3 { 140 return false 141 } 142 143 return false 144} 145 146// Compare returns an integer comparing id and other lexicographically. 147// The result will be 0 if id==other, -1 if id < other, and +1 if id > other. 148func (id NodeID) Compare(other NodeID) int { 149 a0, b0 := binary.BigEndian.Uint64(id[0:]), binary.BigEndian.Uint64(other[0:]) 150 if a0 < b0 { 151 return -1 152 } else if a0 > b0 { 153 return 1 154 } 155 156 a1, b1 := binary.BigEndian.Uint64(id[8:]), binary.BigEndian.Uint64(other[8:]) 157 if a1 < b1 { 158 return -1 159 } else if a1 > b1 { 160 return 1 161 } 162 163 a2, b2 := binary.BigEndian.Uint64(id[16:]), binary.BigEndian.Uint64(other[16:]) 164 if a2 < b2 { 165 return -1 166 } else if a2 > b2 { 167 return 1 168 } 169 170 a3, b3 := binary.BigEndian.Uint64(id[24:]), binary.BigEndian.Uint64(other[24:]) 171 if a3 < b3 { 172 return -1 173 } else if a3 > b3 { 174 return 1 175 } 176 177 return 0 178} 179 180// Version returns the version of the identity format. 181func (id NodeID) Version() IDVersion { 182 versionNumber := id.versionByte() 183 if versionNumber == 0 { 184 return IDVersions[V0] 185 } 186 187 version, err := GetIDVersion(IDVersionNumber(versionNumber)) 188 // NB: when in doubt, use V0 189 if err != nil { 190 return IDVersions[V0] 191 } 192 193 return version 194} 195 196// Difficulty returns the number of trailing zero bits in a node ID. 197func (id NodeID) Difficulty() (uint16, error) { 198 idLen := len(id) 199 var b byte 200 var zeroBits int 201 // NB: last difficulty byte is used for version 202 for i := 2; i <= idLen; i++ { 203 b = id[idLen-i] 204 205 if b != 0 { 206 zeroBits = bits.TrailingZeros16(uint16(b)) 207 if zeroBits == 16 { 208 // we already checked that b != 0. 209 return 0, ErrNodeID.New("impossible codepath!") 210 } 211 212 return uint16((i-1)*8 + zeroBits), nil 213 } 214 } 215 216 return 0, ErrNodeID.New("difficulty matches id hash length: %d; hash (hex): % x", idLen, id) 217} 218 219// Marshal serializes a node id. 220func (id NodeID) Marshal() ([]byte, error) { 221 return id.Bytes(), nil 222} 223 224// MarshalTo serializes a node ID into the passed byte slice. 225func (id *NodeID) MarshalTo(data []byte) (n int, err error) { 226 n = copy(data, id.Bytes()) 227 return n, nil 228} 229 230// Unmarshal deserializes a node ID. 231func (id *NodeID) Unmarshal(data []byte) error { 232 var err error 233 *id, err = NodeIDFromBytes(data) 234 return err 235} 236 237func (id NodeID) versionByte() byte { 238 return id[NodeIDSize-1] 239} 240 241// unversioned returns the node ID with the version byte replaced with `0`. 242// NB: Legacy node IDs (i.e. pre-identity-versions) with a difficulty less 243// than `8` are unsupported. 244func (id NodeID) unversioned() NodeID { 245 unversionedID := NodeID{} 246 copy(unversionedID[:], id[:NodeIDSize-1]) 247 return unversionedID 248} 249 250// Size returns the length of a node ID (implements gogo's custom type interface). 251func (id *NodeID) Size() int { 252 return len(id) 253} 254 255// MarshalJSON serializes a node ID to a json string as bytes. 256func (id NodeID) MarshalJSON() ([]byte, error) { 257 return []byte(`"` + id.String() + `"`), nil 258} 259 260// Value converts a NodeID to a database field. 261func (id NodeID) Value() (driver.Value, error) { 262 return id.Bytes(), nil 263} 264 265// Scan extracts a NodeID from a database field. 266func (id *NodeID) Scan(src interface{}) (err error) { 267 b, ok := src.([]byte) 268 if !ok { 269 return ErrNodeID.New("NodeID Scan expects []byte") 270 } 271 n, err := NodeIDFromBytes(b) 272 *id = n 273 return err 274} 275 276// UnmarshalJSON deserializes a json string (as bytes) to a node ID. 277func (id *NodeID) UnmarshalJSON(data []byte) error { 278 var unquoted string 279 err := json.Unmarshal(data, &unquoted) 280 if err != nil { 281 return err 282 } 283 284 *id, err = NodeIDFromString(unquoted) 285 if err != nil { 286 return err 287 } 288 return nil 289} 290 291// Strings returns a string slice of the node IDs. 292func (n NodeIDList) Strings() []string { 293 var strings []string 294 for _, nid := range n { 295 strings = append(strings, nid.String()) 296 } 297 return strings 298} 299 300// Bytes returns a 2d byte slice of the node IDs. 301func (n NodeIDList) Bytes() (idsBytes [][]byte) { 302 for _, nid := range n { 303 idsBytes = append(idsBytes, nid.Bytes()) 304 } 305 return idsBytes 306} 307 308// Len implements sort.Interface.Len(). 309func (n NodeIDList) Len() int { return len(n) } 310 311// Swap implements sort.Interface.Swap(). 312func (n NodeIDList) Swap(i, j int) { n[i], n[j] = n[j], n[i] } 313 314// Less implements sort.Interface.Less(). 315func (n NodeIDList) Less(i, j int) bool { return n[i].Less(n[j]) } 316 317// Contains tests if the node IDs contain id. 318func (n NodeIDList) Contains(id NodeID) bool { 319 for _, nid := range n { 320 if nid == id { 321 return true 322 } 323 } 324 return false 325} 326 327// Unique returns slice of the unique node IDs. 328func (n NodeIDList) Unique() NodeIDList { 329 var result []NodeID 330next: 331 for _, id := range n { 332 for _, added := range result { 333 if added == id { 334 continue next 335 } 336 } 337 result = append(result, id) 338 } 339 340 return result 341} 342