1package stun
2
3import (
4	"crypto/rand"
5	"encoding/base64"
6	"errors"
7	"fmt"
8	"io"
9)
10
11const (
12	// magicCookie is fixed value that aids in distinguishing STUN packets
13	// from packets of other protocols when STUN is multiplexed with those
14	// other protocols on the same Port.
15	//
16	// The magic cookie field MUST contain the fixed value 0x2112A442 in
17	// network byte order.
18	//
19	// Defined in "STUN Message Structure", section 6.
20	magicCookie         = 0x2112A442
21	attributeHeaderSize = 4
22	messageHeaderSize   = 20
23
24	// TransactionIDSize is length of transaction id array (in bytes).
25	TransactionIDSize = 12 // 96 bit
26)
27
28// NewTransactionID returns new random transaction ID using crypto/rand
29// as source.
30func NewTransactionID() (b [TransactionIDSize]byte) {
31	readFullOrPanic(rand.Reader, b[:])
32	return b
33}
34
35// IsMessage returns true if b looks like STUN message.
36// Useful for multiplexing. IsMessage does not guarantee
37// that decoding will be successful.
38func IsMessage(b []byte) bool {
39	return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie
40}
41
42// New returns *Message with pre-allocated Raw.
43func New() *Message {
44	const defaultRawCapacity = 120
45	return &Message{
46		Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
47	}
48}
49
50// ErrDecodeToNil occurs on Decode(data, nil) call.
51var ErrDecodeToNil = errors.New("attempt to decode to nil message")
52
53// Decode decodes Message from data to m, returning error if any.
54func Decode(data []byte, m *Message) error {
55	if m == nil {
56		return ErrDecodeToNil
57	}
58	m.Raw = append(m.Raw[:0], data...)
59	return m.Decode()
60}
61
62// Message represents a single STUN packet. It uses aggressive internal
63// buffering to enable zero-allocation encoding and decoding,
64// so there are some usage constraints:
65//
66// 	Message, its fields, results of m.Get or any attribute a.GetFrom
67//	are valid only until Message.Raw is not modified.
68type Message struct {
69	Type          MessageType
70	Length        uint32 // len(Raw) not including header
71	TransactionID [TransactionIDSize]byte
72	Attributes    Attributes
73	Raw           []byte
74}
75
76// AddTo sets b.TransactionID to m.TransactionID.
77//
78// Implements Setter to aid in crafting responses.
79func (m *Message) AddTo(b *Message) error {
80	b.TransactionID = m.TransactionID
81	b.WriteTransactionID()
82	return nil
83}
84
85// NewTransactionID sets m.TransactionID to random value from crypto/rand
86// and returns error if any.
87func (m *Message) NewTransactionID() error {
88	_, err := io.ReadFull(rand.Reader, m.TransactionID[:])
89	if err == nil {
90		m.WriteTransactionID()
91	}
92	return err
93}
94
95func (m *Message) String() string {
96	tID := base64.StdEncoding.EncodeToString(m.TransactionID[:])
97	return fmt.Sprintf("%s l=%d attrs=%d id=%s", m.Type, m.Length, len(m.Attributes), tID)
98}
99
100// Reset resets Message, attributes and underlying buffer length.
101func (m *Message) Reset() {
102	m.Raw = m.Raw[:0]
103	m.Length = 0
104	m.Attributes = m.Attributes[:0]
105}
106
107// grow ensures that internal buffer has n length.
108func (m *Message) grow(n int) {
109	if len(m.Raw) >= n {
110		return
111	}
112	if cap(m.Raw) >= n {
113		m.Raw = m.Raw[:n]
114		return
115	}
116	m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
117}
118
119// Add appends new attribute to message. Not goroutine-safe.
120//
121// Value of attribute is copied to internal buffer so
122// it is safe to reuse v.
123func (m *Message) Add(t AttrType, v []byte) {
124	// Allocating buffer for TLV (type-length-value).
125	// T = t, L = len(v), V = v.
126	// m.Raw will look like:
127	// [0:20]                               <- message header
128	// [20:20+m.Length]                     <- existing message attributes
129	// [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
130	// [first:last]                         <- same as previous
131	// [0 1|2 3|4    4 + len(v)]            <- mapping for allocated buffer
132	//   T   L        V
133	allocSize := attributeHeaderSize + len(v)  // ~ len(TLV) = len(TL) + len(V)
134	first := messageHeaderSize + int(m.Length) // first byte number
135	last := first + allocSize                  // last byte number
136	m.grow(last)                               // growing cap(Raw) to fit TLV
137	m.Raw = m.Raw[:last]                       // now len(Raw) = last
138	m.Length += uint32(allocSize)              // rendering length change
139
140	// Sub-slicing internal buffer to simplify encoding.
141	buf := m.Raw[first:last]           // slice for TLV
142	value := buf[attributeHeaderSize:] // slice for V
143	attr := RawAttribute{
144		Type:   t,              // T
145		Length: uint16(len(v)), // L
146		Value:  value,          // V
147	}
148
149	// Encoding attribute TLV to allocated buffer.
150	bin.PutUint16(buf[0:2], attr.Type.Value()) // T
151	bin.PutUint16(buf[2:4], attr.Length)       // L
152	copy(value, v)                             // V
153
154	// Checking that attribute value needs padding.
155	if attr.Length%padding != 0 {
156		// Performing padding.
157		bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
158		last += bytesToAdd
159		m.grow(last)
160		// setting all padding bytes to zero
161		// to prevent data leak from previous
162		// data in next bytesToAdd bytes
163		buf = m.Raw[last-bytesToAdd : last]
164		for i := range buf {
165			buf[i] = 0
166		}
167		m.Raw = m.Raw[:last]           // increasing buffer length
168		m.Length += uint32(bytesToAdd) // rendering length change
169	}
170	m.Attributes = append(m.Attributes, attr)
171	m.WriteLength()
172}
173
174func attrSliceEqual(a, b Attributes) bool {
175	for _, attr := range a {
176		found := false
177		for _, attrB := range b {
178			if attrB.Type != attr.Type {
179				continue
180			}
181			if attrB.Equal(attr) {
182				found = true
183				break
184			}
185		}
186		if !found {
187			return false
188		}
189	}
190	return true
191}
192
193func attrEqual(a, b Attributes) bool {
194	if a == nil && b == nil {
195		return true
196	}
197	if a == nil || b == nil {
198		return false
199	}
200	if len(a) != len(b) {
201		return false
202	}
203	if !attrSliceEqual(a, b) {
204		return false
205	}
206	if !attrSliceEqual(b, a) {
207		return false
208	}
209	return true
210}
211
212// Equal returns true if Message b equals to m.
213// Ignores m.Raw.
214func (m *Message) Equal(b *Message) bool {
215	if m == nil && b == nil {
216		return true
217	}
218	if m == nil || b == nil {
219		return false
220	}
221	if m.Type != b.Type {
222		return false
223	}
224	if m.TransactionID != b.TransactionID {
225		return false
226	}
227	if m.Length != b.Length {
228		return false
229	}
230	if !attrEqual(m.Attributes, b.Attributes) {
231		return false
232	}
233	return true
234}
235
236// WriteLength writes m.Length to m.Raw. Call is valid only if len(m.Raw) >= 4.
237func (m *Message) WriteLength() {
238	_ = m.Raw[4] // early bounds check to guarantee safety of writes below
239	bin.PutUint16(m.Raw[2:4], uint16(m.Length))
240}
241
242// WriteHeader writes header to underlying buffer. Not goroutine-safe.
243func (m *Message) WriteHeader() {
244	if len(m.Raw) < messageHeaderSize {
245		// Making WriteHeader call valid even when m.Raw
246		// is nil or len(m.Raw) is less than needed for header.
247		m.grow(messageHeaderSize)
248	}
249	_ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below
250
251	m.WriteType()
252	m.WriteLength()
253	bin.PutUint32(m.Raw[4:8], magicCookie)               // magic cookie
254	copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
255}
256
257// WriteTransactionID writes m.TransactionID to m.Raw.
258func (m *Message) WriteTransactionID() {
259	copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
260}
261
262// WriteAttributes encodes all m.Attributes to m.
263func (m *Message) WriteAttributes() {
264	attributes := m.Attributes
265	m.Attributes = attributes[:0]
266	for _, a := range attributes {
267		m.Add(a.Type, a.Value)
268	}
269	m.Attributes = attributes
270}
271
272// WriteType writes m.Type to m.Raw.
273func (m *Message) WriteType() {
274	bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type
275}
276
277// SetType sets m.Type and writes it to m.Raw.
278func (m *Message) SetType(t MessageType) {
279	m.Type = t
280	m.WriteType()
281}
282
283// Encode re-encodes message into m.Raw.
284func (m *Message) Encode() {
285	m.Raw = m.Raw[:0]
286	m.WriteHeader()
287	m.Length = 0
288	m.WriteAttributes()
289}
290
291// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
292// call result.
293func (m *Message) WriteTo(w io.Writer) (int64, error) {
294	n, err := w.Write(m.Raw)
295	return int64(n), err
296}
297
298// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
299// Decodes it and return error if any. If m.Raw is too small, will return
300// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
301//
302// Can return *DecodeErr while decoding too.
303func (m *Message) ReadFrom(r io.Reader) (int64, error) {
304	tBuf := m.Raw[:cap(m.Raw)]
305	var (
306		n   int
307		err error
308	)
309	if n, err = r.Read(tBuf); err != nil {
310		return int64(n), err
311	}
312	m.Raw = tBuf[:n]
313	return int64(n), m.Decode()
314}
315
316// ErrUnexpectedHeaderEOF means that there were not enough bytes in
317// m.Raw to read header.
318var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header")
319
320// Decode decodes m.Raw into m.
321func (m *Message) Decode() error {
322	// decoding message header
323	buf := m.Raw
324	if len(buf) < messageHeaderSize {
325		return ErrUnexpectedHeaderEOF
326	}
327	var (
328		t        = bin.Uint16(buf[0:2])      // first 2 bytes
329		size     = int(bin.Uint16(buf[2:4])) // second 2 bytes
330		cookie   = bin.Uint32(buf[4:8])      // last 4 bytes
331		fullSize = messageHeaderSize + size  // len(m.Raw)
332	)
333	if cookie != magicCookie {
334		msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
335		return newDecodeErr("message", "cookie", msg)
336	}
337	if len(buf) < fullSize {
338		msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
339		return newAttrDecodeErr("message", msg)
340	}
341	// saving header data
342	m.Type.ReadValue(t)
343	m.Length = uint32(size)
344	copy(m.TransactionID[:], buf[8:messageHeaderSize])
345
346	m.Attributes = m.Attributes[:0]
347	var (
348		offset = 0
349		b      = buf[messageHeaderSize:fullSize]
350	)
351	for offset < size {
352		// checking that we have enough bytes to read header
353		if len(b) < attributeHeaderSize {
354			msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
355			return newAttrDecodeErr("header", msg)
356		}
357		var (
358			a = RawAttribute{
359				Type:   compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
360				Length: bin.Uint16(b[2:4]),                 // second 2 bytes
361			}
362			aL     = int(a.Length)                // attribute length
363			aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
364		)
365		b = b[attributeHeaderSize:] // slicing again to simplify value read
366		offset += attributeHeaderSize
367		if len(b) < aBuffL { // checking size
368			msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type)
369			return newAttrDecodeErr("value", msg)
370		}
371		a.Value = b[:aL]
372		offset += aBuffL
373		b = b[aBuffL:]
374
375		m.Attributes = append(m.Attributes, a)
376	}
377	return nil
378}
379
380// Write decodes message and return error if any.
381//
382// Any error is unrecoverable, but message could be partially decoded.
383func (m *Message) Write(tBuf []byte) (int, error) {
384	m.Raw = append(m.Raw[:0], tBuf...)
385	return len(tBuf), m.Decode()
386}
387
388// CloneTo clones m to b securing any further m mutations.
389func (m *Message) CloneTo(b *Message) error {
390	// TODO(ar): implement low-level copy.
391	b.Raw = append(b.Raw[:0], m.Raw...)
392	return b.Decode()
393}
394
395// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
396type MessageClass byte
397
398// Possible values for message class in STUN Message Type.
399const (
400	ClassRequest         MessageClass = 0x00 // 0b00
401	ClassIndication      MessageClass = 0x01 // 0b01
402	ClassSuccessResponse MessageClass = 0x02 // 0b10
403	ClassErrorResponse   MessageClass = 0x03 // 0b11
404)
405
406// Common STUN message types.
407var (
408	// Binding request message type.
409	BindingRequest = NewType(MethodBinding, ClassRequest)
410	// Binding success response message type
411	BindingSuccess = NewType(MethodBinding, ClassSuccessResponse)
412	// Binding error response message type.
413	BindingError = NewType(MethodBinding, ClassErrorResponse)
414)
415
416func (c MessageClass) String() string {
417	switch c {
418	case ClassRequest:
419		return "request"
420	case ClassIndication:
421		return "indication"
422	case ClassSuccessResponse:
423		return "success response"
424	case ClassErrorResponse:
425		return "error response"
426	default:
427		panic("unknown message class")
428	}
429}
430
431// Method is uint16 representation of 12-bit STUN method.
432type Method uint16
433
434// Possible methods for STUN Message.
435const (
436	MethodBinding          Method = 0x001
437	MethodAllocate         Method = 0x003
438	MethodRefresh          Method = 0x004
439	MethodSend             Method = 0x006
440	MethodData             Method = 0x007
441	MethodCreatePermission Method = 0x008
442	MethodChannelBind      Method = 0x009
443)
444
445// Methods from RFC 6062.
446const (
447	MethodConnect           Method = 0x000a
448	MethodConnectionBind    Method = 0x000b
449	MethodConnectionAttempt Method = 0x000c
450)
451
452var methodName = map[Method]string{
453	MethodBinding:          "Binding",
454	MethodAllocate:         "Allocate",
455	MethodRefresh:          "Refresh",
456	MethodSend:             "Send",
457	MethodData:             "Data",
458	MethodCreatePermission: "CreatePermission",
459	MethodChannelBind:      "ChannelBind",
460
461	// RFC 6062.
462	MethodConnect:           "Connect",
463	MethodConnectionBind:    "ConnectionBind",
464	MethodConnectionAttempt: "ConnectionAttempt",
465}
466
467func (m Method) String() string {
468	s, ok := methodName[m]
469	if !ok {
470		// Falling back to hex representation.
471		s = fmt.Sprintf("0x%x", uint16(m))
472	}
473	return s
474}
475
476// MessageType is STUN Message Type Field.
477type MessageType struct {
478	Method Method       // e.g. binding
479	Class  MessageClass // e.g. request
480}
481
482// AddTo sets m type to t.
483func (t MessageType) AddTo(m *Message) error {
484	m.SetType(t)
485	return nil
486}
487
488// NewType returns new message type with provided method and class.
489func NewType(method Method, class MessageClass) MessageType {
490	return MessageType{
491		Method: method,
492		Class:  class,
493	}
494}
495
496const (
497	methodABits = 0xf   // 0b0000000000001111
498	methodBBits = 0x70  // 0b0000000001110000
499	methodDBits = 0xf80 // 0b0000111110000000
500
501	methodBShift = 1
502	methodDShift = 2
503
504	firstBit  = 0x1
505	secondBit = 0x2
506
507	c0Bit = firstBit
508	c1Bit = secondBit
509
510	classC0Shift = 4
511	classC1Shift = 7
512)
513
514// Value returns bit representation of messageType.
515func (t MessageType) Value() uint16 {
516	//	 0                 1
517	//	 2  3  4 5 6 7 8 9 0 1 2 3 4 5
518	//	+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
519	//	|M |M |M|M|M|C|M|M|M|C|M|M|M|M|
520	//	|11|10|9|8|7|1|6|5|4|0|3|2|1|0|
521	//	+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
522	// Figure 3: Format of STUN Message Type Field
523
524	// Warning: Abandon all hope ye who enter here.
525	// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
526	m := uint16(t.Method)
527	a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
528	b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
529	d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
530
531	// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
532	m = a + (b << methodBShift) + (d << methodDShift)
533
534	// C0 is zero bit of C, C1 is first bit.
535	// C0 = C * 0b01, C1 = (C * 0b10) >> 1
536	// Ct = C0 << 4 + C1 << 8.
537	// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
538	// We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
539	// (see figure 3).
540	c := uint16(t.Class)
541	c0 := (c & c0Bit) << classC0Shift
542	c1 := (c & c1Bit) << classC1Shift
543	class := c0 + c1
544
545	return m + class
546}
547
548// ReadValue decodes uint16 into MessageType.
549func (t *MessageType) ReadValue(v uint16) {
550	// Decoding class.
551	// We are taking first bit from v >> 4 and second from v >> 7.
552	c0 := (v >> classC0Shift) & c0Bit
553	c1 := (v >> classC1Shift) & c1Bit
554	class := c0 + c1
555	t.Class = MessageClass(class)
556
557	// Decoding method.
558	a := v & methodABits                   // A(M0-M3)
559	b := (v >> methodBShift) & methodBBits // B(M4-M6)
560	d := (v >> methodDShift) & methodDBits // D(M7-M11)
561	m := a + b + d
562	t.Method = Method(m)
563}
564
565func (t MessageType) String() string {
566	return fmt.Sprintf("%s %s", t.Method, t.Class)
567}
568
569// Contains return true if message contain t attribute.
570func (m *Message) Contains(t AttrType) bool {
571	for _, a := range m.Attributes {
572		if a.Type == t {
573			return true
574		}
575	}
576	return false
577}
578
579type transactionIDValueSetter [TransactionIDSize]byte
580
581// NewTransactionIDSetter returns new Setter that sets message transaction id
582// to provided value.
583func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter {
584	return transactionIDValueSetter(value)
585}
586
587func (t transactionIDValueSetter) AddTo(m *Message) error {
588	m.TransactionID = t
589	m.WriteTransactionID()
590	return nil
591}
592