1package eventstreamapi
2
3import (
4	"fmt"
5	"io"
6
7	"github.com/aws/aws-sdk-go/aws"
8	"github.com/aws/aws-sdk-go/private/protocol"
9	"github.com/aws/aws-sdk-go/private/protocol/eventstream"
10)
11
12// Unmarshaler provides the interface for unmarshaling a EventStream
13// message into a SDK type.
14type Unmarshaler interface {
15	UnmarshalEvent(protocol.PayloadUnmarshaler, eventstream.Message) error
16}
17
18// EventStream headers with specific meaning to async API functionality.
19const (
20	MessageTypeHeader    = `:message-type` // Identifies type of message.
21	EventMessageType     = `event`
22	ErrorMessageType     = `error`
23	ExceptionMessageType = `exception`
24
25	// Message Events
26	EventTypeHeader = `:event-type` // Identifies message event type e.g. "Stats".
27
28	// Message Error
29	ErrorCodeHeader    = `:error-code`
30	ErrorMessageHeader = `:error-message`
31
32	// Message Exception
33	ExceptionTypeHeader = `:exception-type`
34)
35
36// EventReader provides reading from the EventStream of an reader.
37type EventReader struct {
38	reader  io.ReadCloser
39	decoder *eventstream.Decoder
40
41	unmarshalerForEventType func(string) (Unmarshaler, error)
42	payloadUnmarshaler      protocol.PayloadUnmarshaler
43
44	payloadBuf []byte
45}
46
47// NewEventReader returns a EventReader built from the reader and unmarshaler
48// provided.  Use ReadStream method to start reading from the EventStream.
49func NewEventReader(
50	reader io.ReadCloser,
51	payloadUnmarshaler protocol.PayloadUnmarshaler,
52	unmarshalerForEventType func(string) (Unmarshaler, error),
53) *EventReader {
54	return &EventReader{
55		reader:                  reader,
56		decoder:                 eventstream.NewDecoder(reader),
57		payloadUnmarshaler:      payloadUnmarshaler,
58		unmarshalerForEventType: unmarshalerForEventType,
59		payloadBuf:              make([]byte, 10*1024),
60	}
61}
62
63// UseLogger instructs the EventReader to use the logger and log level
64// specified.
65func (r *EventReader) UseLogger(logger aws.Logger, logLevel aws.LogLevelType) {
66	if logger != nil && logLevel.Matches(aws.LogDebugWithEventStreamBody) {
67		r.decoder.UseLogger(logger)
68	}
69}
70
71// ReadEvent attempts to read a message from the EventStream and return the
72// unmarshaled event value that the message is for.
73//
74// For EventStream API errors check if the returned error satisfies the
75// awserr.Error interface to get the error's Code and Message components.
76//
77// EventUnmarshalers called with EventStream messages must take copies of the
78// message's Payload. The payload will is reused between events read.
79func (r *EventReader) ReadEvent() (event interface{}, err error) {
80	msg, err := r.decoder.Decode(r.payloadBuf)
81	if err != nil {
82		return nil, err
83	}
84	defer func() {
85		// Reclaim payload buffer for next message read.
86		r.payloadBuf = msg.Payload[0:0]
87	}()
88
89	typ, err := GetHeaderString(msg, MessageTypeHeader)
90	if err != nil {
91		return nil, err
92	}
93
94	switch typ {
95	case EventMessageType:
96		return r.unmarshalEventMessage(msg)
97	case ExceptionMessageType:
98		err = r.unmarshalEventException(msg)
99		return nil, err
100	case ErrorMessageType:
101		return nil, r.unmarshalErrorMessage(msg)
102	default:
103		return nil, fmt.Errorf("unknown eventstream message type, %v", typ)
104	}
105}
106
107func (r *EventReader) unmarshalEventMessage(
108	msg eventstream.Message,
109) (event interface{}, err error) {
110	eventType, err := GetHeaderString(msg, EventTypeHeader)
111	if err != nil {
112		return nil, err
113	}
114
115	ev, err := r.unmarshalerForEventType(eventType)
116	if err != nil {
117		return nil, err
118	}
119
120	err = ev.UnmarshalEvent(r.payloadUnmarshaler, msg)
121	if err != nil {
122		return nil, err
123	}
124
125	return ev, nil
126}
127
128func (r *EventReader) unmarshalEventException(
129	msg eventstream.Message,
130) (err error) {
131	eventType, err := GetHeaderString(msg, ExceptionTypeHeader)
132	if err != nil {
133		return err
134	}
135
136	ev, err := r.unmarshalerForEventType(eventType)
137	if err != nil {
138		return err
139	}
140
141	err = ev.UnmarshalEvent(r.payloadUnmarshaler, msg)
142	if err != nil {
143		return err
144	}
145
146	var ok bool
147	err, ok = ev.(error)
148	if !ok {
149		err = messageError{
150			code: "SerializationError",
151			msg: fmt.Sprintf(
152				"event stream exception %s mapped to non-error %T, %v",
153				eventType, ev, ev,
154			),
155		}
156	}
157
158	return err
159}
160
161func (r *EventReader) unmarshalErrorMessage(msg eventstream.Message) (err error) {
162	var msgErr messageError
163
164	msgErr.code, err = GetHeaderString(msg, ErrorCodeHeader)
165	if err != nil {
166		return err
167	}
168
169	msgErr.msg, err = GetHeaderString(msg, ErrorMessageHeader)
170	if err != nil {
171		return err
172	}
173
174	return msgErr
175}
176
177// Close closes the EventReader's EventStream reader.
178func (r *EventReader) Close() error {
179	return r.reader.Close()
180}
181
182// GetHeaderString returns the value of the header as a string. If the header
183// is not set or the value is not a string an error will be returned.
184func GetHeaderString(msg eventstream.Message, headerName string) (string, error) {
185	headerVal := msg.Headers.Get(headerName)
186	if headerVal == nil {
187		return "", fmt.Errorf("error header %s not present", headerName)
188	}
189
190	v, ok := headerVal.Get().(string)
191	if !ok {
192		return "", fmt.Errorf("error header value is not a string, %T", headerVal)
193	}
194
195	return v, nil
196}
197