1package memberlist
2
3import (
4	"bufio"
5	"fmt"
6	"io"
7	"net"
8)
9
10// General approach is to prefix all packets and streams with the same structure:
11//
12// magic type byte (244): uint8
13// length of label name:  uint8 (because labels can't be longer than 255 bytes)
14// label name:            []uint8
15
16// LabelMaxSize is the maximum length of a packet or stream label.
17const LabelMaxSize = 255
18
19// AddLabelHeaderToPacket prefixes outgoing packets with the correct header if
20// the label is not empty.
21func AddLabelHeaderToPacket(buf []byte, label string) ([]byte, error) {
22	if label == "" {
23		return buf, nil
24	}
25	if len(label) > LabelMaxSize {
26		return nil, fmt.Errorf("label %q is too long", label)
27	}
28
29	return makeLabelHeader(label, buf), nil
30}
31
32// RemoveLabelHeaderFromPacket removes any label header from the provided
33// packet and returns it along with the remaining packet contents.
34func RemoveLabelHeaderFromPacket(buf []byte) (newBuf []byte, label string, err error) {
35	if len(buf) == 0 {
36		return buf, "", nil // can't possibly be labeled
37	}
38
39	// [type:byte] [size:byte] [size bytes]
40
41	msgType := messageType(buf[0])
42	if msgType != hasLabelMsg {
43		return buf, "", nil
44	}
45
46	if len(buf) < 2 {
47		return nil, "", fmt.Errorf("cannot decode label; packet has been truncated")
48	}
49
50	size := int(buf[1])
51	if size < 1 {
52		return nil, "", fmt.Errorf("label header cannot be empty when present")
53	}
54
55	if len(buf) < 2+size {
56		return nil, "", fmt.Errorf("cannot decode label; packet has been truncated")
57	}
58
59	label = string(buf[2 : 2+size])
60	newBuf = buf[2+size:]
61
62	return newBuf, label, nil
63}
64
65// AddLabelHeaderToStream prefixes outgoing streams with the correct header if
66// the label is not empty.
67func AddLabelHeaderToStream(conn net.Conn, label string) error {
68	if label == "" {
69		return nil
70	}
71	if len(label) > LabelMaxSize {
72		return fmt.Errorf("label %q is too long", label)
73	}
74
75	header := makeLabelHeader(label, nil)
76
77	_, err := conn.Write(header)
78	return err
79}
80
81// RemoveLabelHeaderFromStream removes any label header from the beginning of
82// the stream if present and returns it along with an updated conn with that
83// header removed.
84//
85// Note that on error it is the caller's responsibility to close the
86// connection.
87func RemoveLabelHeaderFromStream(conn net.Conn) (net.Conn, string, error) {
88	br := bufio.NewReader(conn)
89
90	// First check for the type byte.
91	peeked, err := br.Peek(1)
92	if err != nil {
93		if err == io.EOF {
94			// It is safe to return the original net.Conn at this point because
95			// it never contained any data in the first place so we don't have
96			// to splice the buffer into the conn because both are empty.
97			return conn, "", nil
98		}
99		return nil, "", err
100	}
101
102	msgType := messageType(peeked[0])
103	if msgType != hasLabelMsg {
104		conn, err = newPeekedConnFromBufferedReader(conn, br, 0)
105		return conn, "", err
106	}
107
108	// We are guaranteed to get a size byte as well.
109	peeked, err = br.Peek(2)
110	if err != nil {
111		if err == io.EOF {
112			return nil, "", fmt.Errorf("cannot decode label; stream has been truncated")
113		}
114		return nil, "", err
115	}
116
117	size := int(peeked[1])
118	if size < 1 {
119		return nil, "", fmt.Errorf("label header cannot be empty when present")
120	}
121	// NOTE: we don't have to check this against LabelMaxSize because a byte
122	// already has a max value of 255.
123
124	// Once we know the size we can peek the label as well. Note that since we
125	// are using the default bufio.Reader size of 4096, the entire label header
126	// fits in the initial buffer fill so this should be free.
127	peeked, err = br.Peek(2 + size)
128	if err != nil {
129		if err == io.EOF {
130			return nil, "", fmt.Errorf("cannot decode label; stream has been truncated")
131		}
132		return nil, "", err
133	}
134
135	label := string(peeked[2 : 2+size])
136
137	conn, err = newPeekedConnFromBufferedReader(conn, br, 2+size)
138	if err != nil {
139		return nil, "", err
140	}
141
142	return conn, label, nil
143}
144
145// newPeekedConnFromBufferedReader will splice the buffer contents after the
146// offset into the provided net.Conn and return the result so that the rest of
147// the buffer contents are returned first when reading from the returned
148// peekedConn before moving on to the unbuffered conn contents.
149func newPeekedConnFromBufferedReader(conn net.Conn, br *bufio.Reader, offset int) (*peekedConn, error) {
150	// Extract any of the readahead buffer.
151	peeked, err := br.Peek(br.Buffered())
152	if err != nil {
153		return nil, err
154	}
155
156	return &peekedConn{
157		Peeked: peeked[offset:],
158		Conn:   conn,
159	}, nil
160}
161
162func makeLabelHeader(label string, rest []byte) []byte {
163	newBuf := make([]byte, 2, 2+len(label)+len(rest))
164	newBuf[0] = byte(hasLabelMsg)
165	newBuf[1] = byte(len(label))
166	newBuf = append(newBuf, []byte(label)...)
167	if len(rest) > 0 {
168		newBuf = append(newBuf, []byte(rest)...)
169	}
170	return newBuf
171}
172
173func labelOverhead(label string) int {
174	if label == "" {
175		return 0
176	}
177	return 2 + len(label)
178}
179