1package eventstreamapi
2
3import (
4	"fmt"
5
6	"github.com/aws/aws-sdk-go/private/protocol"
7	"github.com/aws/aws-sdk-go/private/protocol/eventstream"
8)
9
10// Unmarshaler provides the interface for unmarshaling a EventStream
11// message into a SDK type.
12type Unmarshaler interface {
13	UnmarshalEvent(protocol.PayloadUnmarshaler, eventstream.Message) error
14}
15
16// EventReader provides reading from the EventStream of an reader.
17type EventReader struct {
18	decoder *eventstream.Decoder
19
20	unmarshalerForEventType func(string) (Unmarshaler, error)
21	payloadUnmarshaler      protocol.PayloadUnmarshaler
22
23	payloadBuf []byte
24}
25
26// NewEventReader returns a EventReader built from the reader and unmarshaler
27// provided.  Use ReadStream method to start reading from the EventStream.
28func NewEventReader(
29	decoder *eventstream.Decoder,
30	payloadUnmarshaler protocol.PayloadUnmarshaler,
31	unmarshalerForEventType func(string) (Unmarshaler, error),
32) *EventReader {
33	return &EventReader{
34		decoder:                 decoder,
35		payloadUnmarshaler:      payloadUnmarshaler,
36		unmarshalerForEventType: unmarshalerForEventType,
37		payloadBuf:              make([]byte, 10*1024),
38	}
39}
40
41// ReadEvent attempts to read a message from the EventStream and return the
42// unmarshaled event value that the message is for.
43//
44// For EventStream API errors check if the returned error satisfies the
45// awserr.Error interface to get the error's Code and Message components.
46//
47// EventUnmarshalers called with EventStream messages must take copies of the
48// message's Payload. The payload will is reused between events read.
49func (r *EventReader) ReadEvent() (event interface{}, err error) {
50	msg, err := r.decoder.Decode(r.payloadBuf)
51	if err != nil {
52		return nil, err
53	}
54	defer func() {
55		// Reclaim payload buffer for next message read.
56		r.payloadBuf = msg.Payload[0:0]
57	}()
58
59	typ, err := GetHeaderString(msg, MessageTypeHeader)
60	if err != nil {
61		return nil, err
62	}
63
64	switch typ {
65	case EventMessageType:
66		return r.unmarshalEventMessage(msg)
67	case ExceptionMessageType:
68		return nil, r.unmarshalEventException(msg)
69	case ErrorMessageType:
70		return nil, r.unmarshalErrorMessage(msg)
71	default:
72		return nil, &UnknownMessageTypeError{
73			Type: typ, Message: msg.Clone(),
74		}
75	}
76}
77
78// UnknownMessageTypeError provides an error when a message is received from
79// the stream, but the reader is unable to determine what kind of message it is.
80type UnknownMessageTypeError struct {
81	Type    string
82	Message eventstream.Message
83}
84
85func (e *UnknownMessageTypeError) Error() string {
86	return "unknown eventstream message type, " + e.Type
87}
88
89func (r *EventReader) unmarshalEventMessage(
90	msg eventstream.Message,
91) (event interface{}, err error) {
92	eventType, err := GetHeaderString(msg, EventTypeHeader)
93	if err != nil {
94		return nil, err
95	}
96
97	ev, err := r.unmarshalerForEventType(eventType)
98	if err != nil {
99		return nil, err
100	}
101
102	err = ev.UnmarshalEvent(r.payloadUnmarshaler, msg)
103	if err != nil {
104		return nil, err
105	}
106
107	return ev, nil
108}
109
110func (r *EventReader) unmarshalEventException(
111	msg eventstream.Message,
112) (err error) {
113	eventType, err := GetHeaderString(msg, ExceptionTypeHeader)
114	if err != nil {
115		return err
116	}
117
118	ev, err := r.unmarshalerForEventType(eventType)
119	if err != nil {
120		return err
121	}
122
123	err = ev.UnmarshalEvent(r.payloadUnmarshaler, msg)
124	if err != nil {
125		return err
126	}
127
128	var ok bool
129	err, ok = ev.(error)
130	if !ok {
131		err = messageError{
132			code: "SerializationError",
133			msg: fmt.Sprintf(
134				"event stream exception %s mapped to non-error %T, %v",
135				eventType, ev, ev,
136			),
137		}
138	}
139
140	return err
141}
142
143func (r *EventReader) unmarshalErrorMessage(msg eventstream.Message) (err error) {
144	var msgErr messageError
145
146	msgErr.code, err = GetHeaderString(msg, ErrorCodeHeader)
147	if err != nil {
148		return err
149	}
150
151	msgErr.msg, err = GetHeaderString(msg, ErrorMessageHeader)
152	if err != nil {
153		return err
154	}
155
156	return msgErr
157}
158
159// GetHeaderString returns the value of the header as a string. If the header
160// is not set or the value is not a string an error will be returned.
161func GetHeaderString(msg eventstream.Message, headerName string) (string, error) {
162	headerVal := msg.Headers.Get(headerName)
163	if headerVal == nil {
164		return "", fmt.Errorf("error header %s not present", headerName)
165	}
166
167	v, ok := headerVal.Get().(string)
168	if !ok {
169		return "", fmt.Errorf("error header value is not a string, %T", headerVal)
170	}
171
172	return v, nil
173}
174