1// +build codegen
2
3package api
4
5import (
6	"bytes"
7	"fmt"
8	"io"
9	"strings"
10	"text/template"
11)
12
13// EventStreamAPI provides details about the event stream async API and
14// associated EventStream shapes.
15type EventStreamAPI struct {
16	API       *API
17	Name      string
18	Operation *Operation
19	Shape     *Shape
20	Inbound   *EventStream
21	Outbound  *EventStream
22}
23
24// EventStream represents a single eventstream group (input/output) and the
25// modeled events that are known for the stream.
26type EventStream struct {
27	Name       string
28	Shape      *Shape
29	Events     []*Event
30	Exceptions []*Event
31}
32
33// Event is a single EventStream event that can be sent or received in an
34// EventStream.
35type Event struct {
36	Name  string
37	Shape *Shape
38	For   *EventStream
39}
40
41// ShapeDoc returns the docstring for the EventStream API.
42func (esAPI *EventStreamAPI) ShapeDoc() string {
43	tmpl := template.Must(template.New("eventStreamShapeDoc").Parse(`
44{{- $.Name }} provides handling of EventStreams for
45the {{ $.Operation.ExportedName }} API.
46{{- if $.Inbound }}
47
48Use this type to receive {{ $.Inbound.Name }} events. The events
49can be read from the Events channel member.
50
51The events that can be received are:
52{{ range $_, $event := $.Inbound.Events }}
53    * {{ $event.Shape.ShapeName }}
54{{- end }}
55
56{{- end }}
57
58{{- if $.Outbound }}
59
60Use this type to send {{ $.Outbound.Name }} events. The events
61can be sent with the Send method.
62
63The events that can be sent are:
64{{ range $_, $event := $.Outbound.Events -}}
65    * {{ $event.Shape.ShapeName }}
66{{- end }}
67
68{{- end }}`))
69
70	var w bytes.Buffer
71	if err := tmpl.Execute(&w, esAPI); err != nil {
72		panic(fmt.Sprintf("failed to generate eventstream shape template for %v, %v", esAPI.Name, err))
73	}
74
75	return commentify(w.String())
76}
77
78func hasEventStream(topShape *Shape) bool {
79	for _, ref := range topShape.MemberRefs {
80		if ref.Shape.IsEventStream {
81			return true
82		}
83	}
84
85	return false
86}
87
88func eventStreamAPIShapeRefDoc(refName string) string {
89	return commentify(fmt.Sprintf("Use %s to use the API's stream.", refName))
90}
91
92func (a *API) setupEventStreams() {
93	const eventStreamMemberName = "EventStream"
94
95	for _, op := range a.Operations {
96		outbound := setupEventStream(op.InputRef.Shape)
97		inbound := setupEventStream(op.OutputRef.Shape)
98
99		if outbound == nil && inbound == nil {
100			continue
101		}
102
103		if outbound != nil {
104			panic(fmt.Sprintf("Outbound stream support not implemented, %s, %s",
105				outbound.Name, outbound.Shape.ShapeName))
106		}
107
108		switch a.Metadata.Protocol {
109		case `rest-json`, `rest-xml`, `json`:
110		default:
111			panic(fmt.Sprintf("EventStream not supported for protocol %v",
112				a.Metadata.Protocol))
113		}
114
115		op.EventStreamAPI = &EventStreamAPI{
116			API:       a,
117			Name:      op.ExportedName + eventStreamMemberName,
118			Operation: op,
119			Outbound:  outbound,
120			Inbound:   inbound,
121		}
122
123		streamShape := &Shape{
124			API:            a,
125			ShapeName:      op.EventStreamAPI.Name,
126			Documentation:  op.EventStreamAPI.ShapeDoc(),
127			Type:           "structure",
128			EventStreamAPI: op.EventStreamAPI,
129			IsEventStream:  true,
130			MemberRefs: map[string]*ShapeRef{
131				"Inbound": &ShapeRef{
132					ShapeName: inbound.Shape.ShapeName,
133				},
134			},
135		}
136		inbound.Shape.refs = append(inbound.Shape.refs, streamShape.MemberRefs["Inbound"])
137		streamShapeRef := &ShapeRef{
138			API:           a,
139			ShapeName:     streamShape.ShapeName,
140			Shape:         streamShape,
141			Documentation: eventStreamAPIShapeRefDoc(eventStreamMemberName),
142		}
143		streamShape.refs = []*ShapeRef{streamShapeRef}
144		op.EventStreamAPI.Shape = streamShape
145
146		if _, ok := op.OutputRef.Shape.MemberRefs[eventStreamMemberName]; ok {
147			panic(fmt.Sprintf("shape ref already exists, %s.%s",
148				op.OutputRef.Shape.ShapeName, eventStreamMemberName))
149		}
150		op.OutputRef.Shape.MemberRefs[eventStreamMemberName] = streamShapeRef
151		op.OutputRef.Shape.EventStreamsMemberName = eventStreamMemberName
152		if _, ok := a.Shapes[streamShape.ShapeName]; ok {
153			panic("shape already exists, " + streamShape.ShapeName)
154		}
155		a.Shapes[streamShape.ShapeName] = streamShape
156
157		a.HasEventStream = true
158	}
159}
160
161func setupEventStream(topShape *Shape) *EventStream {
162	var eventStream *EventStream
163	for refName, ref := range topShape.MemberRefs {
164		if !ref.Shape.IsEventStream {
165			continue
166		}
167		if eventStream != nil {
168			panic(fmt.Sprintf("multiple shape ref eventstreams, %s, prev: %s",
169				refName, eventStream.Name))
170		}
171
172		eventStream = &EventStream{
173			Name:  ref.Shape.ShapeName,
174			Shape: ref.Shape,
175		}
176
177		if topShape.API.Metadata.Protocol == "json" {
178			topShape.EventFor = append(topShape.EventFor, eventStream)
179		}
180
181		for _, eventRefName := range ref.Shape.MemberNames() {
182			eventRef := ref.Shape.MemberRefs[eventRefName]
183			if !(eventRef.Shape.IsEvent || eventRef.Shape.Exception) {
184				panic(fmt.Sprintf("unexpected non-event member reference %s.%s",
185					ref.Shape.ShapeName, eventRefName))
186			}
187
188			updateEventPayloadRef(eventRef.Shape)
189
190			eventRef.Shape.EventFor = append(eventRef.Shape.EventFor, eventStream)
191
192			// Exceptions and events are two different lists to allow the SDK
193			// to easly generate code with the two handled differently.
194			event := &Event{
195				Name:  eventRefName,
196				Shape: eventRef.Shape,
197				For:   eventStream,
198			}
199			if eventRef.Shape.Exception {
200				eventStream.Exceptions = append(eventStream.Exceptions, event)
201			} else {
202				eventStream.Events = append(eventStream.Events, event)
203			}
204		}
205
206		// Remove the eventstream references as they will be added elsewhere.
207		ref.Shape.removeRef(ref)
208		delete(topShape.MemberRefs, refName)
209		delete(topShape.API.Shapes, ref.Shape.ShapeName)
210	}
211
212	return eventStream
213}
214
215func updateEventPayloadRef(parent *Shape) {
216	refName := parent.PayloadRefName()
217	if len(refName) == 0 {
218		return
219	}
220
221	payloadRef := parent.MemberRefs[refName]
222
223	if payloadRef.Shape.Type == "blob" {
224		return
225	}
226
227	if len(payloadRef.LocationName) != 0 {
228		return
229	}
230
231	payloadRef.LocationName = refName
232}
233
234func renderEventStreamAPIShape(w io.Writer, s *Shape) error {
235	// Imports needed by the EventStream APIs.
236	s.API.AddImport("fmt")
237	s.API.AddImport("bytes")
238	s.API.AddImport("io")
239	s.API.AddImport("sync")
240	s.API.AddImport("sync/atomic")
241	s.API.AddSDKImport("aws")
242	s.API.AddSDKImport("aws/awserr")
243	s.API.AddSDKImport("private/protocol/eventstream")
244	s.API.AddSDKImport("private/protocol/eventstream/eventstreamapi")
245
246	return eventStreamAPIShapeTmpl.Execute(w, s)
247}
248
249// Template for an EventStream API Shape that will provide read/writing events
250// across the EventStream. This is a special shape that's only public members
251// are the Events channel and a Close and Err method.
252//
253// Executed in the context of a Shape.
254var eventStreamAPIShapeTmpl = func() *template.Template {
255	t := template.Must(
256		template.New("eventStreamAPIShapeTmpl").
257			Funcs(template.FuncMap{}).
258			Parse(eventStreamAPITmplDef),
259	)
260
261	template.Must(
262		t.AddParseTree(
263			"eventStreamAPIReaderTmpl", eventStreamAPIReaderTmpl.Tree),
264	)
265
266	return t
267}()
268
269const eventStreamAPITmplDef = `
270{{ $.Documentation }}
271type {{ $.ShapeName }} struct {
272	{{- if $.EventStreamAPI.Inbound }}
273		// Reader is the EventStream reader for the {{ $.EventStreamAPI.Inbound.Name }}
274		// events. This value is automatically set by the SDK when the API call is made
275		// Use this member when unit testing your code with the SDK to mock out the
276		// EventStream Reader.
277		//
278		// Must not be nil.
279		Reader {{ $.ShapeName }}Reader
280
281	{{ end -}}
282
283	{{- if $.EventStreamAPI.Outbound }}
284		// Writer is the EventStream reader for the {{ $.EventStreamAPI.Inbound.Name }}
285		// events. This value is automatically set by the SDK when the API call is made
286		// Use this member when unit testing your code with the SDK to mock out the
287		// EventStream Writer.
288		//
289		// Must not be nil.
290		Writer *{{ $.ShapeName }}Writer
291
292	{{ end -}}
293
294	// StreamCloser is the io.Closer for the EventStream connection. For HTTP
295	// EventStream this is the response Body. The stream will be closed when
296	// the Close method of the EventStream is called.
297	StreamCloser io.Closer
298}
299
300// Close closes the EventStream. This will also cause the Events channel to be
301// closed. You can use the closing of the Events channel to terminate your
302// application's read from the API's EventStream.
303{{- if $.EventStreamAPI.Inbound }}
304//
305// Will close the underlying EventStream reader. For EventStream over HTTP
306// connection this will also close the HTTP connection.
307{{ end -}}
308//
309// Close must be called when done using the EventStream API. Not calling Close
310// may result in resource leaks.
311func (es *{{ $.ShapeName }}) Close() (err error) {
312	{{- if $.EventStreamAPI.Inbound }}
313		es.Reader.Close()
314	{{ end -}}
315	{{- if $.EventStreamAPI.Outbound }}
316		es.Writer.Close()
317	{{ end -}}
318
319	return es.Err()
320}
321
322// Err returns any error that occurred while reading EventStream Events from
323// the service API's response. Returns nil if there were no errors.
324func (es *{{ $.ShapeName }}) Err() error {
325	{{- if $.EventStreamAPI.Outbound }}
326		if err := es.Writer.Err(); err != nil {
327			return err
328		}
329	{{ end -}}
330
331	{{- if $.EventStreamAPI.Inbound }}
332		if err := es.Reader.Err(); err != nil {
333			return err
334		}
335	{{ end -}}
336
337	es.StreamCloser.Close()
338
339	return nil
340}
341
342{{ if $.EventStreamAPI.Inbound }}
343	// Events returns a channel to read EventStream Events from the
344	// {{ $.EventStreamAPI.Operation.ExportedName }} API.
345	//
346	// These events are:
347	// {{ range $_, $event := $.EventStreamAPI.Inbound.Events }}
348	//     * {{ $event.Shape.ShapeName }}
349	{{- end }}
350	func (es *{{ $.ShapeName }}) Events() <-chan {{ $.EventStreamAPI.Inbound.Name }}Event {
351		return es.Reader.Events()
352	}
353
354	{{ template "eventStreamAPIReaderTmpl" $ }}
355{{ end }}
356
357{{ if $.EventStreamAPI.Outbound }}
358	// TODO writer helper method.
359{{ end }}
360
361`
362
363var eventStreamAPIReaderTmpl = template.Must(template.New("eventStreamAPIReaderTmpl").
364	Funcs(template.FuncMap{}).
365	Parse(`
366// {{ $.EventStreamAPI.Inbound.Name }}Event groups together all EventStream
367// events read from the {{ $.EventStreamAPI.Operation.ExportedName }} API.
368//
369// These events are:
370// {{ range $_, $event := $.EventStreamAPI.Inbound.Events }}
371//     * {{ $event.Shape.ShapeName }}
372{{- end }}
373type {{ $.EventStreamAPI.Inbound.Name }}Event interface {
374	event{{ $.EventStreamAPI.Inbound.Name }}()
375}
376
377// {{ $.ShapeName }}Reader provides the interface for reading EventStream
378// Events from the {{ $.EventStreamAPI.Operation.ExportedName }} API. The
379// default implementation for this interface will be {{ $.ShapeName }}.
380//
381// The reader's Close method must allow multiple concurrent calls.
382//
383// These events are:
384// {{ range $_, $event := $.EventStreamAPI.Inbound.Events }}
385//     * {{ $event.Shape.ShapeName }}
386{{- end }}
387type {{ $.ShapeName }}Reader interface {
388	// Returns a channel of events as they are read from the event stream.
389	Events() <-chan {{ $.EventStreamAPI.Inbound.Name }}Event
390
391	// Close will close the underlying event stream reader. For event stream over
392	// HTTP this will also close the HTTP connection.
393	Close() error
394
395	// Returns any error that has occurred while reading from the event stream.
396	Err() error
397}
398
399type read{{ $.ShapeName }} struct {
400	eventReader *eventstreamapi.EventReader
401	stream chan {{ $.EventStreamAPI.Inbound.Name }}Event
402	errVal atomic.Value
403
404	done      chan struct{}
405	closeOnce sync.Once
406
407	{{ if eq $.API.Metadata.Protocol "json" -}}
408		initResp eventstreamapi.Unmarshaler
409	{{ end -}}
410}
411
412func newRead{{ $.ShapeName }}(
413	reader io.ReadCloser,
414	unmarshalers request.HandlerList,
415	logger aws.Logger,
416	logLevel aws.LogLevelType,
417	{{ if eq $.API.Metadata.Protocol "json" -}}
418		initResp eventstreamapi.Unmarshaler,
419	{{ end -}}
420) *read{{ $.ShapeName }} {
421	r := &read{{ $.ShapeName }}{
422		stream: make(chan {{ $.EventStreamAPI.Inbound.Name }}Event),
423		done: make(chan struct{}),
424		{{ if eq $.API.Metadata.Protocol "json" -}}
425			initResp: initResp,
426		{{ end -}}
427	}
428
429	r.eventReader = eventstreamapi.NewEventReader(
430		reader,
431		protocol.HandlerPayloadUnmarshal{
432			Unmarshalers: unmarshalers,
433		},
434		r.unmarshalerForEventType,
435	)
436	r.eventReader.UseLogger(logger, logLevel)
437
438	return r
439}
440
441// Close will close the underlying event stream reader. For EventStream over
442// HTTP this will also close the HTTP connection.
443func (r *read{{ $.ShapeName }}) Close() error {
444	r.closeOnce.Do(r.safeClose)
445
446	return r.Err()
447}
448
449func (r *read{{ $.ShapeName }}) safeClose() {
450	close(r.done)
451	err := r.eventReader.Close()
452	if err != nil {
453		r.errVal.Store(err)
454	}
455}
456
457func (r *read{{ $.ShapeName }}) Err() error {
458	if v := r.errVal.Load(); v != nil {
459		return v.(error)
460	}
461
462	return nil
463}
464
465func (r *read{{ $.ShapeName }}) Events() <-chan {{ $.EventStreamAPI.Inbound.Name }}Event {
466	return r.stream
467}
468
469func (r *read{{ $.ShapeName }}) readEventStream() {
470	defer close(r.stream)
471
472	for {
473		event, err := r.eventReader.ReadEvent()
474		if err != nil {
475			if err == io.EOF {
476				return
477			}
478			select {
479			case <-r.done:
480				// If closed already ignore the error
481				return
482			default:
483			}
484			r.errVal.Store(err)
485			return
486		}
487
488		select {
489		case r.stream <- event.({{ $.EventStreamAPI.Inbound.Name }}Event):
490		case <-r.done:
491			return
492		}
493	}
494}
495
496func (r *read{{ $.ShapeName }}) unmarshalerForEventType(
497	eventType string,
498) (eventstreamapi.Unmarshaler, error) {
499	switch eventType {
500		{{- if eq $.API.Metadata.Protocol "json" }}
501			case "initial-response":
502				return r.initResp, nil
503		{{ end -}}
504		{{- range $_, $event := $.EventStreamAPI.Inbound.Events }}
505			case {{ printf "%q" $event.Name }}:
506				return &{{ $event.Shape.ShapeName }}{}, nil
507		{{ end -}}
508		{{- range $_, $event := $.EventStreamAPI.Inbound.Exceptions }}
509			case {{ printf "%q" $event.Name }}:
510				return &{{ $event.Shape.ShapeName }}{}, nil
511		{{ end -}}
512	default:
513		return nil, awserr.New(
514			request.ErrCodeSerialization,
515			fmt.Sprintf("unknown event type name, %s, for {{ $.ShapeName }}", eventType),
516			nil,
517		)
518	}
519}
520`))
521
522// Template for the EventStream API Output shape that contains the EventStream
523// member.
524//
525// Executed in the context of a Shape.
526var eventStreamAPILoopMethodTmpl = template.Must(
527	template.New("eventStreamAPILoopMethodTmpl").Parse(`
528func (s *{{ $.ShapeName }}) runEventStreamLoop(r *request.Request) {
529	if r.Error != nil {
530		return
531	}
532
533	{{- $esMemberRef := index $.MemberRefs $.EventStreamsMemberName }}
534	{{- if $esMemberRef.Shape.EventStreamAPI.Inbound }}
535		reader := newRead{{ $esMemberRef.ShapeName }}(
536			r.HTTPResponse.Body,
537			r.Handlers.UnmarshalStream,
538			r.Config.Logger,
539			r.Config.LogLevel.Value(),
540			{{ if eq $.API.Metadata.Protocol "json" -}}
541				s,
542			{{ end -}}
543		)
544		go reader.readEventStream()
545
546		eventStream := &{{ $esMemberRef.ShapeName }} {
547			StreamCloser: r.HTTPResponse.Body,
548			Reader: reader,
549		}
550	{{ end -}}
551
552	s.{{ $.EventStreamsMemberName }} = eventStream
553}
554
555{{ if eq $.API.Metadata.Protocol "json" -}}
556	func (s *{{ $.ShapeName }}) unmarshalInitialResponse(r *request.Request) {
557		// Wait for the initial response event, which must be the first event to be
558		// received from the API.
559		select {
560		case event, ok := <-s.EventStream.Events():
561			if !ok {
562				return
563			}
564			es := s.EventStream
565			v, ok := event.(*{{ $.ShapeName }})
566			if !ok || v == nil {
567				r.Error = awserr.New(
568					request.ErrCodeSerialization,
569					fmt.Sprintf("invalid event, %T, expect *SubscribeToShardOutput, %v", event, v),
570					nil,
571				)
572				return
573			}
574			*s = *v
575			s.EventStream = es
576		}
577	}
578{{ end -}}
579`))
580
581// EventStreamHeaderTypeMap provides the mapping of a EventStream Header's
582// Value type to the shape reference's member type.
583type EventStreamHeaderTypeMap struct {
584	Header string
585	Member string
586}
587
588var eventStreamEventShapeTmplFuncs = template.FuncMap{
589	"EventStreamHeaderTypeMap": func(ref *ShapeRef) EventStreamHeaderTypeMap {
590		switch ref.Shape.Type {
591		case "boolean":
592			return EventStreamHeaderTypeMap{Header: "bool", Member: "bool"}
593		case "byte":
594			return EventStreamHeaderTypeMap{Header: "int8", Member: "int64"}
595		case "short":
596			return EventStreamHeaderTypeMap{Header: "int16", Member: "int64"}
597		case "integer":
598			return EventStreamHeaderTypeMap{Header: "int32", Member: "int64"}
599		case "long":
600			return EventStreamHeaderTypeMap{Header: "int64", Member: "int64"}
601		case "timestamp":
602			return EventStreamHeaderTypeMap{Header: "time.Time", Member: "time.Time"}
603		case "blob":
604			return EventStreamHeaderTypeMap{Header: "[]byte", Member: "[]byte"}
605		case "string":
606			return EventStreamHeaderTypeMap{Header: "string", Member: "string"}
607		// TODO case "uuid"  what is modeled type
608		default:
609			panic("unsupported EventStream header type, " + ref.Shape.Type)
610		}
611	},
612	"HasNonBlobPayloadMembers": eventHasNonBlobPayloadMembers,
613}
614
615// Returns if the event has any members which are not the event's blob payload,
616// nor a header.
617func eventHasNonBlobPayloadMembers(s *Shape) bool {
618	num := len(s.MemberRefs)
619	for _, ref := range s.MemberRefs {
620		if ref.IsEventHeader || (ref.IsEventPayload && (ref.Shape.Type == "blob" || ref.Shape.Type == "string")) {
621			num--
622		}
623	}
624	return num > 0
625}
626
627// Template for an EventStream Event shape. This is a normal API shape that is
628// decorated as an EventStream Event.
629//
630// Executed in the context of a Shape.
631var eventStreamEventShapeTmpl = template.Must(template.New("eventStreamEventShapeTmpl").
632	Funcs(eventStreamEventShapeTmplFuncs).Parse(`
633{{ range $_, $eventstream := $.EventFor }}
634	// The {{ $.ShapeName }} is and event in the {{ $eventstream.Name }} group of events.
635	func (s *{{ $.ShapeName }}) event{{ $eventstream.Name }}() {}
636{{ end }}
637
638// UnmarshalEvent unmarshals the EventStream Message into the {{ $.ShapeName }} value.
639// This method is only used internally within the SDK's EventStream handling.
640func (s *{{ $.ShapeName }}) UnmarshalEvent(
641	payloadUnmarshaler protocol.PayloadUnmarshaler,
642	msg eventstream.Message,
643) error {
644	{{- range $memName, $memRef := $.MemberRefs }}
645		{{- if $memRef.IsEventHeader }}
646			if hv := msg.Headers.Get("{{ $memName }}"); hv != nil {
647				{{ $types := EventStreamHeaderTypeMap $memRef -}}
648				v := hv.Get().({{ $types.Header }})
649				{{- if ne $types.Header $types.Member }}
650					m := {{ $types.Member }}(v)
651					s.{{ $memName }} = {{ if $memRef.UseIndirection }}&{{ end }}m
652				{{- else }}
653					s.{{ $memName }} = {{ if $memRef.UseIndirection }}&{{ end }}v
654				{{- end }}
655			}
656		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "blob")) }}
657			s.{{ $memName }} = make([]byte, len(msg.Payload))
658			copy(s.{{ $memName }}, msg.Payload)
659		{{- else if (and ($memRef.IsEventPayload) (eq $memRef.Shape.Type "string")) }}
660			s.{{ $memName }} = aws.String(string(msg.Payload))
661		{{- end }}
662	{{- end }}
663	{{- if HasNonBlobPayloadMembers $ }}
664		if err := payloadUnmarshaler.UnmarshalPayload(
665			bytes.NewReader(msg.Payload), s,
666		); err != nil {
667			return err
668		}
669	{{- end }}
670	return nil
671}
672`))
673
674var eventStreamExceptionEventShapeTmpl = template.Must(
675	template.New("eventStreamExceptionEventShapeTmpl").Parse(`
676// Code returns the exception type name.
677func (s {{ $.ShapeName }}) Code() string {
678	{{- if $.ErrorInfo.Code }}
679		return "{{ $.ErrorInfo.Code }}"
680	{{- else }}
681		return "{{ $.ShapeName }}"
682	{{ end -}}
683}
684
685// Message returns the exception's message.
686func (s {{ $.ShapeName }}) Message() string {
687	{{- if index $.MemberRefs "Message_" }}
688		return *s.Message_
689	{{- else }}
690		return ""
691	{{ end -}}
692}
693
694// OrigErr always returns nil, satisfies awserr.Error interface.
695func (s {{ $.ShapeName }}) OrigErr() error {
696	return nil
697}
698
699func (s {{ $.ShapeName }}) Error() string {
700	return fmt.Sprintf("%s: %s", s.Code(), s.Message())
701}
702`))
703
704// APIEventStreamTestGoCode generates Go code for EventStream operation tests.
705func (a *API) APIEventStreamTestGoCode() string {
706	var buf bytes.Buffer
707
708	a.resetImports()
709	a.AddImport("bytes")
710	a.AddImport("io/ioutil")
711	a.AddImport("net/http")
712	a.AddImport("reflect")
713	a.AddImport("testing")
714	a.AddImport("time")
715	a.AddSDKImport("aws")
716	a.AddSDKImport("aws/corehandlers")
717	a.AddSDKImport("aws/request")
718	a.AddSDKImport("aws/awserr")
719	a.AddSDKImport("awstesting/unit")
720	a.AddSDKImport("private/protocol")
721	a.AddSDKImport("private/protocol/", a.ProtocolPackage())
722	a.AddSDKImport("private/protocol/eventstream")
723	a.AddSDKImport("private/protocol/eventstream/eventstreamapi")
724	a.AddSDKImport("private/protocol/eventstream/eventstreamtest")
725
726	unused := `
727	var _ time.Time
728	var _ awserr.Error
729	`
730
731	if err := eventStreamTestTmpl.Execute(&buf, a); err != nil {
732		panic(err)
733	}
734
735	return a.importsGoCode() + unused + strings.TrimSpace(buf.String())
736}
737
738func valueForType(s *Shape, visited []string) string {
739	for _, v := range visited {
740		if v == s.ShapeName {
741			return "nil"
742		}
743	}
744
745	visited = append(visited, s.ShapeName)
746
747	switch s.Type {
748	case "blob":
749		return `[]byte("blob value goes here")`
750	case "string":
751		return `aws.String("string value goes here")`
752	case "boolean":
753		return `aws.Bool(true)`
754	case "byte":
755		return `aws.Int64(1)`
756	case "short":
757		return `aws.Int64(12)`
758	case "integer":
759		return `aws.Int64(123)`
760	case "long":
761		return `aws.Int64(1234)`
762	case "float":
763		return `aws.Float64(123.4)`
764	case "double":
765		return `aws.Float64(123.45)`
766	case "timestamp":
767		return `aws.Time(time.Unix(1396594860, 0).UTC())`
768	case "structure":
769		w := bytes.NewBuffer(nil)
770		fmt.Fprintf(w, "&%s{\n", s.ShapeName)
771		for _, refName := range s.MemberNames() {
772			fmt.Fprintf(w, "%s: %s,\n", refName, valueForType(s.MemberRefs[refName].Shape, visited))
773		}
774		fmt.Fprintf(w, "}")
775		return w.String()
776	case "list":
777		w := bytes.NewBuffer(nil)
778		fmt.Fprintf(w, "%s{\n", s.GoType())
779		for i := 0; i < 3; i++ {
780			fmt.Fprintf(w, "%s,\n", valueForType(s.MemberRef.Shape, visited))
781		}
782		fmt.Fprintf(w, "}")
783		return w.String()
784
785	case "map":
786		w := bytes.NewBuffer(nil)
787		fmt.Fprintf(w, "%s{\n", s.GoType())
788		for _, k := range []string{"a", "b", "c"} {
789			fmt.Fprintf(w, "%q: %s,\n", k, valueForType(s.ValueRef.Shape, visited))
790		}
791		fmt.Fprintf(w, "}")
792		return w.String()
793
794	default:
795		panic(fmt.Sprintf("valueForType does not support %s, %s", s.ShapeName, s.Type))
796	}
797}
798
799func setEventHeaderValueForType(s *Shape, memVar string) string {
800	switch s.Type {
801	case "blob":
802		return fmt.Sprintf("eventstream.BytesValue(%s)", memVar)
803	case "string":
804		return fmt.Sprintf("eventstream.StringValue(*%s)", memVar)
805	case "boolean":
806		return fmt.Sprintf("eventstream.BoolValue(*%s)", memVar)
807	case "byte":
808		return fmt.Sprintf("eventstream.Int8Value(int8(*%s))", memVar)
809	case "short":
810		return fmt.Sprintf("eventstream.Int16Value(int16(*%s))", memVar)
811	case "integer":
812		return fmt.Sprintf("eventstream.Int32Value(int32(*%s))", memVar)
813	case "long":
814		return fmt.Sprintf("eventstream.Int64Value(*%s)", memVar)
815	case "float":
816		return fmt.Sprintf("eventstream.Float32Value(float32(*%s))", memVar)
817	case "double":
818		return fmt.Sprintf("eventstream.Float64Value(*%s)", memVar)
819	case "timestamp":
820		return fmt.Sprintf("eventstream.TimestampValue(*%s)", memVar)
821	default:
822		panic(fmt.Sprintf("value type %s not supported for event headers, %s", s.Type, s.ShapeName))
823	}
824}
825
826func templateMap(args ...interface{}) map[string]interface{} {
827	if len(args)%2 != 0 {
828		panic(fmt.Sprintf("invalid map call, non-even args %v", args))
829	}
830
831	m := map[string]interface{}{}
832	for i := 0; i < len(args); i += 2 {
833		k, ok := args[i].(string)
834		if !ok {
835			panic(fmt.Sprintf("invalid map call, arg is not string, %T, %v", args[i], args[i]))
836		}
837		m[k] = args[i+1]
838	}
839
840	return m
841}
842
843var eventStreamTestTmpl = template.Must(
844	template.New("eventStreamTestTmpl").Funcs(template.FuncMap{
845		"ValueForType":               valueForType,
846		"HasNonBlobPayloadMembers":   eventHasNonBlobPayloadMembers,
847		"SetEventHeaderValueForType": setEventHeaderValueForType,
848		"Map":                        templateMap,
849		"OptionalAddInt": func(do bool, a, b int) int {
850			if !do {
851				return a
852			}
853			return a + b
854		},
855		"HasNonEventStreamMember": func(s *Shape) bool {
856			for _, ref := range s.MemberRefs {
857				if !ref.Shape.IsEventStream {
858					return true
859				}
860			}
861			return false
862		},
863	}).Parse(`
864{{ range $opName, $op := $.Operations }}
865	{{ if $op.EventStreamAPI }}
866		{{ if $op.EventStreamAPI.Inbound }}
867			{{ template "event stream inbound tests" $op.EventStreamAPI }}
868		{{ end }}
869	{{ end }}
870{{ end }}
871
872type loopReader struct {
873	source *bytes.Reader
874}
875
876func (c *loopReader) Read(p []byte) (int, error) {
877	if c.source.Len() == 0 {
878		c.source.Seek(0, 0)
879	}
880
881	return c.source.Read(p)
882}
883
884{{ define "event stream inbound tests" }}
885	func Test{{ $.Operation.ExportedName }}_Read(t *testing.T) {
886		expectEvents, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
887		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
888			eventstreamtest.ServeEventStream{
889				T:      t,
890				Events: eventMsgs,
891			},
892			true,
893		)
894		if err != nil {
895			t.Fatalf("expect no error, %v", err)
896		}
897		defer cleanupFn()
898
899		svc := New(sess)
900		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
901		if err != nil {
902			t.Fatalf("expect no error got, %v", err)
903		}
904		defer resp.EventStream.Close()
905
906		{{- if eq $.Operation.API.Metadata.Protocol "json" }}
907			{{- if HasNonEventStreamMember $.Operation.OutputRef.Shape }}
908				expectResp := expectEvents[0].(*{{ $.Operation.OutputRef.Shape.ShapeName }})
909				{{- range $name, $ref := $.Operation.OutputRef.Shape.MemberRefs }}
910					{{- if not $ref.Shape.IsEventStream }}
911						if e, a := expectResp.{{ $name }}, resp.{{ $name }}; !reflect.DeepEqual(e,a) {
912							t.Errorf("expect %v, got %v", e, a)
913						}
914					{{- end }}
915				{{- end }}
916			{{- end }}
917			// Trim off response output type pseudo event so only event messages remain.
918			expectEvents = expectEvents[1:]
919		{{ end }}
920
921		var i int
922		for event := range resp.EventStream.Events() {
923			if event == nil {
924				t.Errorf("%d, expect event, got nil", i)
925			}
926			if e, a := expectEvents[i], event; !reflect.DeepEqual(e, a) {
927				t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a)
928			}
929			i++
930		}
931
932		if err := resp.EventStream.Err(); err != nil {
933			t.Errorf("expect no error, %v", err)
934		}
935	}
936
937	func Test{{ $.Operation.ExportedName }}_ReadClose(t *testing.T) {
938		_, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
939		sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
940			eventstreamtest.ServeEventStream{
941				T:      t,
942				Events: eventMsgs,
943			},
944			true,
945		)
946		if err != nil {
947			t.Fatalf("expect no error, %v", err)
948		}
949		defer cleanupFn()
950
951		svc := New(sess)
952		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
953		if err != nil {
954			t.Fatalf("expect no error got, %v", err)
955		}
956
957		resp.EventStream.Close()
958		<-resp.EventStream.Events()
959
960		if err := resp.EventStream.Err(); err != nil {
961			t.Errorf("expect no error, %v", err)
962		}
963	}
964
965	func Benchmark{{ $.Operation.ExportedName }}_Read(b *testing.B) {
966		_, eventMsgs := mock{{ $.Operation.ExportedName }}ReadEvents()
967		var buf bytes.Buffer
968		encoder := eventstream.NewEncoder(&buf)
969		for _, msg := range eventMsgs {
970			if err := encoder.Encode(msg); err != nil {
971				b.Fatalf("failed to encode message, %v", err)
972			}
973		}
974		stream := &loopReader{source: bytes.NewReader(buf.Bytes())}
975
976		sess := unit.Session
977		svc := New(sess, &aws.Config{
978			Endpoint:               aws.String("https://example.com"),
979			DisableParamValidation: aws.Bool(true),
980		})
981		svc.Handlers.Send.Swap(corehandlers.SendHandler.Name,
982			request.NamedHandler{Name: "mockSend",
983				Fn: func(r *request.Request) {
984					r.HTTPResponse = &http.Response{
985						Status:     "200 OK",
986						StatusCode: 200,
987						Header:     http.Header{},
988						Body:       ioutil.NopCloser(stream),
989					}
990				},
991			},
992		)
993
994		resp, err := svc.{{ $.Operation.ExportedName }}(nil)
995		if err != nil {
996			b.Fatalf("failed to create request, %v", err)
997		}
998		defer resp.EventStream.Close()
999		b.ResetTimer()
1000
1001		for i := 0; i < b.N; i++ {
1002			if err = resp.EventStream.Err(); err != nil {
1003				b.Fatalf("expect no error, got %v", err)
1004			}
1005			event := <-resp.EventStream.Events()
1006			if event == nil {
1007				b.Fatalf("expect event, got nil, %v, %d", resp.EventStream.Err(), i)
1008			}
1009		}
1010	}
1011
1012	func mock{{ $.Operation.ExportedName }}ReadEvents() (
1013		[]{{ $.Inbound.Name }}Event,
1014		[]eventstream.Message,
1015	) {
1016		expectEvents := []{{ $.Inbound.Name }}Event {
1017			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
1018				{{- template "set event type" $.Operation.OutputRef.Shape }}
1019			{{- end }}
1020			{{- range $_, $event := $.Inbound.Events }}
1021				{{- template "set event type" $event.Shape }}
1022			{{- end }}
1023		}
1024
1025		var marshalers request.HandlerList
1026		marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
1027		payloadMarshaler := protocol.HandlerPayloadMarshal{
1028			Marshalers: marshalers,
1029		}
1030		_ = payloadMarshaler
1031
1032		eventMsgs := []eventstream.Message{
1033			{{- if eq $.Operation.API.Metadata.Protocol "json" }}
1034				{{- template "set event message" Map "idx" 0 "parentShape" $.Operation.OutputRef.Shape "eventName" "initial-response" }}
1035			{{- end }}
1036			{{- range $idx, $event := $.Inbound.Events }}
1037				{{- $offsetIdx := OptionalAddInt (eq $.Operation.API.Metadata.Protocol "json") $idx 1 }}
1038				{{- template "set event message" Map "idx" $offsetIdx "parentShape" $event.Shape "eventName" $event.Name }}
1039			{{- end }}
1040		}
1041
1042		return expectEvents, eventMsgs
1043	}
1044
1045	{{- if $.Inbound.Exceptions }}
1046		func Test{{ $.Operation.ExportedName }}_ReadException(t *testing.T) {
1047			expectEvents := []{{ $.Inbound.Name }}Event {
1048				{{- if eq $.Operation.API.Metadata.Protocol "json" }}
1049					{{- template "set event type" $.Operation.OutputRef.Shape }}
1050				{{- end }}
1051
1052				{{- $exception := index $.Inbound.Exceptions 0 }}
1053				{{- template "set event type" $exception.Shape }}
1054			}
1055
1056			var marshalers request.HandlerList
1057			marshalers.PushBackNamed({{ $.API.ProtocolPackage }}.BuildHandler)
1058			payloadMarshaler := protocol.HandlerPayloadMarshal{
1059				Marshalers: marshalers,
1060			}
1061
1062			eventMsgs := []eventstream.Message{
1063				{{- if eq $.Operation.API.Metadata.Protocol "json" }}
1064					{{- template "set event message" Map "idx" 0 "parentShape" $.Operation.OutputRef.Shape "eventName" "initial-response" }}
1065				{{- end }}
1066
1067				{{- $offsetIdx := OptionalAddInt (eq $.Operation.API.Metadata.Protocol "json") 0 1 }}
1068				{{- $exception := index $.Inbound.Exceptions 0 }}
1069				{{- template "set event message" Map "idx" $offsetIdx "parentShape" $exception.Shape "eventName" $exception.Name }}
1070			}
1071
1072			sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
1073				eventstreamtest.ServeEventStream{
1074					T:      t,
1075					Events: eventMsgs,
1076				},
1077				true,
1078			)
1079			if err != nil {
1080				t.Fatalf("expect no error, %v", err)
1081			}
1082			defer cleanupFn()
1083
1084			svc := New(sess)
1085			resp, err := svc.{{ $.Operation.ExportedName }}(nil)
1086			if err != nil {
1087				t.Fatalf("expect no error got, %v", err)
1088			}
1089
1090			defer resp.EventStream.Close()
1091
1092			<-resp.EventStream.Events()
1093
1094			err = resp.EventStream.Err()
1095			if err == nil {
1096				t.Fatalf("expect err, got none")
1097			}
1098
1099			expectErr := {{ ValueForType $exception.Shape nil }}
1100			aerr, ok := err.(awserr.Error)
1101			if !ok {
1102				t.Errorf("expect exception, got %T, %#v", err, err)
1103			}
1104			if e, a := expectErr.Code(), aerr.Code(); e != a {
1105				t.Errorf("expect %v, got %v", e, a)
1106			}
1107			if e, a := expectErr.Message(), aerr.Message(); e != a {
1108				t.Errorf("expect %v, got %v", e, a)
1109			}
1110
1111			if e, a := expectErr, aerr; !reflect.DeepEqual(e, a) {
1112				t.Errorf("expect %#v, got %#v", e, a)
1113			}
1114		}
1115
1116		{{- range $_, $exception := $.Inbound.Exceptions }}
1117			var _ awserr.Error = (*{{ $exception.Shape.ShapeName }})(nil)
1118		{{- end }}
1119
1120	{{ end }}
1121{{ end }}
1122
1123{{/* Params: *Shape */}}
1124{{ define "set event type" }}
1125	&{{ $.ShapeName }}{
1126		{{- range $memName, $memRef := $.MemberRefs }}
1127			{{- if not $memRef.Shape.IsEventStream }}
1128				{{ $memName }}: {{ ValueForType $memRef.Shape nil }},
1129			{{- end }}
1130		{{- end }}
1131	},
1132{{- end }}
1133
1134{{/* Params: idx:int, parentShape:*Shape, eventName:string */}}
1135{{ define "set event message" }}
1136	{
1137		Headers: eventstream.Headers{
1138			{{- if $.parentShape.Exception }}
1139				eventstreamtest.EventExceptionTypeHeader,
1140				{
1141					Name:  eventstreamapi.ExceptionTypeHeader,
1142					Value: eventstream.StringValue("{{ $.eventName }}"),
1143				},
1144			{{- else }}
1145				eventstreamtest.EventMessageTypeHeader,
1146				{
1147					Name:  eventstreamapi.EventTypeHeader,
1148					Value: eventstream.StringValue("{{ $.eventName }}"),
1149				},
1150			{{- end }}
1151			{{- range $memName, $memRef := $.parentShape.MemberRefs }}
1152				{{- template "set event message header" Map "idx" $.idx "parentShape" $.parentShape "memName" $memName "memRef" $memRef }}
1153			{{- end }}
1154		},
1155		{{- template "set event message payload" Map "idx" $.idx "parentShape" $.parentShape }}
1156	},
1157{{- end }}
1158
1159{{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
1160{{ define "set event message header" }}
1161	{{- if $.memRef.IsEventHeader }}
1162		{
1163			Name: "{{ $.memName }}",
1164			{{- $shapeValueVar := printf "expectEvents[%d].(%s).%s" $.idx $.parentShape.GoType $.memName }}
1165			Value: {{ SetEventHeaderValueForType $.memRef.Shape $shapeValueVar }},
1166		},
1167	{{- end }}
1168{{- end }}
1169
1170{{/* Params: idx:int, parentShape:*Shape, memName:string, memRef:*ShapeRef */}}
1171{{ define "set event message payload" }}
1172	{{- $payloadMemName := $.parentShape.PayloadRefName }}
1173	{{- if HasNonBlobPayloadMembers $.parentShape }}
1174		Payload: eventstreamtest.MarshalEventPayload(payloadMarshaler, expectEvents[{{ $.idx }}]),
1175	{{- else if $payloadMemName }}
1176		{{- $shapeType := (index $.parentShape.MemberRefs $payloadMemName).Shape.Type }}
1177		{{- if eq $shapeType "blob" }}
1178			Payload: expectEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }},
1179		{{- else if eq $shapeType "string" }}
1180			Payload: []byte(*expectEvents[{{ $.idx }}].({{ $.parentShape.GoType }}).{{ $payloadMemName }}),
1181		{{- end }}
1182	{{- end }}
1183{{- end }}
1184`))
1185