1// +build codegen
2
3package api
4
5import (
6	"text/template"
7)
8
9var eventStreamReaderTestTmpl = template.Must(
10	template.New("eventStreamReaderTestTmpl").Funcs(template.FuncMap{
11		"ValueForType":             valueForType,
12		"HasNonBlobPayloadMembers": eventHasNonBlobPayloadMembers,
13		"EventHeaderValueForType":  setEventHeaderValueForType,
14		"Map":                      templateMap,
15		"OptionalAddInt": func(do bool, a, b int) int {
16			if !do {
17				return a
18			}
19			return a + b
20		},
21		"HasNonEventStreamMember": func(s *Shape) bool {
22			for _, ref := range s.MemberRefs {
23				if !ref.Shape.IsEventStream {
24					return true
25				}
26			}
27			return false
28		},
29	}).Parse(`
30{{ range $opName, $op := $.Operations }}
31	{{ if $op.EventStreamAPI }}
32		{{ if  $op.EventStreamAPI.OutputStream }}
33			{{ template "event stream outputStream tests" $op.EventStreamAPI }}
34		{{ end }}
35	{{ end }}
36{{ end }}
37
38type loopReader struct {
39	source *bytes.Reader
40}
41
42func (c *loopReader) Read(p []byte) (int, error) {
43	if c.source.Len() == 0 {
44		c.source.Seek(0, 0)
45	}
46
47	return c.source.Read(p)
48}
49
50{{ define "event stream outputStream tests" }}
51	func Test{{ $.Operation.ExportedName }}_Read(t *testing.T) {
52		expectEvents, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
53		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
54			eventstreamtest.ServeEventStream{
55				T:      t,
56				Events: eventMsgs,
57			},
58			true,
59		)
60		if err != nil {
61			t.Fatalf("expect no error, %v", err)
62		}
63		defer cleanupFn()
64
65		svc := New(sess)
66		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
67		if err != nil {
68			t.Fatalf("expect no error got, %v", err)
69		}
70		defer resp.GetStream().Close()
71
72		{{- if eq $.Operation.API.Metadata.Protocol "json" }}
73			{{- if HasNonEventStreamMember $.Operation.OutputRef.Shape }}
74				expectResp := expectEvents[0].(*{{ $.Operation.OutputRef.Shape.ShapeName }})
75				{{- range $name, $ref := $.Operation.OutputRef.Shape.MemberRefs }}
76					{{- if not $ref.Shape.IsEventStream }}
77						if e, a := expectResp.{{ $name }}, resp.{{ $name }}; !reflect.DeepEqual(e,a) {
78							t.Errorf("expect %v, got %v", e, a)
79						}
80					{{- end }}
81				{{- end }}
82			{{- end }}
83			// Trim off response output type pseudo event so only event messages remain.
84			expectEvents = expectEvents[1:]
85		{{ end }}
86
87		var i int
88		for event := range resp.GetStream().Events() {
89			if event == nil {
90				t.Errorf("%d, expect event, got nil", i)
91			}
92			if e, a := expectEvents[i], event; !reflect.DeepEqual(e, a) {
93				t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a)
94			}
95			i++
96		}
97
98		if err := resp.GetStream().Err(); err != nil {
99			t.Errorf("expect no error, %v", err)
100		}
101	}
102
103	func Test{{ $.Operation.ExportedName }}_ReadClose(t *testing.T) {
104		_, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
105		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
106			eventstreamtest.ServeEventStream{
107				T:      t,
108				Events: eventMsgs,
109			},
110			true,
111		)
112		if err != nil {
113			t.Fatalf("expect no error, %v", err)
114		}
115		defer cleanupFn()
116
117		svc := New(sess)
118		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
119		if err != nil {
120			t.Fatalf("expect no error got, %v", err)
121		}
122
123		{{ if gt (len $.OutputStream.Events) 0 -}}
124			// Assert calling Err before close does not close the stream.
125			resp.GetStream().Err()
126			select {
127			case _, ok := <-resp.GetStream().Events():
128				if !ok {
129					t.Fatalf("expect stream not to be closed, but was")
130				}
131			default:
132			}
133		{{- end }}
134
135		resp.GetStream().Close()
136		<-resp.GetStream().Events()
137
138		if err := resp.GetStream().Err(); err != nil {
139			t.Errorf("expect no error, %v", err)
140		}
141	}
142
143	func Test{{ $.Operation.ExportedName }}_ReadUnknownEvent(t *testing.T) {
144		expectEvents, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
145
146		{{- if eq $.Operation.API.Metadata.Protocol "json" }}
147			eventOffset := 1
148		{{- else }}
149			var eventOffset int
150		{{- end }}
151
152		unknownEvent := eventstream.Message{
153			Headers: eventstream.Headers{
154				eventstreamtest.EventMessageTypeHeader,
155				{
156					Name:  eventstreamapi.EventTypeHeader,
157					Value: eventstream.StringValue("UnknownEventName"),
158				},
159			},
160			Payload: []byte("some unknown event"),
161		}
162
163		eventMsgs = append(eventMsgs[:eventOffset],
164			append([]eventstream.Message{unknownEvent}, eventMsgs[eventOffset:]...)...)
165
166		expectEvents = append(expectEvents[:eventOffset],
167			append([]{{ $.OutputStream.Name }}Event{
168					&{{ $.OutputStream.StreamUnknownEventName }}{
169						Type: "UnknownEventName",
170						Message: unknownEvent,
171					},
172				},
173				expectEvents[eventOffset:]...)...)
174
175		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
176			eventstreamtest.ServeEventStream{
177				T:      t,
178				Events: eventMsgs,
179			},
180			true,
181		)
182		if err != nil {
183			t.Fatalf("expect no error, %v", err)
184		}
185		defer cleanupFn()
186
187		svc := New(sess)
188		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
189		if err != nil {
190			t.Fatalf("expect no error got, %v", err)
191		}
192		defer resp.GetStream().Close()
193
194		{{- if eq $.Operation.API.Metadata.Protocol "json" }}
195			// Trim off response output type pseudo event so only event messages remain.
196			expectEvents = expectEvents[1:]
197		{{ end }}
198
199		var i int
200		for event := range resp.GetStream().Events() {
201			if event == nil {
202				t.Errorf("%d, expect event, got nil", i)
203			}
204			if e, a := expectEvents[i], event; !reflect.DeepEqual(e, a) {
205				t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a)
206			}
207			i++
208		}
209
210		if err := resp.GetStream().Err(); err != nil {
211			t.Errorf("expect no error, %v", err)
212		}
213	}
214
215	func Benchmark{{ $.Operation.ExportedName }}_Read(b *testing.B) {
216		_, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
217		var buf bytes.Buffer
218		encoder := eventstream.NewEncoder(&buf)
219		for _, msg := range eventMsgs {
220			if err := encoder.Encode(msg); err != nil {
221				b.Fatalf("failed to encode message, %v", err)
222			}
223		}
224		stream := &loopReader{source: bytes.NewReader(buf.Bytes())}
225
226		sess := unit.Session
227		svc := New(sess, &aws.Config{
228			Endpoint:               aws.String("https://example.com"),
229			DisableParamValidation: aws.Bool(true),
230		})
231		svc.Handlers.Send.Swap(corehandlers.SendHandler.Name,
232			request.NamedHandler{Name: "mockSend",
233				Fn: func(r *request.Request) {
234					r.HTTPResponse = &http.Response{
235						Status:     "200 OK",
236						StatusCode: 200,
237						Header:     http.Header{},
238						Body:       ioutil.NopCloser(stream),
239					}
240				},
241			},
242		)
243
244		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
245		if err != nil {
246			b.Fatalf("failed to create request, %v", err)
247		}
248		defer resp.GetStream().Close()
249		b.ResetTimer()
250
251		for i := 0; i < b.N; i++ {
252			if err = resp.GetStream().Err(); err != nil {
253				b.Fatalf("expect no error, got %v", err)
254			}
255			event := <-resp.GetStream().Events()
256			if event == nil {
257				b.Fatalf("expect event, got nil, %v, %d", resp.GetStream().Err(), i)
258			}
259		}
260	}
261
262	func mock{{ $.Operation.ExportedName }}ReadEvents() (
263		[]{{ $.OutputStream.Name }}Event,
264		[]eventstream.Message,
265	) {
266		expectEvents := []{{ $.OutputStream.Name }}Event {
267			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
268				{{- template "set event type" $.Operation.OutputRef.Shape }}
269			{{- end }}
270			{{- range $_, $event := $.OutputStream.Events }}
271				{{- template "set event type" $event.Shape }}
272			{{- end }}
273		}
274
275		var marshalers request.HandlerList
276		marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
277		payloadMarshaler := protocol.HandlerPayloadMarshal{
278			Marshalers: marshalers,
279		}
280		_ = payloadMarshaler
281
282		eventMsgs := []eventstream.Message{
283			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
284				{{- template "set event message" Map "idx" 0 "parentShape" $.Operation.OutputRef.Shape "eventName" "initial-response" }}
285			{{- end }}
286			{{- range $idx, $event := $.OutputStream.Events }}
287				{{- $offsetIdx := OptionalAddInt (eq $.Operation.API.Metadata.Protocol "json") $idx 1 }}
288				{{- template "set event message" Map "idx" $offsetIdx "parentShape" $event.Shape "eventName" $event.Name }}
289			{{- end }}
290		}
291
292		return expectEvents, eventMsgs
293	}
294
295	{{- if $.OutputStream.Exceptions }}
296		func Test{{ $.Operation.ExportedName }}_ReadException(t *testing.T) {
297			expectEvents := []{{ $.OutputStream.Name }}Event {
298				{{- if eq $.Operation.API.Metadata.Protocol "json" }}
299					{{- template "set event type" $.Operation.OutputRef.Shape }}
300				{{- end }}
301
302				{{- $exception := index $.OutputStream.Exceptions 0 }}
303				{{- template "set event type" $exception.Shape }}
304			}
305
306			var marshalers request.HandlerList
307			marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
308			payloadMarshaler := protocol.HandlerPayloadMarshal{
309				Marshalers: marshalers,
310			}
311
312			eventMsgs := []eventstream.Message{
313				{{- if eq $.Operation.API.Metadata.Protocol "json" }}
314					{{- template "set event message" Map "idx" 0 "parentShape" $.Operation.OutputRef.Shape "eventName" "initial-response" }}
315				{{- end }}
316
317				{{- $offsetIdx := OptionalAddInt (eq $.Operation.API.Metadata.Protocol "json") 0 1 }}
318				{{- $exception := index $.OutputStream.Exceptions 0 }}
319				{{- template "set event message" Map "idx" $offsetIdx "parentShape" $exception.Shape "eventName" $exception.Name }}
320			}
321
322			sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
323				eventstreamtest.ServeEventStream{
324					T:      t,
325					Events: eventMsgs,
326				},
327				true,
328			)
329			if err != nil {
330				t.Fatalf("expect no error, %v", err)
331			}
332			defer cleanupFn()
333
334			svc := New(sess)
335			resp, err := svc.{{ $.Operation.ExportedName }}(nil)
336			if err != nil {
337				t.Fatalf("expect no error got, %v", err)
338			}
339
340			defer resp.GetStream().Close()
341
342			<-resp.GetStream().Events()
343
344			err = resp.GetStream().Err()
345			if err == nil {
346				t.Fatalf("expect err, got none")
347			}
348
349			expectErr := {{ ValueForType $exception.Shape nil }}
350			aerr, ok := err.(awserr.Error)
351			if !ok {
352				t.Errorf("expect exception, got %T, %#v", err, err)
353			}
354			if e, a := expectErr.Code(), aerr.Code(); e != a {
355				t.Errorf("expect %v, got %v", e, a)
356			}
357			if e, a := expectErr.Message(), aerr.Message(); e != a {
358				t.Errorf("expect %v, got %v", e, a)
359			}
360
361			if e, a := expectErr, aerr; !reflect.DeepEqual(e, a) {
362				t.Errorf("expect error %+#v, got %+#v", e, a)
363			}
364		}
365
366		{{- range $_, $exception := $.OutputStream.Exceptions }}
367			var _ awserr.Error = (*{{ $exception.Shape.ShapeName }})(nil)
368		{{- end }}
369
370	{{ end }}
371{{ end }}
372
373{{/* Params: *Shape */}}
374{{ define "set event type" }}
375	&{{ $.ShapeName }}{
376		{{- if $.Exception }}
377			RespMetadata: protocol.ResponseMetadata{
378				StatusCode: 200,
379			},
380		{{- end }}
381		{{- range $memName, $memRef := $.MemberRefs }}
382			{{- if not $memRef.Shape.IsEventStream }}
383				{{ $memName }}: {{ ValueForType $memRef.Shape nil }},
384			{{- end }}
385		{{- end }}
386	},
387{{- end }}
388
389{{/* Params: idx:int, parentShape:*Shape, eventName:string */}}
390{{ define "set event message" }}
391	{
392		Headers: eventstream.Headers{
393			{{- if $.parentShape.Exception }}
394				eventstreamtest.EventExceptionTypeHeader,
395				{
396					Name:  eventstreamapi.ExceptionTypeHeader,
397					Value: eventstream.StringValue("{{ $.eventName }}"),
398				},
399			{{- else }}
400				eventstreamtest.EventMessageTypeHeader,
401				{
402					Name:  eventstreamapi.EventTypeHeader,
403					Value: eventstream.StringValue("{{ $.eventName }}"),
404				},
405			{{- end }}
406			{{- range $memName, $memRef := $.parentShape.MemberRefs }}
407				{{- template "set event message header" Map "idx" $.idx "parentShape" $.parentShape "memName" $memName "memRef" $memRef }}
408			{{- end }}
409		},
410		{{- template "set event message payload" Map "idx" $.idx "parentShape" $.parentShape }}
411	},
412{{- end }}
413
414{{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
415{{ define "set event message header" }}
416	{{- if $.memRef.IsEventHeader }}
417		{
418			Name: "{{ $.memName }}",
419			{{- $shapeValueVar := printf "expectEvents[%d].(%s).%s" $.idx $.parentShape.GoType $.memName }}
420			Value: {{ EventHeaderValueForType $.memRef.Shape $shapeValueVar }},
421		},
422	{{- end }}
423{{- end }}
424
425{{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
426{{ define "set event message payload" }}
427	{{- $payloadMemName := $.parentShape.PayloadRefName }}
428	{{- if HasNonBlobPayloadMembers $.parentShape }}
429		Payload: eventstreamtest.MarshalEventPayload(payloadMarshaler, expectEvents[{{ $.idx }}]),
430	{{- else if $payloadMemName }}
431		{{- $shapeType := (index $.parentShape.MemberRefs $payloadMemName).Shape.Type }}
432		{{- if eq $shapeType "blob" }}
433			Payload: expectEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }},
434		{{- else if eq $shapeType "string" }}
435			Payload: []byte(*expectEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }}),
436		{{- end }}
437	{{- end }}
438{{- end }}
439`))
440