1package stun 2 3import ( 4 "crypto/rand" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "io" 9) 10 11const ( 12 // magicCookie is fixed value that aids in distinguishing STUN packets 13 // from packets of other protocols when STUN is multiplexed with those 14 // other protocols on the same Port. 15 // 16 // The magic cookie field MUST contain the fixed value 0x2112A442 in 17 // network byte order. 18 // 19 // Defined in "STUN Message Structure", section 6. 20 magicCookie = 0x2112A442 21 attributeHeaderSize = 4 22 messageHeaderSize = 20 23 24 // TransactionIDSize is length of transaction id array (in bytes). 25 TransactionIDSize = 12 // 96 bit 26) 27 28// NewTransactionID returns new random transaction ID using crypto/rand 29// as source. 30func NewTransactionID() (b [TransactionIDSize]byte) { 31 readFullOrPanic(rand.Reader, b[:]) 32 return b 33} 34 35// IsMessage returns true if b looks like STUN message. 36// Useful for multiplexing. IsMessage does not guarantee 37// that decoding will be successful. 38func IsMessage(b []byte) bool { 39 return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie 40} 41 42// New returns *Message with pre-allocated Raw. 43func New() *Message { 44 const defaultRawCapacity = 120 45 return &Message{ 46 Raw: make([]byte, messageHeaderSize, defaultRawCapacity), 47 } 48} 49 50// ErrDecodeToNil occurs on Decode(data, nil) call. 51var ErrDecodeToNil = errors.New("attempt to decode to nil message") 52 53// Decode decodes Message from data to m, returning error if any. 54func Decode(data []byte, m *Message) error { 55 if m == nil { 56 return ErrDecodeToNil 57 } 58 m.Raw = append(m.Raw[:0], data...) 59 return m.Decode() 60} 61 62// Message represents a single STUN packet. It uses aggressive internal 63// buffering to enable zero-allocation encoding and decoding, 64// so there are some usage constraints: 65// 66// Message, its fields, results of m.Get or any attribute a.GetFrom 67// are valid only until Message.Raw is not modified. 68type Message struct { 69 Type MessageType 70 Length uint32 // len(Raw) not including header 71 TransactionID [TransactionIDSize]byte 72 Attributes Attributes 73 Raw []byte 74} 75 76// AddTo sets b.TransactionID to m.TransactionID. 77// 78// Implements Setter to aid in crafting responses. 79func (m *Message) AddTo(b *Message) error { 80 b.TransactionID = m.TransactionID 81 b.WriteTransactionID() 82 return nil 83} 84 85// NewTransactionID sets m.TransactionID to random value from crypto/rand 86// and returns error if any. 87func (m *Message) NewTransactionID() error { 88 _, err := io.ReadFull(rand.Reader, m.TransactionID[:]) 89 if err == nil { 90 m.WriteTransactionID() 91 } 92 return err 93} 94 95func (m *Message) String() string { 96 tID := base64.StdEncoding.EncodeToString(m.TransactionID[:]) 97 return fmt.Sprintf("%s l=%d attrs=%d id=%s", m.Type, m.Length, len(m.Attributes), tID) 98} 99 100// Reset resets Message, attributes and underlying buffer length. 101func (m *Message) Reset() { 102 m.Raw = m.Raw[:0] 103 m.Length = 0 104 m.Attributes = m.Attributes[:0] 105} 106 107// grow ensures that internal buffer has n length. 108func (m *Message) grow(n int) { 109 if len(m.Raw) >= n { 110 return 111 } 112 if cap(m.Raw) >= n { 113 m.Raw = m.Raw[:n] 114 return 115 } 116 m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...) 117} 118 119// Add appends new attribute to message. Not goroutine-safe. 120// 121// Value of attribute is copied to internal buffer so 122// it is safe to reuse v. 123func (m *Message) Add(t AttrType, v []byte) { 124 // Allocating buffer for TLV (type-length-value). 125 // T = t, L = len(v), V = v. 126 // m.Raw will look like: 127 // [0:20] <- message header 128 // [20:20+m.Length] <- existing message attributes 129 // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV 130 // [first:last] <- same as previous 131 // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer 132 // T L V 133 allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V) 134 first := messageHeaderSize + int(m.Length) // first byte number 135 last := first + allocSize // last byte number 136 m.grow(last) // growing cap(Raw) to fit TLV 137 m.Raw = m.Raw[:last] // now len(Raw) = last 138 m.Length += uint32(allocSize) // rendering length change 139 140 // Sub-slicing internal buffer to simplify encoding. 141 buf := m.Raw[first:last] // slice for TLV 142 value := buf[attributeHeaderSize:] // slice for V 143 attr := RawAttribute{ 144 Type: t, // T 145 Length: uint16(len(v)), // L 146 Value: value, // V 147 } 148 149 // Encoding attribute TLV to allocated buffer. 150 bin.PutUint16(buf[0:2], attr.Type.Value()) // T 151 bin.PutUint16(buf[2:4], attr.Length) // L 152 copy(value, v) // V 153 154 // Checking that attribute value needs padding. 155 if attr.Length%padding != 0 { 156 // Performing padding. 157 bytesToAdd := nearestPaddedValueLength(len(v)) - len(v) 158 last += bytesToAdd 159 m.grow(last) 160 // setting all padding bytes to zero 161 // to prevent data leak from previous 162 // data in next bytesToAdd bytes 163 buf = m.Raw[last-bytesToAdd : last] 164 for i := range buf { 165 buf[i] = 0 166 } 167 m.Raw = m.Raw[:last] // increasing buffer length 168 m.Length += uint32(bytesToAdd) // rendering length change 169 } 170 m.Attributes = append(m.Attributes, attr) 171 m.WriteLength() 172} 173 174func attrSliceEqual(a, b Attributes) bool { 175 for _, attr := range a { 176 found := false 177 for _, attrB := range b { 178 if attrB.Type != attr.Type { 179 continue 180 } 181 if attrB.Equal(attr) { 182 found = true 183 break 184 } 185 } 186 if !found { 187 return false 188 } 189 } 190 return true 191} 192 193func attrEqual(a, b Attributes) bool { 194 if a == nil && b == nil { 195 return true 196 } 197 if a == nil || b == nil { 198 return false 199 } 200 if len(a) != len(b) { 201 return false 202 } 203 if !attrSliceEqual(a, b) { 204 return false 205 } 206 if !attrSliceEqual(b, a) { 207 return false 208 } 209 return true 210} 211 212// Equal returns true if Message b equals to m. 213// Ignores m.Raw. 214func (m *Message) Equal(b *Message) bool { 215 if m == nil && b == nil { 216 return true 217 } 218 if m == nil || b == nil { 219 return false 220 } 221 if m.Type != b.Type { 222 return false 223 } 224 if m.TransactionID != b.TransactionID { 225 return false 226 } 227 if m.Length != b.Length { 228 return false 229 } 230 if !attrEqual(m.Attributes, b.Attributes) { 231 return false 232 } 233 return true 234} 235 236// WriteLength writes m.Length to m.Raw. Call is valid only if len(m.Raw) >= 4. 237func (m *Message) WriteLength() { 238 _ = m.Raw[4] // early bounds check to guarantee safety of writes below 239 bin.PutUint16(m.Raw[2:4], uint16(m.Length)) 240} 241 242// WriteHeader writes header to underlying buffer. Not goroutine-safe. 243func (m *Message) WriteHeader() { 244 if len(m.Raw) < messageHeaderSize { 245 // Making WriteHeader call valid even when m.Raw 246 // is nil or len(m.Raw) is less than needed for header. 247 m.grow(messageHeaderSize) 248 } 249 _ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below 250 251 m.WriteType() 252 m.WriteLength() 253 bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie 254 copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID 255} 256 257// WriteTransactionID writes m.TransactionID to m.Raw. 258func (m *Message) WriteTransactionID() { 259 copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID 260} 261 262// WriteAttributes encodes all m.Attributes to m. 263func (m *Message) WriteAttributes() { 264 attributes := m.Attributes 265 m.Attributes = attributes[:0] 266 for _, a := range attributes { 267 m.Add(a.Type, a.Value) 268 } 269 m.Attributes = attributes 270} 271 272// WriteType writes m.Type to m.Raw. 273func (m *Message) WriteType() { 274 bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type 275} 276 277// SetType sets m.Type and writes it to m.Raw. 278func (m *Message) SetType(t MessageType) { 279 m.Type = t 280 m.WriteType() 281} 282 283// Encode re-encodes message into m.Raw. 284func (m *Message) Encode() { 285 m.Raw = m.Raw[:0] 286 m.WriteHeader() 287 m.Length = 0 288 m.WriteAttributes() 289} 290 291// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning 292// call result. 293func (m *Message) WriteTo(w io.Writer) (int64, error) { 294 n, err := w.Write(m.Raw) 295 return int64(n), err 296} 297 298// ReadFrom implements ReaderFrom. Reads message from r into m.Raw, 299// Decodes it and return error if any. If m.Raw is too small, will return 300// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr. 301// 302// Can return *DecodeErr while decoding too. 303func (m *Message) ReadFrom(r io.Reader) (int64, error) { 304 tBuf := m.Raw[:cap(m.Raw)] 305 var ( 306 n int 307 err error 308 ) 309 if n, err = r.Read(tBuf); err != nil { 310 return int64(n), err 311 } 312 m.Raw = tBuf[:n] 313 return int64(n), m.Decode() 314} 315 316// ErrUnexpectedHeaderEOF means that there were not enough bytes in 317// m.Raw to read header. 318var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header") 319 320// Decode decodes m.Raw into m. 321func (m *Message) Decode() error { 322 // decoding message header 323 buf := m.Raw 324 if len(buf) < messageHeaderSize { 325 return ErrUnexpectedHeaderEOF 326 } 327 var ( 328 t = bin.Uint16(buf[0:2]) // first 2 bytes 329 size = int(bin.Uint16(buf[2:4])) // second 2 bytes 330 cookie = bin.Uint32(buf[4:8]) // last 4 bytes 331 fullSize = messageHeaderSize + size // len(m.Raw) 332 ) 333 if cookie != magicCookie { 334 msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie) 335 return newDecodeErr("message", "cookie", msg) 336 } 337 if len(buf) < fullSize { 338 msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize) 339 return newAttrDecodeErr("message", msg) 340 } 341 // saving header data 342 m.Type.ReadValue(t) 343 m.Length = uint32(size) 344 copy(m.TransactionID[:], buf[8:messageHeaderSize]) 345 346 m.Attributes = m.Attributes[:0] 347 var ( 348 offset = 0 349 b = buf[messageHeaderSize:fullSize] 350 ) 351 for offset < size { 352 // checking that we have enough bytes to read header 353 if len(b) < attributeHeaderSize { 354 msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize) 355 return newAttrDecodeErr("header", msg) 356 } 357 var ( 358 a = RawAttribute{ 359 Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes 360 Length: bin.Uint16(b[2:4]), // second 2 bytes 361 } 362 aL = int(a.Length) // attribute length 363 aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding) 364 ) 365 b = b[attributeHeaderSize:] // slicing again to simplify value read 366 offset += attributeHeaderSize 367 if len(b) < aBuffL { // checking size 368 msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type) 369 return newAttrDecodeErr("value", msg) 370 } 371 a.Value = b[:aL] 372 offset += aBuffL 373 b = b[aBuffL:] 374 375 m.Attributes = append(m.Attributes, a) 376 } 377 return nil 378} 379 380// Write decodes message and return error if any. 381// 382// Any error is unrecoverable, but message could be partially decoded. 383func (m *Message) Write(tBuf []byte) (int, error) { 384 m.Raw = append(m.Raw[:0], tBuf...) 385 return len(tBuf), m.Decode() 386} 387 388// CloneTo clones m to b securing any further m mutations. 389func (m *Message) CloneTo(b *Message) error { 390 // TODO(ar): implement low-level copy. 391 b.Raw = append(b.Raw[:0], m.Raw...) 392 return b.Decode() 393} 394 395// MessageClass is 8-bit representation of 2-bit class of STUN Message Class. 396type MessageClass byte 397 398// Possible values for message class in STUN Message Type. 399const ( 400 ClassRequest MessageClass = 0x00 // 0b00 401 ClassIndication MessageClass = 0x01 // 0b01 402 ClassSuccessResponse MessageClass = 0x02 // 0b10 403 ClassErrorResponse MessageClass = 0x03 // 0b11 404) 405 406// Common STUN message types. 407var ( 408 // Binding request message type. 409 BindingRequest = NewType(MethodBinding, ClassRequest) 410 // Binding success response message type 411 BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) 412 // Binding error response message type. 413 BindingError = NewType(MethodBinding, ClassErrorResponse) 414) 415 416func (c MessageClass) String() string { 417 switch c { 418 case ClassRequest: 419 return "request" 420 case ClassIndication: 421 return "indication" 422 case ClassSuccessResponse: 423 return "success response" 424 case ClassErrorResponse: 425 return "error response" 426 default: 427 panic("unknown message class") 428 } 429} 430 431// Method is uint16 representation of 12-bit STUN method. 432type Method uint16 433 434// Possible methods for STUN Message. 435const ( 436 MethodBinding Method = 0x001 437 MethodAllocate Method = 0x003 438 MethodRefresh Method = 0x004 439 MethodSend Method = 0x006 440 MethodData Method = 0x007 441 MethodCreatePermission Method = 0x008 442 MethodChannelBind Method = 0x009 443) 444 445// Methods from RFC 6062. 446const ( 447 MethodConnect Method = 0x000a 448 MethodConnectionBind Method = 0x000b 449 MethodConnectionAttempt Method = 0x000c 450) 451 452var methodName = map[Method]string{ 453 MethodBinding: "Binding", 454 MethodAllocate: "Allocate", 455 MethodRefresh: "Refresh", 456 MethodSend: "Send", 457 MethodData: "Data", 458 MethodCreatePermission: "CreatePermission", 459 MethodChannelBind: "ChannelBind", 460 461 // RFC 6062. 462 MethodConnect: "Connect", 463 MethodConnectionBind: "ConnectionBind", 464 MethodConnectionAttempt: "ConnectionAttempt", 465} 466 467func (m Method) String() string { 468 s, ok := methodName[m] 469 if !ok { 470 // Falling back to hex representation. 471 s = fmt.Sprintf("0x%x", uint16(m)) 472 } 473 return s 474} 475 476// MessageType is STUN Message Type Field. 477type MessageType struct { 478 Method Method // e.g. binding 479 Class MessageClass // e.g. request 480} 481 482// AddTo sets m type to t. 483func (t MessageType) AddTo(m *Message) error { 484 m.SetType(t) 485 return nil 486} 487 488// NewType returns new message type with provided method and class. 489func NewType(method Method, class MessageClass) MessageType { 490 return MessageType{ 491 Method: method, 492 Class: class, 493 } 494} 495 496const ( 497 methodABits = 0xf // 0b0000000000001111 498 methodBBits = 0x70 // 0b0000000001110000 499 methodDBits = 0xf80 // 0b0000111110000000 500 501 methodBShift = 1 502 methodDShift = 2 503 504 firstBit = 0x1 505 secondBit = 0x2 506 507 c0Bit = firstBit 508 c1Bit = secondBit 509 510 classC0Shift = 4 511 classC1Shift = 7 512) 513 514// Value returns bit representation of messageType. 515func (t MessageType) Value() uint16 { 516 // 0 1 517 // 2 3 4 5 6 7 8 9 0 1 2 3 4 5 518 // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ 519 // |M |M |M|M|M|C|M|M|M|C|M|M|M|M| 520 // |11|10|9|8|7|1|6|5|4|0|3|2|1|0| 521 // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ 522 // Figure 3: Format of STUN Message Type Field 523 524 // Warning: Abandon all hope ye who enter here. 525 // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11). 526 m := uint16(t.Method) 527 a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits) 528 b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A) 529 d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B) 530 531 // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). 532 m = a + (b << methodBShift) + (d << methodDShift) 533 534 // C0 is zero bit of C, C1 is first bit. 535 // C0 = C * 0b01, C1 = (C * 0b10) >> 1 536 // Ct = C0 << 4 + C1 << 8. 537 // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7" 538 // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions 539 // (see figure 3). 540 c := uint16(t.Class) 541 c0 := (c & c0Bit) << classC0Shift 542 c1 := (c & c1Bit) << classC1Shift 543 class := c0 + c1 544 545 return m + class 546} 547 548// ReadValue decodes uint16 into MessageType. 549func (t *MessageType) ReadValue(v uint16) { 550 // Decoding class. 551 // We are taking first bit from v >> 4 and second from v >> 7. 552 c0 := (v >> classC0Shift) & c0Bit 553 c1 := (v >> classC1Shift) & c1Bit 554 class := c0 + c1 555 t.Class = MessageClass(class) 556 557 // Decoding method. 558 a := v & methodABits // A(M0-M3) 559 b := (v >> methodBShift) & methodBBits // B(M4-M6) 560 d := (v >> methodDShift) & methodDBits // D(M7-M11) 561 m := a + b + d 562 t.Method = Method(m) 563} 564 565func (t MessageType) String() string { 566 return fmt.Sprintf("%s %s", t.Method, t.Class) 567} 568 569// Contains return true if message contain t attribute. 570func (m *Message) Contains(t AttrType) bool { 571 for _, a := range m.Attributes { 572 if a.Type == t { 573 return true 574 } 575 } 576 return false 577} 578 579type transactionIDValueSetter [TransactionIDSize]byte 580 581// NewTransactionIDSetter returns new Setter that sets message transaction id 582// to provided value. 583func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter { 584 return transactionIDValueSetter(value) 585} 586 587func (t transactionIDValueSetter) AddTo(m *Message) error { 588 m.TransactionID = t 589 m.WriteTransactionID() 590 return nil 591} 592