1package packets
2
3import (
4	"bytes"
5	"encoding/binary"
6	"errors"
7	"fmt"
8	"io"
9)
10
11//ControlPacket defines the interface for structs intended to hold
12//decoded MQTT packets, either from being read or before being
13//written
14type ControlPacket interface {
15	Write(io.Writer) error
16	Unpack(io.Reader) error
17	String() string
18	Details() Details
19}
20
21//PacketNames maps the constants for each of the MQTT packet types
22//to a string representation of their name.
23var PacketNames = map[uint8]string{
24	1:  "CONNECT",
25	2:  "CONNACK",
26	3:  "PUBLISH",
27	4:  "PUBACK",
28	5:  "PUBREC",
29	6:  "PUBREL",
30	7:  "PUBCOMP",
31	8:  "SUBSCRIBE",
32	9:  "SUBACK",
33	10: "UNSUBSCRIBE",
34	11: "UNSUBACK",
35	12: "PINGREQ",
36	13: "PINGRESP",
37	14: "DISCONNECT",
38}
39
40//Below are the constants assigned to each of the MQTT packet types
41const (
42	Connect     = 1
43	Connack     = 2
44	Publish     = 3
45	Puback      = 4
46	Pubrec      = 5
47	Pubrel      = 6
48	Pubcomp     = 7
49	Subscribe   = 8
50	Suback      = 9
51	Unsubscribe = 10
52	Unsuback    = 11
53	Pingreq     = 12
54	Pingresp    = 13
55	Disconnect  = 14
56)
57
58//Below are the const definitions for error codes returned by
59//Connect()
60const (
61	Accepted                        = 0x00
62	ErrRefusedBadProtocolVersion    = 0x01
63	ErrRefusedIDRejected            = 0x02
64	ErrRefusedServerUnavailable     = 0x03
65	ErrRefusedBadUsernameOrPassword = 0x04
66	ErrRefusedNotAuthorised         = 0x05
67	ErrNetworkError                 = 0xFE
68	ErrProtocolViolation            = 0xFF
69)
70
71//ConnackReturnCodes is a map of the error codes constants for Connect()
72//to a string representation of the error
73var ConnackReturnCodes = map[uint8]string{
74	0:   "Connection Accepted",
75	1:   "Connection Refused: Bad Protocol Version",
76	2:   "Connection Refused: Client Identifier Rejected",
77	3:   "Connection Refused: Server Unavailable",
78	4:   "Connection Refused: Username or Password in unknown format",
79	5:   "Connection Refused: Not Authorised",
80	254: "Connection Error",
81	255: "Connection Refused: Protocol Violation",
82}
83
84//ConnErrors is a map of the errors codes constants for Connect()
85//to a Go error
86var ConnErrors = map[byte]error{
87	Accepted:                        nil,
88	ErrRefusedBadProtocolVersion:    errors.New("Unnacceptable protocol version"),
89	ErrRefusedIDRejected:            errors.New("Identifier rejected"),
90	ErrRefusedServerUnavailable:     errors.New("Server Unavailable"),
91	ErrRefusedBadUsernameOrPassword: errors.New("Bad user name or password"),
92	ErrRefusedNotAuthorised:         errors.New("Not Authorized"),
93	ErrNetworkError:                 errors.New("Network Error"),
94	ErrProtocolViolation:            errors.New("Protocol Violation"),
95}
96
97//ReadPacket takes an instance of an io.Reader (such as net.Conn) and attempts
98//to read an MQTT packet from the stream. It returns a ControlPacket
99//representing the decoded MQTT packet and an error. One of these returns will
100//always be nil, a nil ControlPacket indicating an error occurred.
101func ReadPacket(r io.Reader) (cp ControlPacket, err error) {
102	var fh FixedHeader
103	b := make([]byte, 1)
104
105	_, err = io.ReadFull(r, b)
106	if err != nil {
107		return nil, err
108	}
109	fh.unpack(b[0], r)
110	cp = NewControlPacketWithHeader(fh)
111	if cp == nil {
112		return nil, errors.New("Bad data from client")
113	}
114	packetBytes := make([]byte, fh.RemainingLength)
115	_, err = io.ReadFull(r, packetBytes)
116	if err != nil {
117		return nil, err
118	}
119	err = cp.Unpack(bytes.NewBuffer(packetBytes))
120	return cp, err
121}
122
123//NewControlPacket is used to create a new ControlPacket of the type specified
124//by packetType, this is usually done by reference to the packet type constants
125//defined in packets.go. The newly created ControlPacket is empty and a pointer
126//is returned.
127func NewControlPacket(packetType byte) (cp ControlPacket) {
128	switch packetType {
129	case Connect:
130		cp = &ConnectPacket{FixedHeader: FixedHeader{MessageType: Connect}}
131	case Connack:
132		cp = &ConnackPacket{FixedHeader: FixedHeader{MessageType: Connack}}
133	case Disconnect:
134		cp = &DisconnectPacket{FixedHeader: FixedHeader{MessageType: Disconnect}}
135	case Publish:
136		cp = &PublishPacket{FixedHeader: FixedHeader{MessageType: Publish}}
137	case Puback:
138		cp = &PubackPacket{FixedHeader: FixedHeader{MessageType: Puback}}
139	case Pubrec:
140		cp = &PubrecPacket{FixedHeader: FixedHeader{MessageType: Pubrec}}
141	case Pubrel:
142		cp = &PubrelPacket{FixedHeader: FixedHeader{MessageType: Pubrel, Qos: 1}}
143	case Pubcomp:
144		cp = &PubcompPacket{FixedHeader: FixedHeader{MessageType: Pubcomp}}
145	case Subscribe:
146		cp = &SubscribePacket{FixedHeader: FixedHeader{MessageType: Subscribe, Qos: 1}}
147	case Suback:
148		cp = &SubackPacket{FixedHeader: FixedHeader{MessageType: Suback}}
149	case Unsubscribe:
150		cp = &UnsubscribePacket{FixedHeader: FixedHeader{MessageType: Unsubscribe, Qos: 1}}
151	case Unsuback:
152		cp = &UnsubackPacket{FixedHeader: FixedHeader{MessageType: Unsuback}}
153	case Pingreq:
154		cp = &PingreqPacket{FixedHeader: FixedHeader{MessageType: Pingreq}}
155	case Pingresp:
156		cp = &PingrespPacket{FixedHeader: FixedHeader{MessageType: Pingresp}}
157	default:
158		return nil
159	}
160	return cp
161}
162
163//NewControlPacketWithHeader is used to create a new ControlPacket of the type
164//specified within the FixedHeader that is passed to the function.
165//The newly created ControlPacket is empty and a pointer is returned.
166func NewControlPacketWithHeader(fh FixedHeader) (cp ControlPacket) {
167	switch fh.MessageType {
168	case Connect:
169		cp = &ConnectPacket{FixedHeader: fh}
170	case Connack:
171		cp = &ConnackPacket{FixedHeader: fh}
172	case Disconnect:
173		cp = &DisconnectPacket{FixedHeader: fh}
174	case Publish:
175		cp = &PublishPacket{FixedHeader: fh}
176	case Puback:
177		cp = &PubackPacket{FixedHeader: fh}
178	case Pubrec:
179		cp = &PubrecPacket{FixedHeader: fh}
180	case Pubrel:
181		cp = &PubrelPacket{FixedHeader: fh}
182	case Pubcomp:
183		cp = &PubcompPacket{FixedHeader: fh}
184	case Subscribe:
185		cp = &SubscribePacket{FixedHeader: fh}
186	case Suback:
187		cp = &SubackPacket{FixedHeader: fh}
188	case Unsubscribe:
189		cp = &UnsubscribePacket{FixedHeader: fh}
190	case Unsuback:
191		cp = &UnsubackPacket{FixedHeader: fh}
192	case Pingreq:
193		cp = &PingreqPacket{FixedHeader: fh}
194	case Pingresp:
195		cp = &PingrespPacket{FixedHeader: fh}
196	default:
197		return nil
198	}
199	return cp
200}
201
202//Details struct returned by the Details() function called on
203//ControlPackets to present details of the Qos and MessageID
204//of the ControlPacket
205type Details struct {
206	Qos       byte
207	MessageID uint16
208}
209
210//FixedHeader is a struct to hold the decoded information from
211//the fixed header of an MQTT ControlPacket
212type FixedHeader struct {
213	MessageType     byte
214	Dup             bool
215	Qos             byte
216	Retain          bool
217	RemainingLength int
218}
219
220func (fh FixedHeader) String() string {
221	return fmt.Sprintf("%s: dup: %t qos: %d retain: %t rLength: %d", PacketNames[fh.MessageType], fh.Dup, fh.Qos, fh.Retain, fh.RemainingLength)
222}
223
224func boolToByte(b bool) byte {
225	switch b {
226	case true:
227		return 1
228	default:
229		return 0
230	}
231}
232
233func (fh *FixedHeader) pack() bytes.Buffer {
234	var header bytes.Buffer
235	header.WriteByte(fh.MessageType<<4 | boolToByte(fh.Dup)<<3 | fh.Qos<<1 | boolToByte(fh.Retain))
236	header.Write(encodeLength(fh.RemainingLength))
237	return header
238}
239
240func (fh *FixedHeader) unpack(typeAndFlags byte, r io.Reader) {
241	fh.MessageType = typeAndFlags >> 4
242	fh.Dup = (typeAndFlags>>3)&0x01 > 0
243	fh.Qos = (typeAndFlags >> 1) & 0x03
244	fh.Retain = typeAndFlags&0x01 > 0
245	fh.RemainingLength = decodeLength(r)
246}
247
248func decodeByte(b io.Reader) byte {
249	num := make([]byte, 1)
250	b.Read(num)
251	return num[0]
252}
253
254func decodeUint16(b io.Reader) uint16 {
255	num := make([]byte, 2)
256	b.Read(num)
257	return binary.BigEndian.Uint16(num)
258}
259
260func encodeUint16(num uint16) []byte {
261	bytes := make([]byte, 2)
262	binary.BigEndian.PutUint16(bytes, num)
263	return bytes
264}
265
266func encodeString(field string) []byte {
267	fieldLength := make([]byte, 2)
268	binary.BigEndian.PutUint16(fieldLength, uint16(len(field)))
269	return append(fieldLength, []byte(field)...)
270}
271
272func decodeString(b io.Reader) string {
273	fieldLength := decodeUint16(b)
274	field := make([]byte, fieldLength)
275	b.Read(field)
276	return string(field)
277}
278
279func decodeBytes(b io.Reader) []byte {
280	fieldLength := decodeUint16(b)
281	field := make([]byte, fieldLength)
282	b.Read(field)
283	return field
284}
285
286func encodeBytes(field []byte) []byte {
287	fieldLength := make([]byte, 2)
288	binary.BigEndian.PutUint16(fieldLength, uint16(len(field)))
289	return append(fieldLength, field...)
290}
291
292func encodeLength(length int) []byte {
293	var encLength []byte
294	for {
295		digit := byte(length % 128)
296		length /= 128
297		if length > 0 {
298			digit |= 0x80
299		}
300		encLength = append(encLength, digit)
301		if length == 0 {
302			break
303		}
304	}
305	return encLength
306}
307
308func decodeLength(r io.Reader) int {
309	var rLength uint32
310	var multiplier uint32
311	b := make([]byte, 1)
312	for multiplier < 27 { //fix: Infinite '(digit & 128) == 1' will cause the dead loop
313		io.ReadFull(r, b)
314		digit := b[0]
315		rLength |= uint32(digit&127) << multiplier
316		if (digit & 128) == 0 {
317			break
318		}
319		multiplier += 7
320	}
321	return int(rLength)
322}
323