1package eventstream
2
3import (
4	"bytes"
5	"encoding/binary"
6	"encoding/hex"
7	"encoding/json"
8	"fmt"
9	"hash"
10	"hash/crc32"
11	"io"
12
13	"github.com/aws/aws-sdk-go/aws"
14)
15
16// Decoder provides decoding of an Event Stream messages.
17type Decoder struct {
18	r      io.Reader
19	logger aws.Logger
20}
21
22// NewDecoder initializes and returns a Decoder for decoding event
23// stream messages from the reader provided.
24func NewDecoder(r io.Reader, opts ...func(*Decoder)) *Decoder {
25	d := &Decoder{
26		r: r,
27	}
28
29	for _, opt := range opts {
30		opt(d)
31	}
32
33	return d
34}
35
36// DecodeWithLogger adds a logger to be used by the decoder when decoding
37// stream events.
38func DecodeWithLogger(logger aws.Logger) func(*Decoder) {
39	return func(d *Decoder) {
40		d.logger = logger
41	}
42}
43
44// Decode attempts to decode a single message from the event stream reader.
45// Will return the event stream message, or error if Decode fails to read
46// the message from the stream.
47func (d *Decoder) Decode(payloadBuf []byte) (m Message, err error) {
48	reader := d.r
49	if d.logger != nil {
50		debugMsgBuf := bytes.NewBuffer(nil)
51		reader = io.TeeReader(reader, debugMsgBuf)
52		defer func() {
53			logMessageDecode(d.logger, debugMsgBuf, m, err)
54		}()
55	}
56
57	m, err = Decode(reader, payloadBuf)
58
59	return m, err
60}
61
62// Decode attempts to decode a single message from the event stream reader.
63// Will return the event stream message, or error if Decode fails to read
64// the message from the reader.
65func Decode(reader io.Reader, payloadBuf []byte) (m Message, err error) {
66	crc := crc32.New(crc32IEEETable)
67	hashReader := io.TeeReader(reader, crc)
68
69	prelude, err := decodePrelude(hashReader, crc)
70	if err != nil {
71		return Message{}, err
72	}
73
74	if prelude.HeadersLen > 0 {
75		lr := io.LimitReader(hashReader, int64(prelude.HeadersLen))
76		m.Headers, err = decodeHeaders(lr)
77		if err != nil {
78			return Message{}, err
79		}
80	}
81
82	if payloadLen := prelude.PayloadLen(); payloadLen > 0 {
83		buf, err := decodePayload(payloadBuf, io.LimitReader(hashReader, int64(payloadLen)))
84		if err != nil {
85			return Message{}, err
86		}
87		m.Payload = buf
88	}
89
90	msgCRC := crc.Sum32()
91	if err := validateCRC(reader, msgCRC); err != nil {
92		return Message{}, err
93	}
94
95	return m, nil
96}
97
98func logMessageDecode(logger aws.Logger, msgBuf *bytes.Buffer, msg Message, decodeErr error) {
99	w := bytes.NewBuffer(nil)
100	defer func() { logger.Log(w.String()) }()
101
102	fmt.Fprintf(w, "Raw message:\n%s\n",
103		hex.Dump(msgBuf.Bytes()))
104
105	if decodeErr != nil {
106		fmt.Fprintf(w, "Decode error: %v\n", decodeErr)
107		return
108	}
109
110	rawMsg, err := msg.rawMessage()
111	if err != nil {
112		fmt.Fprintf(w, "failed to create raw message, %v\n", err)
113		return
114	}
115
116	decodedMsg := decodedMessage{
117		rawMessage: rawMsg,
118		Headers:    decodedHeaders(msg.Headers),
119	}
120
121	fmt.Fprintf(w, "Decoded message:\n")
122	encoder := json.NewEncoder(w)
123	if err := encoder.Encode(decodedMsg); err != nil {
124		fmt.Fprintf(w, "failed to generate decoded message, %v\n", err)
125	}
126}
127
128func decodePrelude(r io.Reader, crc hash.Hash32) (messagePrelude, error) {
129	var p messagePrelude
130
131	var err error
132	p.Length, err = decodeUint32(r)
133	if err != nil {
134		return messagePrelude{}, err
135	}
136
137	p.HeadersLen, err = decodeUint32(r)
138	if err != nil {
139		return messagePrelude{}, err
140	}
141
142	if err := p.ValidateLens(); err != nil {
143		return messagePrelude{}, err
144	}
145
146	preludeCRC := crc.Sum32()
147	if err := validateCRC(r, preludeCRC); err != nil {
148		return messagePrelude{}, err
149	}
150
151	p.PreludeCRC = preludeCRC
152
153	return p, nil
154}
155
156func decodePayload(buf []byte, r io.Reader) ([]byte, error) {
157	w := bytes.NewBuffer(buf[0:0])
158
159	_, err := io.Copy(w, r)
160	return w.Bytes(), err
161}
162
163func decodeUint8(r io.Reader) (uint8, error) {
164	type byteReader interface {
165		ReadByte() (byte, error)
166	}
167
168	if br, ok := r.(byteReader); ok {
169		v, err := br.ReadByte()
170		return uint8(v), err
171	}
172
173	var b [1]byte
174	_, err := io.ReadFull(r, b[:])
175	return uint8(b[0]), err
176}
177func decodeUint16(r io.Reader) (uint16, error) {
178	var b [2]byte
179	bs := b[:]
180	_, err := io.ReadFull(r, bs)
181	if err != nil {
182		return 0, err
183	}
184	return binary.BigEndian.Uint16(bs), nil
185}
186func decodeUint32(r io.Reader) (uint32, error) {
187	var b [4]byte
188	bs := b[:]
189	_, err := io.ReadFull(r, bs)
190	if err != nil {
191		return 0, err
192	}
193	return binary.BigEndian.Uint32(bs), nil
194}
195func decodeUint64(r io.Reader) (uint64, error) {
196	var b [8]byte
197	bs := b[:]
198	_, err := io.ReadFull(r, bs)
199	if err != nil {
200		return 0, err
201	}
202	return binary.BigEndian.Uint64(bs), nil
203}
204
205func validateCRC(r io.Reader, expect uint32) error {
206	msgCRC, err := decodeUint32(r)
207	if err != nil {
208		return err
209	}
210
211	if msgCRC != expect {
212		return ChecksumError{}
213	}
214
215	return nil
216}
217